Skip to content

Commit 1f6288d

Browse files
committed
Add/fix StdioClientTransport tests
Enhance `StdioClientTransportTest` with clean shutdown validation, error handling, and added assertions. Extend timeout for stability.
1 parent 93c4e68 commit 1f6288d

File tree

7 files changed

+288
-6
lines changed

7 files changed

+288
-6
lines changed

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ public class StreamableHttpClientTransport(
105105
resumptionToken: String?,
106106
onResumptionToken: ((String) -> Unit)? = null,
107107
) {
108+
check(initialized.load()) { "Transport is not started" }
108109
logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" }
109110

110111
// If we have a resumption token, reconnect the SSE stream with it
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package io.modelcontextprotocol.kotlin.sdk.client
2+
3+
import io.kotest.assertions.throwables.shouldThrow
4+
import io.kotest.matchers.shouldBe
5+
import io.kotest.matchers.string.shouldContain
6+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
7+
import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
8+
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
9+
import kotlinx.coroutines.delay
10+
import kotlinx.coroutines.test.runTest
11+
import kotlin.test.BeforeTest
12+
import kotlin.test.Test
13+
import kotlin.time.Duration.Companion.milliseconds
14+
15+
abstract class AbstractClientTransportLifecycleTest<T : AbstractTransport> {
16+
17+
protected lateinit var transport: T
18+
19+
@BeforeTest
20+
fun beforeEach() {
21+
transport = createTransport()
22+
}
23+
24+
@Test
25+
fun `should throw when started twice`() = runTest {
26+
transport.start()
27+
28+
val exception = shouldThrow<IllegalStateException> {
29+
transport.start()
30+
}
31+
exception.message shouldContain "already started"
32+
}
33+
34+
@Test
35+
fun `should be idempotent when closed twice`() = runTest {
36+
val transport = createTransport()
37+
38+
transport.start()
39+
transport.close()
40+
41+
// Second close should not throw
42+
transport.close()
43+
}
44+
45+
@Test
46+
fun `should throw when sending before start`() = runTest {
47+
val transport = createTransport()
48+
49+
val exception = shouldThrow<IllegalStateException> {
50+
transport.send(PingRequest().toJSON())
51+
}
52+
exception.message shouldContain "not started"
53+
}
54+
55+
@Test
56+
fun `should throw when sending after close`() = runTest {
57+
val transport = createTransport()
58+
59+
transport.start()
60+
delay(50.milliseconds)
61+
transport.close()
62+
63+
shouldThrow<Exception> {
64+
transport.send(PingRequest().toJSON())
65+
}
66+
}
67+
68+
@Test
69+
fun `should call onClose exactly once`() = runTest {
70+
val transport = createTransport()
71+
72+
var closeCallCount = 0
73+
transport.onClose { closeCallCount++ }
74+
75+
transport.start()
76+
delay(50.milliseconds)
77+
78+
// Multiple close attempts
79+
transport.close()
80+
transport.close()
81+
82+
closeCallCount shouldBe 1
83+
}
84+
85+
protected abstract fun createTransport(): T
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package io.modelcontextprotocol.kotlin.sdk.client.stdio
2+
3+
import io.modelcontextprotocol.kotlin.sdk.client.AbstractClientTransportLifecycleTest
4+
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
5+
import kotlinx.io.Buffer
6+
import kotlin.test.Ignore
7+
import kotlin.test.Test
8+
9+
class StdioClientTransportLifecycleTest : AbstractClientTransportLifecycleTest<StdioClientTransport>() {
10+
11+
/**
12+
* Dummy method to make IDE treat this class as a test
13+
*/
14+
@Test
15+
@Ignore
16+
fun dummyTest() = Unit
17+
18+
override fun createTransport(): StdioClientTransport {
19+
val inputBuffer = Buffer()
20+
val outputBuffer = Buffer()
21+
return StdioClientTransport(
22+
input = inputBuffer,
23+
output = outputBuffer,
24+
)
25+
}
26+
}

kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt renamed to kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package io.modelcontextprotocol.kotlin.sdk.client
1+
package io.modelcontextprotocol.kotlin.sdk.client.streamable.http
22

33
import io.ktor.client.HttpClient
44
import io.ktor.client.engine.mock.MockEngine
@@ -12,6 +12,8 @@ import io.ktor.http.HttpStatusCode
1212
import io.ktor.http.content.TextContent
1313
import io.ktor.http.headersOf
1414
import io.ktor.utils.io.ByteReadChannel
15+
import io.modelcontextprotocol.kotlin.sdk.client.Client
16+
import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport
1517
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
1618
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
1719
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package io.modelcontextprotocol.kotlin.sdk.client.streamable.http
2+
3+
import io.ktor.client.HttpClient
4+
import io.ktor.client.engine.mock.MockEngine
5+
import io.ktor.client.engine.mock.respond
6+
import io.ktor.client.plugins.sse.SSE
7+
import io.ktor.http.ContentType
8+
import io.ktor.http.HttpHeaders
9+
import io.ktor.http.HttpStatusCode
10+
import io.ktor.http.headersOf
11+
import io.modelcontextprotocol.kotlin.sdk.client.AbstractClientTransportLifecycleTest
12+
import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport
13+
import kotlin.test.Ignore
14+
import kotlin.test.Test
15+
import kotlin.time.Duration.Companion.seconds
16+
17+
class StreamingHttpClientTransportLifecycleTest :
18+
AbstractClientTransportLifecycleTest<StreamableHttpClientTransport>() {
19+
20+
/**
21+
* Dummy method to make IDE treat this class as a test
22+
*/
23+
@Test
24+
@Ignore
25+
fun dummyTest() = Unit
26+
27+
override fun createTransport(): StreamableHttpClientTransport {
28+
val mockEngine = MockEngine {
29+
respond(
30+
"this is not valid json",
31+
status = HttpStatusCode.OK,
32+
headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()),
33+
)
34+
}
35+
val httpClient = HttpClient(mockEngine) {
36+
install(SSE) {
37+
reconnectionTime = 1.seconds
38+
}
39+
}
40+
41+
return StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp")
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package io.modelcontextprotocol.kotlin.sdk.client.stdio
2+
3+
import io.kotest.matchers.booleans.shouldBeFalse
4+
import io.kotest.matchers.shouldBe
5+
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
6+
import kotlinx.coroutines.delay
7+
import kotlinx.coroutines.test.runTest
8+
import kotlinx.io.Buffer
9+
import kotlinx.io.writeString
10+
import kotlin.concurrent.atomics.AtomicBoolean
11+
import kotlin.concurrent.atomics.ExperimentalAtomicApi
12+
import kotlin.test.Test
13+
import kotlin.time.Duration.Companion.milliseconds
14+
15+
/**
16+
* Tests for StdioClientTransport error handling: EOF, IO errors, and edge cases.
17+
*/
18+
class StdioClientTransportErrorHandlingTest {
19+
20+
private lateinit var transport: StdioClientTransport
21+
22+
@OptIn(ExperimentalAtomicApi::class)
23+
@Test
24+
fun `should continue on stderr EOF`() = runTest {
25+
val stderrBuffer = Buffer()
26+
// Empty stderr = immediate EOF
27+
28+
val inputBuffer = Buffer()
29+
inputBuffer.writeString("""data: {"jsonrpc":"2.0","method":"ping","id":1}\n\n""")
30+
val outputBuffer = Buffer()
31+
32+
transport = StdioClientTransport(
33+
input = inputBuffer,
34+
output = outputBuffer,
35+
error = stderrBuffer,
36+
)
37+
38+
val closeCalled = AtomicBoolean(false)
39+
transport.onClose { closeCalled.store(true) }
40+
41+
transport.start()
42+
delay(200.milliseconds)
43+
44+
// Stderr EOF should not close transport
45+
closeCalled.load() shouldBe false
46+
47+
transport.close()
48+
closeCalled.load() shouldBe true
49+
}
50+
51+
@Test
52+
fun `should call onClose exactly once on error scenarios`() = runTest {
53+
val stderrBuffer = Buffer()
54+
stderrBuffer.write("FATAL: critical error\n".encodeToByteArray())
55+
56+
val inputBuffer = Buffer()
57+
val outputBuffer = Buffer()
58+
59+
var closeCallCount = 0
60+
61+
transport = StdioClientTransport(
62+
input = inputBuffer,
63+
output = outputBuffer,
64+
error = stderrBuffer,
65+
classifyStderr = { StdioClientTransport.StderrSeverity.FATAL },
66+
)
67+
68+
transport.onClose { closeCallCount++ }
69+
70+
transport.start()
71+
delay(100.milliseconds)
72+
73+
// Explicit close after error already closed it
74+
transport.close()
75+
76+
closeCallCount shouldBe 1
77+
}
78+
79+
@Test
80+
fun `should handle empty input gracefully`() = runTest {
81+
val inputBuffer = Buffer()
82+
val outputBuffer = Buffer()
83+
84+
transport = StdioClientTransport(
85+
input = inputBuffer,
86+
output = outputBuffer,
87+
)
88+
89+
var errorCalled = false
90+
transport.onError { errorCalled = true }
91+
92+
transport.start()
93+
delay(100.milliseconds)
94+
95+
// Empty input should close cleanly without error
96+
errorCalled.shouldBeFalse()
97+
}
98+
}

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package io.modelcontextprotocol.kotlin.sdk.client
33
import io.modelcontextprotocol.kotlin.sdk.shared.BaseTransportTest
44
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
55
import io.modelcontextprotocol.kotlin.sdk.types.McpException
6+
import kotlinx.coroutines.delay
67
import kotlinx.coroutines.runBlocking
78
import kotlinx.coroutines.test.runTest
89
import kotlinx.io.asSink
@@ -12,12 +13,18 @@ import org.junit.jupiter.api.Test
1213
import org.junit.jupiter.api.Timeout
1314
import org.junit.jupiter.api.assertThrows
1415
import java.util.concurrent.TimeUnit
15-
16-
@Timeout(20, unit = TimeUnit.SECONDS)
16+
import kotlin.concurrent.atomics.AtomicBoolean
17+
import kotlin.concurrent.atomics.ExperimentalAtomicApi
18+
import kotlin.test.assertFalse
19+
import kotlin.test.assertTrue
20+
import kotlin.test.fail
21+
import kotlin.time.Duration.Companion.milliseconds
22+
import kotlin.time.Duration.Companion.seconds
23+
24+
@Timeout(30, unit = TimeUnit.SECONDS)
1725
class StdioClientTransportTest : BaseTransportTest() {
1826

1927
@Test
20-
@Timeout(30, unit = TimeUnit.SECONDS)
2128
fun `handle stdio error`(): Unit = runBlocking {
2229
val processBuilder = if (System.getProperty("os.name").lowercase().contains("win")) {
2330
ProcessBuilder("cmd", "/c", "pause 1 && echo simulated error 1>&2 && exit 1")
@@ -37,7 +44,7 @@ class StdioClientTransportTest : BaseTransportTest() {
3744
error = stderr,
3845
) {
3946
println("💥Ah-oh!, error: \"$it\"")
40-
true
47+
StdioClientTransport.StderrSeverity.FATAL
4148
}
4249

4350
val client = Client(
@@ -55,6 +62,7 @@ class StdioClientTransportTest : BaseTransportTest() {
5562
process.destroyForcibly()
5663
}
5764

65+
@OptIn(ExperimentalAtomicApi::class)
5866
@Test
5967
fun `should start then close cleanly`() = runTest {
6068
// Run process "/usr/bin/tee"
@@ -63,15 +71,33 @@ class StdioClientTransportTest : BaseTransportTest() {
6371

6472
val input = process.inputStream.asSource().buffered()
6573
val output = process.outputStream.asSink().buffered()
74+
val error = process.errorStream.asSource().buffered()
6675

6776
val transport = StdioClientTransport(
6877
input = input,
6978
output = output,
79+
error = error,
7080
)
7181

72-
testTransportOpenClose(transport)
82+
transport.onError { error ->
83+
fail("Unexpected error: $error")
84+
}
85+
86+
val didClose = AtomicBoolean(false)
87+
transport.onClose { didClose.store(true) }
88+
89+
transport.start()
90+
delay(1.seconds)
7391

92+
assertFalse(didClose.load(), "Transport should not be closed immediately after start")
93+
94+
// Destroy process BEFORE close() to unblock stdin reader
7495
process.destroyForcibly()
96+
delay(100.milliseconds) // Give time for EOF to propagate
97+
98+
transport.close()
99+
100+
assertTrue(didClose.load(), "Transport should be closed after close() call")
75101
}
76102

77103
@Test

0 commit comments

Comments
 (0)