Skip to content

Commit c8a4a86

Browse files
committed
fixup! Add MCP conformance test coverage
1 parent 65b244f commit c8a4a86

File tree

1 file changed

+66
-68
lines changed
  • kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance

1 file changed

+66
-68
lines changed

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt

Lines changed: 66 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents
4040
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
4141
import kotlinx.coroutines.CancellationException
4242
import kotlinx.coroutines.CompletableDeferred
43+
import kotlinx.coroutines.CoroutineScope
44+
import kotlinx.coroutines.Dispatchers
4345
import kotlinx.coroutines.channels.Channel
46+
import kotlinx.coroutines.launch
4447
import kotlinx.coroutines.runBlocking
4548
import kotlinx.coroutines.withTimeoutOrNull
4649
import kotlinx.serialization.json.Json
@@ -60,6 +63,10 @@ private val logger = KotlinLogging.logger {}
6063
private val serverTransports = ConcurrentHashMap<String, HttpServerTransport>()
6164
private val jsonFormat = Json { ignoreUnknownKeys = true }
6265

66+
private const val SESSION_CREATION_TIMEOUT_MS = 2000L
67+
private const val REQUEST_TIMEOUT_MS = 10_000L
68+
private const val MESSAGE_QUEUE_CAPACITY = 256
69+
6370
private fun isInitializeRequest(json: JsonElement): Boolean =
6471
json is JsonObject && json["method"]?.jsonPrimitive?.contentOrNull == "initialize"
6572

@@ -72,15 +79,15 @@ fun main(args: Array<String>) {
7279
routing {
7380
get("/mcp") {
7481
val sessionId = call.request.header("mcp-session-id")
75-
if (sessionId == null) {
76-
call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header")
77-
return@get
78-
}
82+
?: run {
83+
call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header")
84+
return@get
85+
}
7986
val transport = serverTransports[sessionId]
80-
if (transport == null) {
81-
call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id")
82-
return@get
83-
}
87+
?: run {
88+
call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id")
89+
return@get
90+
}
8491
transport.stream(call)
8592
}
8693

@@ -99,92 +106,77 @@ fun main(args: Array<String>) {
99106
HttpStatusCode.BadRequest,
100107
jsonFormat.encodeToString(
101108
JsonObject.serializer(),
102-
JsonObject(
103-
mapOf(
104-
"jsonrpc" to JsonPrimitive("2.0"),
105-
"error" to JsonObject(
106-
mapOf(
107-
"code" to JsonPrimitive(-32700),
108-
"message" to JsonPrimitive("Parse error: ${e.message}"),
109-
),
110-
),
111-
"id" to JsonNull,
112-
),
113-
),
114-
),
109+
buildJsonObject {
110+
put("jsonrpc", "2.0")
111+
put("error", buildJsonObject {
112+
put("code", -32700)
113+
put("message", "Parse error: ${e.message}")
114+
})
115+
put("id", JsonNull)
116+
}
117+
)
115118
)
116119
return@post
117120
}
118121

119-
if (sessionId != null && serverTransports.containsKey(sessionId)) {
122+
val transport = sessionId?.let { serverTransports[it] }
123+
if (transport != null) {
120124
logger.debug { "Using existing transport for session: $sessionId" }
121-
val transport = serverTransports[sessionId]!!
122125
transport.handleRequest(call, jsonElement)
123126
} else {
124127
if (isInitializeRequest(jsonElement)) {
125128
val newSessionId = UUID.randomUUID().toString()
126129
logger.info { "Creating new session with ID: $newSessionId" }
127130

128-
val transport = HttpServerTransport(newSessionId)
129-
serverTransports[newSessionId] = transport
131+
val newTransport = HttpServerTransport(newSessionId)
132+
serverTransports[newSessionId] = newTransport
130133

131134
val mcpServer = createConformanceServer()
132135
call.response.header("mcp-session-id", newSessionId)
133136

134137
val sessionReady = CompletableDeferred<Unit>()
135-
Thread {
136-
runBlocking {
137-
try {
138-
mcpServer.createSession(transport)
139-
sessionReady.complete(Unit)
140-
} catch (e: Exception) {
141-
logger.error(e) { "Failed to create session" }
142-
sessionReady.completeExceptionally(e)
143-
}
138+
CoroutineScope(Dispatchers.IO).launch {
139+
try {
140+
mcpServer.createSession(newTransport)
141+
sessionReady.complete(Unit)
142+
} catch (e: Exception) {
143+
logger.error(e) { "Failed to create session" }
144+
sessionReady.completeExceptionally(e)
144145
}
145-
}.start()
146-
147-
runBlocking {
148-
withTimeoutOrNull(2000) {
149-
sessionReady.await()
150-
} ?: logger.warn { "Session creation timed out, proceeding anyway" }
151146
}
152147

153-
transport.handleRequest(call, jsonElement)
148+
withTimeoutOrNull(SESSION_CREATION_TIMEOUT_MS) {
149+
sessionReady.await()
150+
} ?: logger.warn { "Session creation timed out, proceeding anyway" }
151+
152+
newTransport.handleRequest(call, jsonElement)
154153
} else {
155154
logger.warn { "Invalid request: no session ID or not an initialization request" }
156155
call.respond(
157156
HttpStatusCode.BadRequest,
158157
jsonFormat.encodeToString(
159158
JsonObject.serializer(),
160-
JsonObject(
161-
mapOf(
162-
"jsonrpc" to JsonPrimitive("2.0"),
163-
"error" to JsonObject(
164-
mapOf(
165-
"code" to JsonPrimitive(-32000),
166-
"message" to
167-
JsonPrimitive("Bad Request: No valid session ID provided"),
168-
),
169-
),
170-
"id" to JsonNull,
171-
),
172-
),
173-
),
159+
buildJsonObject {
160+
put("jsonrpc", "2.0")
161+
put("error", buildJsonObject {
162+
put("code", -32000)
163+
put("message", "Bad Request: No valid session ID provided")
164+
})
165+
put("id", JsonNull)
166+
}
167+
)
174168
)
175169
}
176170
}
177171
}
178172

179173
delete("/mcp") {
180174
val sessionId = call.request.header("mcp-session-id")
181-
if (sessionId != null && serverTransports.containsKey(sessionId)) {
175+
val transport = sessionId?.let { serverTransports[it] }
176+
if (transport != null) {
182177
logger.info { "Terminating session: $sessionId" }
183-
val transport = serverTransports[sessionId]!!
184178
serverTransports.remove(sessionId)
185-
runBlocking {
186-
transport.close()
187-
}
179+
transport.close()
188180
call.respond(HttpStatusCode.OK)
189181
} else {
190182
logger.warn { "Invalid session termination request: $sessionId" }
@@ -270,7 +262,7 @@ private fun createConformanceServer(): Server {
270262
private class HttpServerTransport(private val sessionId: String) : AbstractTransport() {
271263
private val logger = KotlinLogging.logger {}
272264
private val pendingResponses = ConcurrentHashMap<String, CompletableDeferred<JSONRPCMessage>>()
273-
private val messageQueue = Channel<JSONRPCMessage>(Channel.UNLIMITED)
265+
private val messageQueue = Channel<JSONRPCMessage>(MESSAGE_QUEUE_CAPACITY)
274266

275267
suspend fun stream(call: ApplicationCall) {
276268
logger.debug { "Starting SSE stream for session $sessionId" }
@@ -300,17 +292,20 @@ private class HttpServerTransport(private val sessionId: String) : AbstractTrans
300292

301293
when (message) {
302294
is JSONRPCRequest -> {
303-
val id = message.id.toString()
295+
val idKey = when (val id = message.id) {
296+
is RequestId.NumberId -> id.value.toString()
297+
is RequestId.StringId -> id.value
298+
}
304299
val responseDeferred = CompletableDeferred<JSONRPCMessage>()
305-
pendingResponses[id] = responseDeferred
300+
pendingResponses[idKey] = responseDeferred
306301

307302
_onMessage.invoke(message)
308303

309-
val response = withTimeoutOrNull(10_000) { responseDeferred.await() }
304+
val response = withTimeoutOrNull(REQUEST_TIMEOUT_MS) { responseDeferred.await() }
310305
if (response != null) {
311306
call.respondText(McpJson.encodeToString(response), ContentType.Application.Json)
312307
} else {
313-
logger.warn { "Timeout for request $id" }
308+
logger.warn { "Timeout for request $idKey" }
314309
call.respondText(
315310
McpJson.encodeToString(
316311
JSONRPCError(
@@ -351,9 +346,12 @@ private class HttpServerTransport(private val sessionId: String) : AbstractTrans
351346
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
352347
when (message) {
353348
is JSONRPCResponse -> {
354-
val id = message.id.toString()
355-
pendingResponses.remove(id)?.complete(message) ?: run {
356-
logger.warn { "No pending response for ID $id, queueing" }
349+
val idKey = when (val id = message.id) {
350+
is RequestId.NumberId -> id.value.toString()
351+
is RequestId.StringId -> id.value
352+
}
353+
pendingResponses.remove(idKey)?.complete(message) ?: run {
354+
logger.warn { "No pending response for ID $idKey, queueing" }
357355
messageQueue.send(message)
358356
}
359357
}

0 commit comments

Comments
 (0)