diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index d4d4b9d5..bf280077 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -1071,16 +1071,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/InitializedNotification$Pa public final class io/modelcontextprotocol/kotlin/sdk/JSONRPCError : io/modelcontextprotocol/kotlin/sdk/JSONRPCMessage { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError$Companion; - public fun (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V - public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode; - public final fun component2 ()Ljava/lang/String; - public final fun component3 ()Lkotlinx/serialization/json/JsonObject; - public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; - public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; + public fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/RequestId; + public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode; + public final fun component3 ()Ljava/lang/String; + public final fun component4 ()Lkotlinx/serialization/json/JsonObject; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError; public fun equals (Ljava/lang/Object;)Z public final fun getCode ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode; public final fun getData ()Lkotlinx/serialization/json/JsonObject; + public final fun getId ()Lio/modelcontextprotocol/kotlin/sdk/RequestId; public final fun getMessage ()Ljava/lang/String; public fun hashCode ()I public fun toString ()Ljava/lang/String; diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 6eedfe62..5e178ed0 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -257,7 +257,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio JSONRPCResponse( id = request.id, error = JSONRPCError( - ErrorCode.Defined.MethodNotFound, + code = ErrorCode.Defined.MethodNotFound, message = "Server does not support ${request.method}", ), ), diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 927967af..bdebf342 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -292,8 +292,12 @@ public sealed interface ErrorCode { * A response to a request that indicates an error occurred. */ @Serializable -public data class JSONRPCError(val code: ErrorCode, val message: String, val data: JsonObject = EmptyJsonObject) : - JSONRPCMessage +public data class JSONRPCError( + val id: RequestId? = null, + val code: ErrorCode, + val message: String, + val data: JsonObject = EmptyJsonObject, +) : JSONRPCMessage /** * Base interface for notification parameters with optional metadata. diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 8ee5af28..57ac05c2 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -1,8 +1,17 @@ +public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore { + public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt { @@ -127,6 +136,24 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String; + public fun ()V + public fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSessionId ()Ljava/lang/String; + public final fun handleDeleteRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleGetRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handlePostRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setOnSessionClosed (Lkotlin/jvm/functions/Function1;)V + public final fun setOnSessionInitialized (Lkotlin/jvm/functions/Function1;)V + public final fun setSessionIdGenerator (Lkotlin/jvm/functions/Function0;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 934ba049..3df29826 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.HttpStatusCode import io.ktor.server.application.Application import io.ktor.server.application.install +import io.ktor.server.request.header import io.ktor.server.response.respond import io.ktor.server.routing.Routing import io.ktor.server.routing.RoutingContext @@ -19,16 +20,20 @@ import kotlinx.atomicfu.atomic import kotlinx.atomicfu.update import kotlinx.collections.immutable.PersistentMap import kotlinx.collections.immutable.toPersistentMap +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport private val logger = KotlinLogging.logger {} -internal class SseTransportManager(transports: Map = emptyMap()) { - private val transports: AtomicRef> = atomic(transports.toPersistentMap()) +internal class TransportManager(transports: Map = emptyMap()) { + private val transports: AtomicRef> = atomic(transports.toPersistentMap()) - fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] + fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId) - fun addTransport(transport: SseServerTransport) { - transports.update { it.put(transport.sessionId, transport) } + fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId] + + fun addTransport(sessionId: String, transport: AbstractTransport) { + transports.update { it.put(sessionId, transport) } } fun removeTransport(sessionId: String) { @@ -48,14 +53,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { */ @KtorDsl public fun Routing.mcp(block: ServerSSESession.() -> Server) { - val sseTransportManager = SseTransportManager() + val transportManager = TransportManager() sse { - mcpSseEndpoint("", sseTransportManager, block) + mcpSseEndpoint("", transportManager, block) } post { - mcpPostEndpoint(sseTransportManager) + mcpPostEndpoint(transportManager) } } @@ -74,18 +79,71 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { } } -internal suspend fun ServerSSESession.mcpSseEndpoint( +/* +* Configures the Ktor Application to handle Model Context Protocol (MCP) over Streamable Http. +* It currently only works with JSON response. +*/ +@KtorDsl +public fun Application.mcpStreamableHttp( + enableDnsRebindingProtection: Boolean = false, + allowedHosts: List? = null, + allowedOrigins: List? = null, + eventStore: EventStore? = null, + block: RoutingContext.() -> Server, +) { + val transportManager = TransportManager() + + routing { + post("/mcp") { + mcpStreamableHttpEndpoint( + transportManager, + enableDnsRebindingProtection, + allowedHosts, + allowedOrigins, + eventStore, + block, + ) + } + } +} + +/* +* Configures the Ktor Application to handle Model Context Protocol (MCP) over stateless Streamable Http. +* It currently only works with JSON response. +*/ +@KtorDsl +public fun Application.mcpStatelessStreamableHttp( + enableDnsRebindingProtection: Boolean = false, + allowedHosts: List? = null, + allowedOrigins: List? = null, + eventStore: EventStore? = null, + block: RoutingContext.() -> Server, +) { + routing { + post("/mcp") { + mcpStatelessStreamableHttpEndpoint( + enableDnsRebindingProtection, + allowedHosts, + allowedOrigins, + eventStore, + block, + ) + } + } +} + +private suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - sseTransportManager: SseTransportManager, + transportManager: TransportManager, block: ServerSSESession.() -> Server, ) { - val transport = mcpSseTransport(postEndpoint, sseTransportManager) + val transport = mcpSseTransport(postEndpoint, transportManager) val server = block() server.onClose { logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - sseTransportManager.removeTransport(transport.sessionId) + transportManager.removeTransport(transport.sessionId) } server.connect(transport) @@ -95,16 +153,98 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( internal fun ServerSSESession.mcpSseTransport( postEndpoint: String, - sseTransportManager: SseTransportManager, + transportManager: TransportManager, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) - sseTransportManager.addTransport(transport) + transportManager.addTransport(transport.sessionId, transport) logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } return transport } -internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) { +internal suspend fun RoutingContext.mcpStreamableHttpEndpoint( + transportManager: TransportManager, + enableDnsRebindingProtection: Boolean = false, + allowedHosts: List? = null, + allowedOrigins: List? = null, + eventStore: EventStore? = null, + block: RoutingContext.() -> Server, +) { + val sessionId = this.call.request.header(MCP_SESSION_ID_HEADER) + val transport = if (sessionId != null && transportManager.hasTransport(sessionId)) { + transportManager.getTransport(sessionId) + } else if (sessionId == null) { + val transport = StreamableHttpServerTransport( + enableDnsRebindingProtection = enableDnsRebindingProtection, + allowedHosts = allowedHosts, + allowedOrigins = allowedOrigins, + eventStore = eventStore, + enableJsonResponse = true, + ) + + transport.setOnSessionInitialized { sessionId -> + transportManager.addTransport(sessionId, transport) + + logger.info { "New StreamableHttp connection established and stored with sessionId: $sessionId" } + } + + val server = block() + server.onClose { + logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } + } + + server.connect(transport) + + transport + } else { + null + } + + if (transport == null) { + this.call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Unknown(-32000), + "Bad Request: No valid session ID provided", + ) + return + } + + (transport as StreamableHttpServerTransport).handleRequest(null, this.call) + logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } +} + +internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( + enableDnsRebindingProtection: Boolean = false, + allowedHosts: List? = null, + allowedOrigins: List? = null, + eventStore: EventStore? = null, + block: RoutingContext.() -> Server, +) { + val transport = StreamableHttpServerTransport( + enableDnsRebindingProtection = enableDnsRebindingProtection, + allowedHosts = allowedHosts, + allowedOrigins = allowedOrigins, + eventStore = eventStore, + enableJsonResponse = true, + ) + transport.setSessionIdGenerator(null) + + logger.info { "New stateless StreamableHttp connection established without sessionId" } + + val server = block() + + server.onClose { + logger.info { "Server connection closed without sessionId" } + } + + server.connect(transport) + + transport.handleRequest(null, this.call) + + logger.debug { "Server connected to transport without sessionId" } +} + +internal suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") return @@ -112,7 +252,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTran logger.debug { "Received message for sessionId: $sessionId" } - val transport = sseTransportManager.getTransport(sessionId) + val transport = transportManager.getTransport(sessionId) as SseServerTransport? if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt new file mode 100644 index 00000000..82f8be0c --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -0,0 +1,576 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.request.contentType +import io.ktor.server.request.header +import io.ktor.server.request.host +import io.ktor.server.request.httpMethod +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondBytes +import io.ktor.server.response.respondNullable +import io.ktor.server.sse.ServerSSESession +import io.ktor.util.collections.ConcurrentMap +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.job +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +internal const val MCP_SESSION_ID_HEADER = "mcp-session-id" +private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" +private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" + +/** + * Interface for resumability support via event storage + */ +public interface EventStore { + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + public suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String + + /** + * Replays events after the specified event ID + * @param lastEventId The last event ID that was received + * @param sender Function to send events + * @return The stream ID for the replayed events + */ + public suspend fun replayEventsAfter( + lastEventId: String, + sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit, + ): String +} + +/** + * A holder for an active request call. + * If enableJsonResponse is true, session is null. + * Otherwise, session is not null. + */ +private data class SessionContext(val session: ServerSSESession?, val call: ApplicationCall) + +/** + * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It supports both SSE streaming and direct HTTP responses. + * + * In stateful mode: + * - Session ID is generated and included in response headers + * - Session ID is always included in initialization responses + * - Requests with invalid session IDs are rejected with 404 Not Found + * - Non-initialization requests without a session ID are rejected with 400 Bad Request + * - State is maintained in-memory (connections, message history) + * + * In stateless mode: + * - No Session ID is included in any responses + * - No session validation is performed + * + * @param enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + * Default is false (SSE streams are preferred). + * @param enableDnsRebindingProtection Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + * @param allowedHosts List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + * @param allowedOrigins List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + * @param eventStore Event store for resumability support + * If provided, resumability will be enabled, allowing clients to reconnect and resume messages + */ +@OptIn(ExperimentalUuidApi::class, ExperimentalAtomicApi::class) +public class StreamableHttpServerTransport( + private val enableJsonResponse: Boolean = false, + private val enableDnsRebindingProtection: Boolean = false, + private val allowedHosts: List? = null, + private val allowedOrigins: List? = null, + private val eventStore: EventStore? = null, +) : AbstractTransport() { + public var sessionId: String? = null + private set + + private var sessionIdGenerator: (() -> String)? = { Uuid.random().toString() } + private var onSessionInitialized: ((sessionId: String) -> Unit)? = null + private var onSessionClosed: ((sessionId: String) -> Unit)? = null + + private val started: AtomicBoolean = AtomicBoolean(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) + + private val streamsMapping: ConcurrentMap = ConcurrentMap() + private val requestToStreamMapping: ConcurrentMap = ConcurrentMap() + private val requestToResponseMapping: ConcurrentMap = ConcurrentMap() + + private val sessionMutex = Mutex() + private val streamMutex = Mutex() + + private companion object { + const val STANDALONE_SSE_STREAM_ID = "_GET_stream" + } + + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure + * (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * Set undefined to disable session management. + */ + public fun setSessionIdGenerator(block: (() -> String)?) { + sessionIdGenerator = block + } + + /** + * A callback for session initialization events + * This is called when the server initializes a new session. + * Useful in cases when you need to register multiple mcp sessions + * and need to keep track of them. + */ + public fun setOnSessionInitialized(block: ((String) -> Unit)?) { + onSessionInitialized = block + } + + /** + * A callback for session close events + * This is called when the server closes a session due to a DELETE request. + * Useful in cases when you need to clean up resources associated with the session. + * Note that this is different from the transport closing, if you are handling + * HTTP requests from multiple nodes you might want to close each + * StreamableHTTPServerTransport after a request is completed while still keeping the + * session open/running. + */ + public fun setOnSessionClosed(block: ((String) -> Unit)?) { + onSessionClosed = block + } + + override suspend fun start() { + check(started.compareAndSet(expectedValue = false, newValue = true)) { + "StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically." + } + } + + override suspend fun send(message: JSONRPCMessage) { + val requestId: RequestId? = when (message) { + is JSONRPCResponse -> message.id + is JSONRPCError -> message.id + else -> null + } + + // Standalone SSE stream + if (requestId == null) { + require(message !is JSONRPCResponse && message !is JSONRPCError) { + "Cannot send a response on a standalone SSE stream unless resuming a previous client request" + } + val standaloneStream = streamsMapping[STANDALONE_SSE_STREAM_ID] ?: return + emitOnStream(STANDALONE_SSE_STREAM_ID, standaloneStream.session!!, message) + return + } + + val streamId = requestToStreamMapping[requestId] + ?: error("No connection established for request ID: $requestId") + val activeStream = streamsMapping[streamId] + + if (!enableJsonResponse) { + activeStream?.let { stream -> + emitOnStream(streamId, stream.session!!, message) + } + } + + val isTerminated = message is JSONRPCResponse || message is JSONRPCError + if (!isTerminated) return + + requestToResponseMapping[requestId] = message + val relatedIds = requestToStreamMapping.filterValues { it == streamId }.keys + + val allResponseReady = relatedIds.all { it in requestToResponseMapping } + if (!allResponseReady) return + + streamMutex.withLock { + if (activeStream == null) error("No connection established for request ID: $requestId") + + if (enableJsonResponse) { + activeStream.call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString()) + sessionId?.let { activeStream.call.response.header(MCP_SESSION_ID_HEADER, it) } + val responses = relatedIds + .mapNotNull { requestToResponseMapping[it] } + .map { McpJson.encodeToString(it) } + val payload = if (responses.size == 1) { + responses.first() + } else { + responses + } + activeStream.call.respond(payload) + } else { + activeStream.session!!.close() + } + + // Clean up + relatedIds.forEach { requestId -> + requestToResponseMapping.remove(requestId) + requestToStreamMapping.remove(requestId) + } + } + } + + override suspend fun close() { + streamMutex.withLock { + streamsMapping.values.forEach { + try { + it.session?.close() + } catch (_: Exception) {} + } + streamsMapping.clear() + requestToResponseMapping.clear() + _onClose() + } + } + + /** + * Handles an incoming HTTP request, whether GET, POST or DELETE + */ + public suspend fun handleRequest(session: ServerSSESession?, call: ApplicationCall) { + validateHeaders(call)?.let { reason -> + call.reject(HttpStatusCode.Forbidden, ErrorCode.Unknown(-32000), reason) + _onError(Error(reason)) + return + } + + when (call.request.httpMethod) { + HttpMethod.Post -> handlePostRequest(session, call) + + HttpMethod.Get -> handleGetRequest(session, call) + + HttpMethod.Delete -> handleDeleteRequest(session, call) + + else -> call.run { + response.header(HttpHeaders.Allow, "GET, POST, DELETE") + reject(HttpStatusCode.MethodNotAllowed, ErrorCode.Unknown(-32000), "Method not allowed.") + } + } + } + + /** + * Handles POST requests containing JSON-RPC messages + */ + public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { + try { + if (!enableJsonResponse && session == null) error("Server session can't be null with json response") + + val acceptHeader = call.request.header(HttpHeaders.Accept) + val isAcceptEventStream = acceptHeader.accepts(ContentType.Text.EventStream) + val isAcceptJson = acceptHeader.accepts(ContentType.Application.Json) + + if (!isAcceptEventStream || !isAcceptJson) { + call.reject( + HttpStatusCode.NotAcceptable, + ErrorCode.Unknown(-32000), + "Not Acceptable: Client must accept both application/json and text/event-stream", + ) + return + } + + if (!call.request.contentType().match(ContentType.Application.Json)) { + call.reject( + HttpStatusCode.UnsupportedMediaType, + ErrorCode.Unknown(-32000), + "Unsupported Media Type: Content-Type must be application/json", + ) + return + } + + val messages = parseBody(call) ?: return + val isInitializationRequest = messages.any { + it is JSONRPCRequest && it.method == Method.Defined.Initialize.value + } + + if (isInitializationRequest) { + if (initialized.load() && sessionId != null) { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Defined.InvalidRequest, + "Invalid Request: Server already initialized", + ) + return + } + if (messages.size > 1) { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Defined.InvalidRequest, + "Invalid Request: Only one initialization request is allowed", + ) + return + } + + sessionMutex.withLock { + if (sessionId != null) return@withLock + sessionId = sessionIdGenerator?.invoke() + initialized.store(true) + sessionId?.let { onSessionInitialized?.invoke(it) } + } + } else { + if (!validateSession(call) || !validateProtocolVersion(call)) return + } + + val hasRequest = messages.any { it is JSONRPCRequest } + if (!hasRequest) { + call.respondBytes(status = HttpStatusCode.Accepted, bytes = ByteArray(0)) + messages.forEach { message -> _onMessage(message) } + return + } + + val streamId = Uuid.random().toString() + if (!enableJsonResponse) { + call.appendSseHeaders() + session!!.send(data = "") // flush headers immediately + } + + streamMutex.withLock { + streamsMapping[streamId] = SessionContext(session, call) + messages.filterIsInstance().forEach { requestToStreamMapping[it.id] = streamId } + } + call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } + + messages.forEach { message -> _onMessage(message) } + } catch (e: Exception) { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Defined.ParseError, + "Parse error: ${e.message}", + ) + _onError(e) + } + } + + public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) { + if (enableJsonResponse) { + call.reject( + HttpStatusCode.MethodNotAllowed, + ErrorCode.Unknown(-32000), + "Method not allowed.", + ) + return + } + session!! + + val acceptHeader = call.request.header(HttpHeaders.Accept) + if (!acceptHeader.accepts(ContentType.Text.EventStream)) { + call.reject( + HttpStatusCode.NotAcceptable, + ErrorCode.Unknown(-32000), + "Not Acceptable: Client must accept text/event-stream", + ) + return + } + + if (!validateSession(call) || !validateProtocolVersion(call)) return + + eventStore?.let { store -> + call.request.header(MCP_RESUMPTION_TOKEN_HEADER)?.let { lastEventId -> + replayEvents(store, lastEventId, session) + return + } + } + + if (STANDALONE_SSE_STREAM_ID in streamsMapping) { + call.reject( + HttpStatusCode.Conflict, + ErrorCode.Unknown(-32000), + "Conflict: Only one SSE stream is allowed per session", + ) + return + } + + call.appendSseHeaders() + session.send(data = "") // flush headers immediately + streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(session, call) + session.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(STANDALONE_SSE_STREAM_ID) } + } + + public suspend fun handleDeleteRequest(session: ServerSSESession?, call: ApplicationCall) { + if (enableJsonResponse) { + call.reject( + HttpStatusCode.MethodNotAllowed, + ErrorCode.Unknown(-32000), + "Method not allowed.", + ) + } + + if (!validateSession(call) || !validateProtocolVersion(call)) return + sessionId?.let { onSessionClosed?.invoke(it) } + close() + call.respondNullable(status = HttpStatusCode.OK, message = null) + } + + private suspend fun replayEvents(store: EventStore, lastEventId: String, session: ServerSSESession) { + val call: ApplicationCall = session.call + + try { + call.appendSseHeaders() + val streamId = store.replayEventsAfter(lastEventId) { eventId, message -> + try { + session.send( + event = "message", + id = eventId, + data = McpJson.encodeToString(message), + ) + } catch (e: Exception) { + _onError(e) + } + } + streamsMapping[streamId] = SessionContext(session, call) + } catch (e: Exception) { + _onError(e) + } + } + + private suspend fun validateSession(call: ApplicationCall): Boolean { + if (sessionIdGenerator == null) return true + + if (!initialized.load()) { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Unknown(-32000), + "Bad Request: Server not initialized", + ) + return false + } + + val headerId = call.request.header(MCP_SESSION_ID_HEADER) + + return when { + headerId == null -> { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Unknown(-32000), + "Bad Request: Mcp-Session-Id header is required", + ) + false + } + + headerId != sessionId -> { + call.reject( + HttpStatusCode.NotFound, + ErrorCode.Unknown(-32001), + "Session not found", + ) + false + } + + else -> true + } + } + + private suspend fun validateProtocolVersion(call: ApplicationCall): Boolean { + val version = call.request.header(MCP_PROTOCOL_VERSION_HEADER) ?: LATEST_PROTOCOL_VERSION + + return when (version) { + !in SUPPORTED_PROTOCOL_VERSIONS -> { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Unknown(-32000), + "Bad Request: Unsupported protocol version (supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.joinToString( + ", ", + ) + })", + ) + false + } + + else -> true + } + } + + private fun validateHeaders(call: ApplicationCall): String? { + if (!enableDnsRebindingProtection) return null + + allowedHosts?.let { hosts -> + val hostHeader = call.request.host().substringBefore(':').lowercase() + if (hostHeader !in hosts.map { it.substringBefore(':').lowercase() }) { + return "Invalid Host header: $hostHeader" + } + } + + allowedOrigins?.let { origins -> + val originHeader = call.request.headers[HttpHeaders.Origin]?.removeSuffix("/")?.lowercase() + if (originHeader !in origins.map { it.removeSuffix("/").lowercase() }) { + return "Invalid Origin header: $originHeader" + } + } + + return null + } + + private suspend fun parseBody(call: ApplicationCall): List? { + val body = call.receiveText() + return when (val element = McpJson.parseToJsonElement(body)) { + is JsonObject -> listOf(McpJson.decodeFromJsonElement(element)) + + is JsonArray -> McpJson.decodeFromJsonElement>(element) + + else -> { + call.reject( + HttpStatusCode.BadRequest, + ErrorCode.Defined.InvalidRequest, + "Invalid Request: unable to parse JSON body", + ) + return null + } + } + } + + private fun String?.accepts(mime: ContentType): Boolean { + if (this == null) return false + + val escaped = Regex.escape(mime.toString()) + val pattern = Regex("""(^|,\s*)$escaped(\s*(;|,|$))""", RegexOption.IGNORE_CASE) + return pattern.containsMatchIn(this) + } + + private suspend fun emitOnStream(streamId: String, session: ServerSSESession, message: JSONRPCMessage) { + val eventId = eventStore?.storeEvent(streamId, message) + try { + session.send(event = "message", id = eventId, data = McpJson.encodeToString(message)) + } catch (_: Exception) { + streamsMapping.remove(streamId) + } + } + + private fun ApplicationCall.appendSseHeaders() { + this.response.headers.append(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + this.response.headers.append(HttpHeaders.CacheControl, "no-cache, no-transform") + this.response.headers.append(HttpHeaders.Connection, "keep-alive") + sessionId?.let { this.response.headers.append(MCP_SESSION_ID_HEADER, it) } + this.response.status(HttpStatusCode.OK) + } +} + +internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: ErrorCode, message: String) { + this.response.status(status) + this.respond( + JSONRPCResponse( + id = RequestId.StringId("server-error"), + error = JSONRPCError(message = message, code = code), + ), + ) +}