diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt index e7061073..6ae6ea04 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt @@ -1,7 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.testing.MockTransport import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.boolean @@ -31,7 +36,24 @@ class ClientMetaParameterTest { @BeforeTest fun setup() = runTest { - mockTransport = MockTransport() + mockTransport = MockTransport { + // configure mock transport behavior + onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ) + } + onMessageReplyResult(Method.Defined.ToolsCall) { + CallToolResult( + content = listOf(), + isError = false, + ) + } + } client = Client(clientInfo = clientInfo) mockTransport.setupInitializationResponse() client.connect(mockTransport) diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt deleted file mode 100644 index c987619d..00000000 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt +++ /dev/null @@ -1,94 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.client - -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest -import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.shared.Transport -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock - -class MockTransport : Transport { - private val _sentMessages = mutableListOf() - private val _receivedMessages = mutableListOf() - private val mutex = Mutex() - - suspend fun getSentMessages() = mutex.withLock { _sentMessages.toList() } - suspend fun getReceivedMessages() = mutex.withLock { _receivedMessages.toList() } - - private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null - private var onCloseBlock: (() -> Unit)? = null - private var onErrorBlock: ((Throwable) -> Unit)? = null - - override suspend fun start() = Unit - - override suspend fun send(message: JSONRPCMessage) { - mutex.withLock { - _sentMessages += message - } - - // Auto-respond to initialization and tool calls - when (message) { - is JSONRPCRequest -> { - when (message.method) { - "initialize" -> { - val initResponse = JSONRPCResponse( - id = message.id, - result = InitializeResult( - protocolVersion = "2024-11-05", - capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(listChanged = null), - ), - serverInfo = Implementation("mock-server", "1.0.0"), - ), - ) - onMessageBlock?.invoke(initResponse) - } - - "tools/call" -> { - val toolResponse = JSONRPCResponse( - id = message.id, - result = CallToolResult( - content = listOf(), - isError = false, - ), - ) - onMessageBlock?.invoke(toolResponse) - } - } - } - - else -> { - // Handle other message types if needed - } - } - } - - override suspend fun close() { - onCloseBlock?.invoke() - } - - override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - onMessageBlock = { message -> - mutex.withLock { - _receivedMessages += message - } - block(message) - } - } - - override fun onClose(block: () -> Unit) { - onCloseBlock = block - } - - override fun onError(block: (Throwable) -> Unit) { - onErrorBlock = block - } - - fun setupInitializationResponse() { - // This method helps set up the mock for proper initialization - } -} diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index d4d4b9d5..98d60937 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -3364,3 +3364,24 @@ public final class io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTranspo public static final field MCP_SUBPROTOCOL Ljava/lang/String; } +public class io/modelcontextprotocol/kotlin/sdk/testing/MockTransport : io/modelcontextprotocol/kotlin/sdk/shared/Transport { + public fun ()V + public fun (Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun awaitMessage-ePrTys8 (JJLjava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun awaitMessage-ePrTys8$default (Lio/modelcontextprotocol/kotlin/sdk/testing/MockTransport;JJLjava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getReceivedMessages (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSentMessages (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun onClose (Lkotlin/jvm/functions/Function0;)V + public fun onError (Lkotlin/jvm/functions/Function1;)V + public fun onMessage (Lkotlin/jvm/functions/Function2;)V + public final fun onMessageReply (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;)V + public final fun onMessageReplyError (Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun onMessageReplyError$default (Lio/modelcontextprotocol/kotlin/sdk/testing/MockTransport;Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public final fun onMessageReplyResult (Lio/modelcontextprotocol/kotlin/sdk/Method;Lkotlin/jvm/functions/Function1;)V + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setupInitializationResponse ()V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + diff --git a/kotlin-sdk-core/build.gradle.kts b/kotlin-sdk-core/build.gradle.kts index 68e56185..3e81d3d0 100644 --- a/kotlin-sdk-core/build.gradle.kts +++ b/kotlin-sdk-core/build.gradle.kts @@ -124,6 +124,7 @@ kotlin { implementation(kotlin("test")) implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) + implementation(libs.kotlinx.coroutines.test) } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt new file mode 100644 index 00000000..06f243c7 --- /dev/null +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransport.kt @@ -0,0 +1,213 @@ +package io.modelcontextprotocol.kotlin.sdk.testing + +import io.ktor.util.collections.ConcurrentSet +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestResult +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.time.Clock +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds +import kotlin.time.ExperimentalTime + +private typealias RequestPredicate = (JSONRPCRequest) -> Boolean +private typealias RequestHandler = suspend (JSONRPCRequest) -> JSONRPCResponse + +/** + * A mock transport implementation for testing JSON-RPC communication. + * + * This class simulates transport that can be used to test server and client interactions by + * allowing the registration of handlers for incoming requests and the ability to record + * messages sent and received. + * + * The mock transport supports: + * - Recording all sent and received messages (via `getSentMessages` and `getReceivedMessages`) + * - Registering request handlers that respond to specific message predicates (e.g., by method) + * - Setting up responses that can be either successful or with errors + * - Waiting for specific messages to be received + * + * Note: This class is designed to be used as a test helper and should not be used in production. + */ +@Suppress("TooManyFunctions") +public open class MockTransport(configurer: MockTransport.() -> Unit = {}) : Transport { + private val _sentMessages = mutableListOf() + private val _receivedMessages = mutableListOf() + + private val requestHandlers = ConcurrentSet>() + private val mutex = Mutex() + + public suspend fun getSentMessages(): List = mutex.withLock { _sentMessages.toList() } + + public suspend fun getReceivedMessages(): List = mutex.withLock { _receivedMessages.toList() } + + private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseBlock: (() -> Unit)? = null + private var onErrorBlock: ((Throwable) -> Unit)? = null + + init { + configurer.invoke(this) + } + + override suspend fun start(): Unit = Unit + + override suspend fun send(message: JSONRPCMessage) { + mutex.withLock { + _sentMessages += message + } + + // Auto-respond to using preconfigured request handlers + when (message) { + is JSONRPCRequest -> { + val response = requestHandlers.firstOrNull { + it.first.invoke(message) + }?.second?.invoke(message) + + checkNotNull(response) { + "No request handler found for $message." + } + onMessageBlock?.invoke(response) + } + + else -> { + // TODO("Not implemented yet") + } + } + } + + override suspend fun close() { + onCloseBlock?.invoke() + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageBlock = { message -> + mutex.withLock { + _receivedMessages += message + } + block(message) + } + } + + override fun onClose(block: () -> Unit) { + onCloseBlock = block + } + + override fun onError(block: (Throwable) -> Unit) { + onErrorBlock = block + } + + public fun setupInitializationResponse() { + // This method helps set up the mock for proper initialization + } + + /** + * Registers a handler that will be called when a message matching the given predicate is received. + * + * The handler is expected to return a `RequestResult` which will be used as the response to the request. + * + * @param predicate A predicate that matches the incoming `JSONRPCMessage` + * for which the handler should be triggered. + * @param block A function that processes the incoming `JSONRPCMessage` and returns a `RequestResult` + * to be used as the response. + */ + public fun onMessageReply(predicate: RequestPredicate, block: RequestHandler) { + requestHandlers.add(Pair(predicate, block)) + } + + /** + * Registers a handler for responses to a specific method. + * + * This method allows registering a handler that will be called when a message with the specified method + * is received. The handler is expected to return a `RequestResult` which is the response to the request. + * + * @param method The method (from the `Method` enum) that the handler should respond to. + * @param block A function that processes the incoming `JSONRPCRequest` and returns a `RequestResult`. + * The returned `RequestResult` will be used as the result of the response. + */ + public fun onMessageReplyResult(method: Method, block: (JSONRPCRequest) -> T) { + onMessageReply( + predicate = { + it.method == method.value + }, + block = { + JSONRPCResponse( + id = it.id, + result = block.invoke(it), + ) + }, + ) + } + + /** + * Registers a handler that will be called when a request with the specified method is received + * and an error response is to be generated. + * + * This handler is used to respond to requests with a specific method by returning an error response. + * The handler is triggered when a request message with the given `method` is received. + * + * @param method The method (from the `Method` enum) that the handler should respond to with an error. + * @param block A function that processes the incoming `JSONRPCRequest` and returns a `JSONRPCError` + * to be used as the error response. + * The default block returns an internal error with the message "Expected error". + */ + public fun onMessageReplyError( + method: Method, + block: (JSONRPCRequest) -> JSONRPCError = { + JSONRPCError( + code = ErrorCode.Defined.InternalError, + message = "Expected error", + ) + }, + ) { + onMessageReply( + predicate = { + it.method == method.value + }, + block = { + JSONRPCResponse( + id = it.id, + error = block.invoke(it), + ) + }, + ) + } + + /** + * Waits for a JSON-RPC message that matches the given predicate in the received messages. + * + * @param poolInterval The interval at which the function polls the received messages. Default is 50 milliseconds. + * @param timeout The maximum time to wait for a matching message. Default is 3 seconds. + * @param timeoutMessage The error message to throw when the timeout is reached. + * Default is "No message received matching predicate". + * @param predicate A predicate function that returns true if the message matches the criteria. + * @return The first JSON-RPC message that matches the predicate. + */ + @OptIn(ExperimentalTime::class) + public suspend fun awaitMessage( + poolInterval: Duration = 50.milliseconds, + timeout: Duration = 3.seconds, + timeoutMessage: String = "No message received matching predicate", + predicate: (JSONRPCMessage) -> Boolean, + ): JSONRPCMessage { + val clock = Clock.System + val startTime = clock.now() + val finishTime = startTime + timeout + while (clock.now() < finishTime) { + val found = mutex.withLock { + _receivedMessages.firstOrNull { predicate(it) } + } + if (found != null) { + return found + } + delay(poolInterval) + } + error(timeoutMessage) + } +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt new file mode 100644 index 00000000..5cdc0b62 --- /dev/null +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/MockTransportTest.kt @@ -0,0 +1,575 @@ +package io.modelcontextprotocol.kotlin.sdk.testing + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import kotlinx.coroutines.async +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class MockTransportTest { + + private lateinit var transport: MockTransport + + @BeforeTest + fun beforeTest() { + transport = MockTransport { + // configure mock transport behavior + onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ) + } + } + + // Set up onMessage callback to add messages + transport.onMessage { } + } + + @Test + fun `awaitMessage should return message when predicate matches`() = runTest { + // Trigger the onMessage callback directly via send + launch { + transport.send( + JSONRPCRequest( + id = RequestId.StringId("some-id"), + method = "initialize", + ), + ) + delay(200) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("test-id"), + method = "initialize", + params = buildJsonObject { + put("foo", JsonPrimitive("bar")) + }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("other-id"), + method = "initialize", + ), + ) + } + + // Wait for the auto-response + val message = transport.awaitMessage { + it is JSONRPCResponse && + it.id == RequestId.StringId("test-id") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("test-id"), message.id) + } + + @Test + fun `awaitMessage should timeout when no matching message arrives`() = runTest { + val exception = assertFailsWith { + transport.awaitMessage( + timeout = 100.milliseconds, + timeoutMessage = "Custom timeout message", + ) { false } // Predicate that never matches + } + + assertEquals("Custom timeout message", exception.message) + } + + @Test + fun `awaitMessage should filter messages by predicate`() = runTest { + transport.onMessageReply(predicate = { true }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + // Send multiple messages + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-1"), + method = "test1", + params = buildJsonObject { }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "test2", + params = buildJsonObject { }, + ), + ) + + // Wait for response with specific id - note: no auto-response for non-initialize/tools methods + // So this test will timeout unless we manually trigger a response + // Let's send an initialize to get a response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + val message = transport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-2") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("req-2"), message.id) + } + + @Test + fun `awaitMessage should return first matching message`() = runTest { + // Send initialize request to get auto-response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("init-1"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Wait for any response + val message = transport.awaitMessage { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("init-1"), message.id) + } + + @Test + fun `awaitMessage should handle concurrent access safely`() = runTest { + // Send a message that will trigger auto-response + transport.send( + JSONRPCRequest( + id = RequestId.StringId("concurrent-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Launch multiple concurrent awaitMessage calls + val deferred1 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + val deferred2 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + val deferred3 = async { + transport.awaitMessage { it is JSONRPCResponse } + } + + // All should successfully find the message + val message1 = deferred1.await() + val message2 = deferred2.await() + val message3 = deferred3.await() + + assertNotNull(message1) + assertNotNull(message2) + assertNotNull(message3) + + // All should be the same message + assertTrue(message1 is JSONRPCResponse) + assertTrue(message2 is JSONRPCResponse) + assertTrue(message3 is JSONRPCResponse) + assertEquals(RequestId.StringId("concurrent-test"), message1.id) + assertEquals(RequestId.StringId("concurrent-test"), message2.id) + assertEquals(RequestId.StringId("concurrent-test"), message3.id) + } + + @Test + fun `awaitMessage should wait for message to arrive`() = runTest { + // Launch awaitMessage before message arrives + val deferred = async { + transport.awaitMessage(timeout = 2.seconds) { it is JSONRPCResponse } + } + + // Wait a bit before sending message + delay(100.milliseconds) + + // Now send the message + transport.send( + JSONRPCRequest( + id = RequestId.StringId("delayed"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Should successfully receive it + val message = deferred.await() + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("delayed"), message.id) + } + + @Test + fun `awaitMessage should use custom pool interval`() = runTest { + // Send message + transport.send( + JSONRPCRequest( + id = RequestId.StringId("pool-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Should work with custom pool interval + val message = transport.awaitMessage( + poolInterval = 10.milliseconds, + timeout = 1.seconds, + ) { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + } + + @Test + fun `awaitMessage should handle tools call auto-response`() = runTest { + transport.onMessageReplyResult(Method.Defined.ToolsCall) { + CallToolResult(content = listOf()) + } + + // Send tools/call request + transport.send( + JSONRPCRequest( + id = RequestId.StringId("tool-1"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Should receive auto-response + val message = transport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("tool-1") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("tool-1"), message.id) + } + + @Test + fun `awaitMessage should return existing message immediately`() = runTest { + // Send message first + transport.send( + JSONRPCRequest( + id = RequestId.StringId("existing"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Give it time to be received + delay(50.milliseconds) + + // Now await should return immediately without waiting + val message = transport.awaitMessage( + timeout = 100.milliseconds, + ) { it is JSONRPCResponse } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("existing"), message.id) + } + + @Test + fun `awaitMessage with complex predicate`() = runTest { + transport.onMessageReply(predicate = { true }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + // Send multiple requests + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-1"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + transport.send( + JSONRPCRequest( + id = RequestId.StringId("req-2"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Wait for response with specific criteria + val message = transport.awaitMessage { msg -> + msg is JSONRPCResponse && msg.id == RequestId.StringId("req-2") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("req-2"), message.id) + } + + @Test + fun `onMessageReply should register handler with custom predicate`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler that only responds to requests with "custom" method + customTransport.onMessageReply( + predicate = { request -> request.method == "custom-method" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = EmptyRequestResult(), + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("test-1"), + method = "custom-method", + params = buildJsonObject { }, + ), + ) + + // Verify response was received + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("test-1") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("test-1"), message.id) + assertNotNull(message.result) + } + + @Test + fun `onMessageReply should support multiple handlers with different predicates`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register first handler for "method-a" + customTransport.onMessageReply( + predicate = { it.method == "method-a" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = CallToolResult(content = listOf()), + ) + } + + // Register second handler for "method-b" + customTransport.onMessageReply( + predicate = { it.method == "method-b" }, + ) { request -> + JSONRPCResponse( + id = request.id, + result = EmptyRequestResult(), + ) + } + + // Test first handler + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("req-a"), + method = "method-a", + params = buildJsonObject { }, + ), + ) + + val messageA = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-a") + } + + assertTrue(messageA is JSONRPCResponse) + assertTrue(messageA.result is CallToolResult) + + // Test second handler + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("req-b"), + method = "method-b", + params = buildJsonObject { }, + ), + ) + + val messageB = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("req-b") + } + + assertTrue(messageB is JSONRPCResponse) + assertTrue(messageB.result is EmptyRequestResult) + } + + @Test + fun `onMessageReplyResult should create response with result for matching method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler using onMessageReplyResult + customTransport.onMessageReplyResult(Method.Defined.Initialize) { _ -> + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("test-server", "1.0.0"), + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("init-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Verify response with result + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("init-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("init-test"), message.id) + assertNotNull(message.result) + assertTrue(message.result is InitializeResult) + val result = message.result + assertEquals("2024-11-05", result.protocolVersion) + assertEquals("test-server", result.serverInfo.name) + } + + @Test + fun `onMessageReplyResult should only respond to specified method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register handler only for Initialize + customTransport.onMessageReplyResult(Method.Defined.Initialize) { + InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("test", "1.0"), + ) + } + + // Also register a catch-all handler for other methods + customTransport.onMessageReply(predicate = { it.method != "initialize" }) { + JSONRPCResponse( + id = it.id, + result = EmptyRequestResult(), + ) + } + + // Send non-matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("other-method"), + method = "other-method", + params = buildJsonObject { }, + ), + ) + + // Should get response from catch-all handler + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("other-method") + } + + assertTrue(message is JSONRPCResponse) + assertTrue(message.result is EmptyRequestResult) + } + + @Test + fun `onMessageReplyError should create response with error for matching method`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register error handler with custom error + customTransport.onMessageReplyError(Method.Defined.ToolsCall) { _ -> + io.modelcontextprotocol.kotlin.sdk.JSONRPCError( + code = io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InvalidParams, + message = "Custom error message", + ) + } + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("error-test"), + method = "tools/call", + params = buildJsonObject { }, + ), + ) + + // Verify response with error + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("error-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("error-test"), message.id) + assertNotNull(message.error) + assertEquals(io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InvalidParams, message.error?.code) + assertEquals("Custom error message", message.error.message) + } + + @Test + fun `onMessageReplyError should use default error when block not provided`() = runTest { + val customTransport = MockTransport() + customTransport.onMessage { } + + // Register error handler without custom block (using default) + customTransport.onMessageReplyError(Method.Defined.Initialize) + + // Send matching request + customTransport.send( + JSONRPCRequest( + id = RequestId.StringId("default-error-test"), + method = "initialize", + params = buildJsonObject { }, + ), + ) + + // Verify response with default error + val message = customTransport.awaitMessage { + it is JSONRPCResponse && it.id == RequestId.StringId("default-error-test") + } + + assertNotNull(message) + assertTrue(message is JSONRPCResponse) + assertEquals(RequestId.StringId("default-error-test"), message.id) + assertNotNull(message.error) + assertEquals(io.modelcontextprotocol.kotlin.sdk.ErrorCode.Defined.InternalError, message.error?.code) + assertEquals("Expected error", message.error.message) + } +}