@@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
2020 private readonly RequestHandlers _requestHandlers ;
2121 private readonly NotificationHandlers _notificationHandlers ;
2222
23+ /// <summary>Collection of requests sent on this session and waiting for responses.</summary>
2324 private readonly ConcurrentDictionary < RequestId , TaskCompletionSource < IJsonRpcMessage > > _pendingRequests = [ ] ;
25+ /// <summary>
26+ /// Collection of requests received on this session and currently being handled. The value provides a <see cref="CancellationTokenSource"/>
27+ /// that can be used to request cancellation of the in-flight handler.
28+ /// </summary>
29+ private readonly ConcurrentDictionary < RequestId , CancellationTokenSource > _handlingRequests = new ( ) ;
2430 private readonly JsonSerializerOptions _jsonOptions ;
2531 private readonly ILogger _logger ;
2632
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
6975 {
7076 _logger . TransportMessageRead ( EndpointName , message . GetType ( ) . Name ) ;
7177
72- // Fire and forget the message handling task to avoid blocking the transport
73- // If awaiting the task, the transport will not be able to read more messages,
74- // which could lead to a deadlock if the handler sends a message back
7578 _ = ProcessMessageAsync ( ) ;
7679 async Task ProcessMessageAsync ( )
7780 {
81+ IJsonRpcMessageWithId ? messageWithId = message as IJsonRpcMessageWithId ;
82+ CancellationTokenSource ? combinedCts = null ;
83+ try
84+ {
85+ // Register before we yield, so that the tracking is guaranteed to be there
86+ // when subsequent messages arrive, even if the asynchronous processing happens
87+ // out of order.
88+ if ( messageWithId is not null )
89+ {
90+ combinedCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
91+ _handlingRequests [ messageWithId . Id ] = combinedCts ;
92+ }
93+
94+ // Fire and forget the message handling to avoid blocking the transport
95+ // If awaiting the task, the transport will not be able to read more messages,
96+ // which could lead to a deadlock if the handler sends a message back
97+
7898#if NET
79- await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
99+ await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
80100#else
81- await default ( ForceYielding ) ;
101+ await default ( ForceYielding ) ;
82102#endif
83- try
84- {
85- await HandleMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
103+
104+ // Handle the message.
105+ await HandleMessageAsync ( message , combinedCts ? . Token ?? cancellationToken ) . ConfigureAwait ( false ) ;
86106 }
87107 catch ( Exception ex )
88108 {
89- var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
90- _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
109+ // Only send responses for request errors that aren't user-initiated cancellation.
110+ bool isUserCancellation =
111+ ex is OperationCanceledException &&
112+ ! cancellationToken . IsCancellationRequested &&
113+ combinedCts ? . IsCancellationRequested is true ;
114+
115+ if ( ! isUserCancellation && message is JsonRpcRequest request )
116+ {
117+ _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
118+ await _transport . SendMessageAsync ( new JsonRpcError
119+ {
120+ Id = request . Id ,
121+ JsonRpc = "2.0" ,
122+ Error = new JsonRpcErrorDetail
123+ {
124+ Code = ErrorCodes . InternalError ,
125+ Message = ex . Message
126+ }
127+ } , cancellationToken ) . ConfigureAwait ( false ) ;
128+ }
129+ else if ( ex is not OperationCanceledException )
130+ {
131+ var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
132+ _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
133+ }
134+ }
135+ finally
136+ {
137+ if ( messageWithId is not null )
138+ {
139+ _handlingRequests . TryRemove ( messageWithId . Id , out _ ) ;
140+ combinedCts ! . Dispose ( ) ;
141+ }
91142 }
92143 }
93144 }
@@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123174
124175 private async Task HandleNotification ( JsonRpcNotification notification )
125176 {
177+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178+ if ( notification . Method == NotificationMethods . CancelledNotification )
179+ {
180+ try
181+ {
182+ if ( GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
183+ _handlingRequests . TryGetValue ( cn . RequestId , out var cts ) )
184+ {
185+ await cts . CancelAsync ( ) . ConfigureAwait ( false ) ;
186+ _logger . RequestCanceled ( cn . RequestId , cn . Reason ) ;
187+ }
188+ }
189+ catch
190+ {
191+ // "Invalid cancellation notifications SHOULD be ignored"
192+ }
193+ }
194+
195+ // Handle user-defined notifications.
126196 if ( _notificationHandlers . TryGetValue ( notification . Method , out var handlers ) )
127197 {
128198 foreach ( var notificationHandler in handlers )
@@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161231 {
162232 if ( _requestHandlers . TryGetValue ( request . Method , out var handler ) )
163233 {
164- try
165- {
166- _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
167- var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
168- _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
169- await _transport . SendMessageAsync ( new JsonRpcResponse
170- {
171- Id = request . Id ,
172- JsonRpc = "2.0" ,
173- Result = result
174- } , cancellationToken ) . ConfigureAwait ( false ) ;
175- }
176- catch ( Exception ex )
234+ _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
235+ var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
236+ _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
237+ await _transport . SendMessageAsync ( new JsonRpcResponse
177238 {
178- _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
179- // Send error response
180- await _transport . SendMessageAsync ( new JsonRpcError
181- {
182- Id = request . Id ,
183- JsonRpc = "2.0" ,
184- Error = new JsonRpcErrorDetail
185- {
186- Code = - 32000 , // Implementation defined error
187- Message = ex . Message
188- }
189- } , cancellationToken ) . ConfigureAwait ( false ) ;
190- }
239+ Id = request . Id ,
240+ JsonRpc = "2.0" ,
241+ Result = result
242+ } , cancellationToken ) . ConfigureAwait ( false ) ;
191243 }
192244 else
193245 {
@@ -273,7 +325,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273325 }
274326 }
275327
276- public Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
328+ public async Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
277329 {
278330 Throw . IfNull ( message ) ;
279331
@@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288340 _logger . SendingMessage ( EndpointName , JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ) ;
289341 }
290342
291- return _transport . SendMessageAsync ( message , cancellationToken ) ;
343+ await _transport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
344+
345+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
346+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
347+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
348+ if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
349+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
350+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
351+ {
352+ tcs . TrySetCanceled ( default ) ;
353+ }
354+ }
355+
356+ private static CancelledNotification ? GetCancelledNotificationParams ( object ? notificationParams )
357+ {
358+ try
359+ {
360+ switch ( notificationParams )
361+ {
362+ case null :
363+ return null ;
364+
365+ case CancelledNotification cn :
366+ return cn ;
367+
368+ case JsonElement je :
369+ return JsonSerializer . Deserialize ( je , McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
370+
371+ default :
372+ return JsonSerializer . Deserialize (
373+ JsonSerializer . Serialize ( notificationParams , McpJsonUtilities . DefaultOptions . GetTypeInfo < object ? > ( ) ) ,
374+ McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
375+ }
376+ }
377+ catch
378+ {
379+ return null ;
380+ }
292381 }
293382
294383 public void Dispose ( )
0 commit comments