diff --git a/src/generated_schema/2024_11_05/mcp_schema.rs b/src/generated_schema/2024_11_05/mcp_schema.rs index 0721bef..db8d7f9 100644 --- a/src/generated_schema/2024_11_05/mcp_schema.rs +++ b/src/generated_schema/2024_11_05/mcp_schema.rs @@ -6,7 +6,7 @@ /// /// Generated from : /// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e -/// Generated at : 2025-02-17 17:12:03 +/// Generated at : 2025-02-17 17:23:32 /// ---------------------------------------------------------------------------- /// /// MCP Protocol Version diff --git a/src/generated_schema/2024_11_05/schema_utils.rs b/src/generated_schema/2024_11_05/schema_utils.rs index edc7b2a..e3b41aa 100644 --- a/src/generated_schema/2024_11_05/schema_utils.rs +++ b/src/generated_schema/2024_11_05/schema_utils.rs @@ -1,6 +1,7 @@ use crate::generated_schema::*; use serde::ser::SerializeStruct; use serde_json::{json, Value}; +use std::hash::{Hash, Hasher}; use std::{fmt::Display, str::FromStr}; #[derive(Debug)] @@ -37,6 +38,47 @@ fn detect_message_type(value: &serde_json::Value) -> MessageTypes { MessageTypes::Request } +pub trait MCPMessage { + fn is_response(&self) -> bool; + fn is_request(&self) -> bool; + fn is_notification(&self) -> bool; + fn is_error(&self) -> bool; + fn request_id(&self) -> Option<&RequestId>; +} + +//*******************************// +//** RequestId Implementations **// +//*******************************// + +// Implement PartialEq and Eq for RequestId +impl PartialEq for RequestId { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (RequestId::String(a), RequestId::String(b)) => a == b, + (RequestId::Integer(a), RequestId::Integer(b)) => a == b, + _ => false, // Different variants are never equal + } + } +} + +impl Eq for RequestId {} + +// Implement Hash for RequestId, so we can store it in HashMaps, HashSets, etc. +impl Hash for RequestId { + fn hash(&self, state: &mut H) { + match self { + RequestId::String(s) => { + 0u8.hash(state); // Prefix with 0 for String variant + s.hash(state); + } + RequestId::Integer(i) => { + 1u8.hash(state); // Prefix with 1 for Integer variant + i.hash(state); + } + } + } +} + //*******************// //** ClientMessage **// //*******************// @@ -52,6 +94,43 @@ pub enum ClientMessage { Error(JsonrpcError), } +// Implementing the `MCPMessage` trait for `ClientMessage` +impl MCPMessage for ClientMessage { + // Returns true if the message is a response type + fn is_response(&self) -> bool { + matches!(self, ClientMessage::Response(_)) + } + + // Returns true if the message is a request type + fn is_request(&self) -> bool { + matches!(self, ClientMessage::Request(_)) + } + + // Returns true if the message is a notification type (i.e., does not expect a response) + fn is_notification(&self) -> bool { + matches!(self, ClientMessage::Notification(_)) + } + + // Returns true if the message represents an error + fn is_error(&self) -> bool { + matches!(self, ClientMessage::Error(_)) + } + + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ClientMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + //**************************// //** ClientJsonrpcRequest **// //**************************// @@ -385,6 +464,43 @@ pub enum ServerMessage { Error(JsonrpcError), } +// Implementing the `MCPMessage` trait for `ServerMessage` +impl MCPMessage for ServerMessage { + // Returns true if the message is a response type + fn is_response(&self) -> bool { + matches!(self, ServerMessage::Response(_)) + } + + // Returns true if the message is a request type + fn is_request(&self) -> bool { + matches!(self, ServerMessage::Request(_)) + } + + // Returns true if the message is a notification type (i.e., does not expect a response) + fn is_notification(&self) -> bool { + matches!(self, ServerMessage::Notification(_)) + } + + // Returns true if the message represents an error + fn is_error(&self) -> bool { + matches!(self, ServerMessage::Error(_)) + } + + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ServerMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + impl FromStr for ServerMessage { type Err = JsonrpcErrorError; diff --git a/src/generated_schema/draft/mcp_schema.rs b/src/generated_schema/draft/mcp_schema.rs index 22a81c9..15969fd 100644 --- a/src/generated_schema/draft/mcp_schema.rs +++ b/src/generated_schema/draft/mcp_schema.rs @@ -6,7 +6,7 @@ /// /// Generated from : /// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e -/// Generated at : 2025-02-17 17:12:04 +/// Generated at : 2025-02-17 17:23:32 /// ---------------------------------------------------------------------------- /// /// MCP Protocol Version diff --git a/src/generated_schema/draft/schema_utils.rs b/src/generated_schema/draft/schema_utils.rs index f10f66f..cfef8d0 100644 --- a/src/generated_schema/draft/schema_utils.rs +++ b/src/generated_schema/draft/schema_utils.rs @@ -1,6 +1,7 @@ use crate::generated_schema::*; use serde::ser::SerializeStruct; use serde_json::{json, Value}; +use std::hash::{Hash, Hasher}; use std::{fmt::Display, str::FromStr}; #[derive(Debug)] @@ -37,6 +38,47 @@ fn detect_message_type(value: &serde_json::Value) -> MessageTypes { MessageTypes::Request } +pub trait MCPMessage { + fn is_response(&self) -> bool; + fn is_request(&self) -> bool; + fn is_notification(&self) -> bool; + fn is_error(&self) -> bool; + fn request_id(&self) -> Option<&RequestId>; +} + +//*******************************// +//** RequestId Implementations **// +//*******************************// + +// Implement PartialEq and Eq for RequestId +impl PartialEq for RequestId { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (RequestId::String(a), RequestId::String(b)) => a == b, + (RequestId::Integer(a), RequestId::Integer(b)) => a == b, + _ => false, // Different variants are never equal + } + } +} + +impl Eq for RequestId {} + +// Implement Hash for RequestId, so we can store it in HashMaps, HashSets, etc. +impl Hash for RequestId { + fn hash(&self, state: &mut H) { + match self { + RequestId::String(s) => { + 0u8.hash(state); // Prefix with 0 for String variant + s.hash(state); + } + RequestId::Integer(i) => { + 1u8.hash(state); // Prefix with 1 for Integer variant + i.hash(state); + } + } + } +} + //*******************// //** ClientMessage **// //*******************// @@ -52,6 +94,43 @@ pub enum ClientMessage { Error(JsonrpcError), } +// Implementing the `MCPMessage` trait for `ClientMessage` +impl MCPMessage for ClientMessage { + // Returns true if the message is a response type + fn is_response(&self) -> bool { + matches!(self, ClientMessage::Response(_)) + } + + // Returns true if the message is a request type + fn is_request(&self) -> bool { + matches!(self, ClientMessage::Request(_)) + } + + // Returns true if the message is a notification type (i.e., does not expect a response) + fn is_notification(&self) -> bool { + matches!(self, ClientMessage::Notification(_)) + } + + // Returns true if the message represents an error + fn is_error(&self) -> bool { + matches!(self, ClientMessage::Error(_)) + } + + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ClientMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + //**************************// //** ClientJsonrpcRequest **// //**************************// @@ -385,6 +464,43 @@ pub enum ServerMessage { Error(JsonrpcError), } +// Implementing the `MCPMessage` trait for `ServerMessage` +impl MCPMessage for ServerMessage { + // Returns true if the message is a response type + fn is_response(&self) -> bool { + matches!(self, ServerMessage::Response(_)) + } + + // Returns true if the message is a request type + fn is_request(&self) -> bool { + matches!(self, ServerMessage::Request(_)) + } + + // Returns true if the message is a notification type (i.e., does not expect a response) + fn is_notification(&self) -> bool { + matches!(self, ServerMessage::Notification(_)) + } + + // Returns true if the message represents an error + fn is_error(&self) -> bool { + matches!(self, ServerMessage::Error(_)) + } + + // Retrieves the request ID associated with the message, if applicable + fn request_id(&self) -> Option<&RequestId> { + match self { + // If the message is a request, return the associated request ID + ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id), + // Notifications do not have request IDs + ServerMessage::Notification(_) => None, + // If the message is a response, return the associated request ID + ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id), + // If the message is an error, return the associated request ID + ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id), + } + } +} + impl FromStr for ServerMessage { type Err = JsonrpcErrorError; diff --git a/tests/test_serialize.rs b/tests/test_serialize.rs index cd21032..5e5bdf6 100644 --- a/tests/test_serialize.rs +++ b/tests/test_serialize.rs @@ -37,6 +37,14 @@ mod test_serialize { let message: ClientMessage = re_serialize(message); + assert!(message.is_request()); + assert!(!message.is_response()); + assert!(!message.is_notification()); + assert!(!message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); + assert!(matches!(message, ClientMessage::Request(client_message) if matches!(&client_message.request, RequestFromClient::ClientRequest(client_request) if matches!(client_request, ClientRequest::InitializeRequest(_))) @@ -216,6 +224,14 @@ mod test_serialize { let message: ClientMessage = re_serialize(message); + assert!(!message.is_request()); + assert!(message.is_response()); + assert!(!message.is_notification()); + assert!(!message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); + assert!(matches!(message, ClientMessage::Response(client_message) if matches!(&client_message.result, ResultFromClient::ClientResult(client_result) if matches!( client_result, ClientResult::CreateMessageResult(_)) @@ -259,6 +275,14 @@ mod test_serialize { let message: ServerMessage = re_serialize(message); + assert!(!message.is_request()); + assert!(message.is_response()); + assert!(!message.is_notification()); + assert!(!message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); + assert!(matches!(message, ServerMessage::Response(server_message) if matches!(&server_message.result, ResultFromServer::ServerResult(server_result) if matches!(server_result, ServerResult::InitializeResult(_))) @@ -387,6 +411,13 @@ mod test_serialize { let message: ClientMessage = re_serialize(message); + assert!(!message.is_request()); + assert!(!message.is_response()); + assert!(message.is_notification()); + assert!(!message.is_error()); + + assert!(message.request_id().is_none()); + assert!(matches!(message, ClientMessage::Notification(client_message) if matches!(&client_message.notification,NotificationFromClient::ClientNotification(client_notification) if matches!( client_notification, ClientNotification::InitializedNotification(_))) @@ -474,6 +505,13 @@ mod test_serialize { let message: ServerMessage = re_serialize(message); + assert!(!message.is_request()); + assert!(!message.is_response()); + assert!(message.is_notification()); + assert!(!message.is_error()); + + assert!(message.request_id().is_none()); + assert!(matches!(message, ServerMessage::Notification(client_message) if matches!(&client_message.notification,NotificationFromServer::ServerNotification(client_notification) if matches!( client_notification, ServerNotification::CancelledNotification(_))) @@ -518,6 +556,14 @@ mod test_serialize { let message: ServerMessage = re_serialize(message); + assert!(message.is_request()); + assert!(!message.is_response()); + assert!(!message.is_notification()); + assert!(!message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); + assert!(matches!(message, ServerMessage::Request(server_message) if matches!(&server_message.request,RequestFromServer::ServerRequest(server_request) if matches!( server_request, ServerRequest::CreateMessageRequest(_))) @@ -559,6 +605,13 @@ mod test_serialize { let message: ClientMessage = re_serialize(message); assert!(matches!(message, ClientMessage::Error(_))); + assert!(!message.is_request()); + assert!(!message.is_response()); + assert!(!message.is_notification()); + assert!(message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); let message: ServerMessage = ServerMessage::Error(JsonrpcError::create( RequestId::Integer(15), @@ -570,6 +623,14 @@ mod test_serialize { let message: ServerMessage = re_serialize(message); assert!(matches!(message, ServerMessage::Error(_))); + + assert!(!message.is_request()); + assert!(!message.is_response()); + assert!(!message.is_notification()); + assert!(message.is_error()); + assert!( + matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15)) + ); } /* ---------------------- JsonrpcErrorError ---------------------- */