diff --git a/src/generated_schema/2024_11_05/mcp_schema.rs b/src/generated_schema/2024_11_05/mcp_schema.rs index be232cc..2f034f4 100644 --- a/src/generated_schema/2024_11_05/mcp_schema.rs +++ b/src/generated_schema/2024_11_05/mcp_schema.rs @@ -1,12 +1,12 @@ /// ---------------------------------------------------------------------------- -/// This file is auto-generated by mcp-schema-gen v0.1.3. +/// This file is auto-generated by mcp-schema-gen v0.1.4. /// WARNING: /// It is not recommended to modify this file directly. You are free to /// modify or extend the implementations as needed, but please do so at your own risk. /// /// Generated from : /// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e -/// Generated at : 2025-02-17 18:04:41 +/// Generated at : 2025-02-18 08:30:28 /// ---------------------------------------------------------------------------- /// /// MCP Protocol Version diff --git a/src/generated_schema/2024_11_05/schema_utils.rs b/src/generated_schema/2024_11_05/schema_utils.rs index 51262b5..1c09656 100644 --- a/src/generated_schema/2024_11_05/schema_utils.rs +++ b/src/generated_schema/2024_11_05/schema_utils.rs @@ -4,13 +4,31 @@ use serde_json::{json, Value}; use std::hash::{Hash, Hasher}; use std::{fmt::Display, str::FromStr}; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum MessageTypes { Request, Response, Notification, Error, } +/// Implements the `Display` trait for the `MessageTypes` enum, +/// allowing it to be converted into a human-readable string. +impl Display for MessageTypes { + /// Formats the `MessageTypes` enum variant as a string. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + // Match the current enum variant and return a corresponding string + match self { + MessageTypes::Request => "Request", + MessageTypes::Response => "Response", + MessageTypes::Notification => "Notification", + MessageTypes::Error => "Error", + } + ) + } +} /// A utility function used internally to detect the message type from the payload. /// This function is used when deserializing a `ClientMessage` into strongly-typed structs that represent the specific message received. @@ -38,12 +56,18 @@ fn detect_message_type(value: &serde_json::Value) -> MessageTypes { MessageTypes::Request } -pub trait MCPMessage { +/// Represents a generic MCP (Model Content Protocol) message. +/// This trait defines methods to classify and extract information from messages. +pub trait RPCMessage { + fn request_id(&self) -> Option<&RequestId>; +} + +pub trait MCPMessage: RPCMessage { + fn message_type(&self) -> MessageTypes; fn is_response(&self) -> bool; fn is_request(&self) -> bool; fn is_notification(&self) -> bool; fn is_error(&self) -> bool; - fn request_id(&self) -> Option<&RequestId>; } //*******************************// @@ -94,6 +118,22 @@ pub enum ClientMessage { Error(JsonrpcError), } +impl RPCMessage for ClientMessage { + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ClientMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + // Implementing the `MCPMessage` trait for `ClientMessage` impl MCPMessage for ClientMessage { // Returns true if the message is a response type @@ -116,17 +156,13 @@ impl MCPMessage for ClientMessage { matches!(self, ClientMessage::Error(_)) } - // Retrieves the request ID associated with the message, if applicable - fn request_id(&self) -> Option<&RequestId> { + /// Determines the type of the message and returns the corresponding `MessageTypes` variant. + fn message_type(&self) -> MessageTypes { match self { - // If the message is a request, return the associated request ID - ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), - // Notifications do not have request IDs - ClientMessage::Notification(_) => None, - // If the message is a response, return the associated request ID - ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), - // If the message is an error, return the associated request ID - ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + ClientMessage::Request(_) => MessageTypes::Request, + ClientMessage::Notification(_) => MessageTypes::Notification, + ClientMessage::Response(_) => MessageTypes::Response, + ClientMessage::Error(_) => MessageTypes::Error, } } } @@ -464,6 +500,22 @@ pub enum ServerMessage { Error(JsonrpcError), } +impl RPCMessage for ServerMessage { + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ServerMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + // Implementing the `MCPMessage` trait for `ServerMessage` impl MCPMessage for ServerMessage { // Returns true if the message is a response type @@ -486,17 +538,13 @@ impl MCPMessage for ServerMessage { matches!(self, ServerMessage::Error(_)) } - // Retrieves the request ID associated with the message, if applicable - fn request_id(&self) -> Option<&RequestId> { + /// Determines the type of the message and returns the corresponding `MessageTypes` variant. + fn message_type(&self) -> MessageTypes { match self { - // If the message is a request, return the associated request ID - ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), - // Notifications do not have request IDs - ServerMessage::Notification(_) => None, - // If the message is a response, return the associated request ID - ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), - // If the message is an error, return the associated request ID - ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + ServerMessage::Request(_) => MessageTypes::Request, + ServerMessage::Notification(_) => MessageTypes::Notification, + ServerMessage::Response(_) => MessageTypes::Response, + ServerMessage::Error(_) => MessageTypes::Error, } } } diff --git a/src/generated_schema/draft/mcp_schema.rs b/src/generated_schema/draft/mcp_schema.rs index fe644ee..8657970 100644 --- a/src/generated_schema/draft/mcp_schema.rs +++ b/src/generated_schema/draft/mcp_schema.rs @@ -1,12 +1,12 @@ /// ---------------------------------------------------------------------------- -/// This file is auto-generated by mcp-schema-gen v0.1.3. +/// This file is auto-generated by mcp-schema-gen v0.1.4. /// WARNING: /// It is not recommended to modify this file directly. You are free to /// modify or extend the implementations as needed, but please do so at your own risk. /// /// Generated from : /// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e -/// Generated at : 2025-02-17 18:04:41 +/// Generated at : 2025-02-18 08:30:28 /// ---------------------------------------------------------------------------- /// /// MCP Protocol Version diff --git a/src/generated_schema/draft/schema_utils.rs b/src/generated_schema/draft/schema_utils.rs index a53eaa7..07789ab 100644 --- a/src/generated_schema/draft/schema_utils.rs +++ b/src/generated_schema/draft/schema_utils.rs @@ -4,13 +4,31 @@ use serde_json::{json, Value}; use std::hash::{Hash, Hasher}; use std::{fmt::Display, str::FromStr}; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum MessageTypes { Request, Response, Notification, Error, } +/// Implements the `Display` trait for the `MessageTypes` enum, +/// allowing it to be converted into a human-readable string. +impl Display for MessageTypes { + /// Formats the `MessageTypes` enum variant as a string. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + // Match the current enum variant and return a corresponding string + match self { + MessageTypes::Request => "Request", + MessageTypes::Response => "Response", + MessageTypes::Notification => "Notification", + MessageTypes::Error => "Error", + } + ) + } +} /// A utility function used internally to detect the message type from the payload. /// This function is used when deserializing a `ClientMessage` into strongly-typed structs that represent the specific message received. @@ -44,6 +62,7 @@ pub trait MCPMessage { fn is_notification(&self) -> bool; fn is_error(&self) -> bool; fn request_id(&self) -> Option<&RequestId>; + fn message_type(&self) -> MessageTypes; } //*******************************// @@ -129,6 +148,16 @@ impl MCPMessage for ClientMessage { ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), } } + + /// Determines the type of the message and returns the corresponding `MessageTypes` variant. + fn message_type(&self) -> MessageTypes { + match self { + ClientMessage::Request(_) => MessageTypes::Request, + ClientMessage::Notification(_) => MessageTypes::Notification, + ClientMessage::Response(_) => MessageTypes::Response, + ClientMessage::Error(_) => MessageTypes::Error, + } + } } //**************************// @@ -499,6 +528,16 @@ impl MCPMessage for ServerMessage { ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), } } + + /// Determines the type of the message and returns the corresponding `MessageTypes` variant. + fn message_type(&self) -> MessageTypes { + match self { + ServerMessage::Request(_) => MessageTypes::Request, + ServerMessage::Notification(_) => MessageTypes::Notification, + ServerMessage::Response(_) => MessageTypes::Response, + ServerMessage::Error(_) => MessageTypes::Error, + } + } } impl FromStr for ServerMessage { diff --git a/tests/miscellaneous.rs b/tests/miscellaneous.rs new file mode 100644 index 0000000..4a4c0be --- /dev/null +++ b/tests/miscellaneous.rs @@ -0,0 +1,26 @@ +#[path = "common/common.rs"] +pub mod common; + +mod miscellaneous_tests { + use rust_mcp_schema::schema_utils::*; + + #[test] + fn test_display_request() { + assert_eq!(MessageTypes::Request.to_string(), "Request"); + } + + #[test] + fn test_display_response() { + assert_eq!(MessageTypes::Response.to_string(), "Response"); + } + + #[test] + fn test_display_notification() { + assert_eq!(MessageTypes::Notification.to_string(), "Notification"); + } + + #[test] + fn test_display_error() { + assert_eq!(MessageTypes::Error.to_string(), "Error"); + } +} diff --git a/tests/test_serialize.rs b/tests/test_serialize.rs index 5e5bdf6..8d643d5 100644 --- a/tests/test_serialize.rs +++ b/tests/test_serialize.rs @@ -41,6 +41,7 @@ mod test_serialize { assert!(!message.is_response()); assert!(!message.is_notification()); assert!(!message.is_error()); + assert!(message.message_type() == MessageTypes::Request); assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) ); @@ -228,6 +229,7 @@ mod test_serialize { assert!(message.is_response()); assert!(!message.is_notification()); assert!(!message.is_error()); + assert!(message.message_type() == MessageTypes::Response); assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) ); @@ -279,6 +281,8 @@ mod test_serialize { assert!(message.is_response()); assert!(!message.is_notification()); assert!(!message.is_error()); + assert!(message.message_type() == MessageTypes::Response); + assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) ); @@ -415,7 +419,7 @@ mod test_serialize { assert!(!message.is_response()); assert!(message.is_notification()); assert!(!message.is_error()); - + assert!(message.message_type() == MessageTypes::Notification); assert!(message.request_id().is_none()); assert!(matches!(message, ClientMessage::Notification(client_message) @@ -509,7 +513,7 @@ mod test_serialize { assert!(!message.is_response()); assert!(message.is_notification()); assert!(!message.is_error()); - + assert!(message.message_type() == MessageTypes::Notification); assert!(message.request_id().is_none()); assert!(matches!(message, ServerMessage::Notification(client_message) @@ -560,6 +564,8 @@ mod test_serialize { assert!(!message.is_response()); assert!(!message.is_notification()); assert!(!message.is_error()); + assert!(message.message_type() == MessageTypes::Request); + assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) ); @@ -609,6 +615,8 @@ mod test_serialize { assert!(!message.is_response()); assert!(!message.is_notification()); assert!(message.is_error()); + assert!(message.message_type() == MessageTypes::Error); + assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) ); @@ -628,6 +636,8 @@ mod test_serialize { assert!(!message.is_response()); assert!(!message.is_notification()); assert!(message.is_error()); + assert!(message.message_type() == MessageTypes::Error); + assert!( matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) );