From 0f8a61e4d72adc1d36c71bbacfa2b80f0bff0abb Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Wed, 19 Feb 2025 15:23:54 +0800 Subject: [PATCH 1/4] init --- .../DispatcherHelper.cs | 2 +- .../ICallerClientResultsManager.cs | 7 ++++-- .../MultiEndpointMessageWriter.cs | 10 ++++++++- ...MultiEndpointServiceConnectionContainer.cs | 10 +++++++-- .../ServiceConnectionContainerFactory.cs | 7 +++++- ...MultiEndpointConnectionContainerFactory.cs | 9 +++++--- .../ServiceHubContextImpl.cs | 6 ++--- .../CallerClientResultsManager.cs | 12 ++++++++-- .../HubHost/ServiceHubDispatcher.cs | 3 ++- ...EndpointServiceConnectionContainerTests.cs | 2 +- .../MultiEndpointMessageWriterTests.cs | 2 +- .../ClientInvocationManagerTests.cs | 22 ++++++++++++++----- ...EndpointServiceConnectionContainerTests.cs | 6 ++--- 13 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs index d31596727..436f4c719 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs @@ -141,7 +141,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf); } - var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, loggerFactory); + var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, null, loggerFactory); if (hubs?.Count > 0) { diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs index 4be33c706..08140d712 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Threading; @@ -31,5 +31,8 @@ internal interface ICallerClientResultsManager : IClientResultsManager bool TryCompleteResult(string connectionId, ErrorCompletionMessage message); void RemoveInvocation(string invocationId); + + void SetAckNumber(string invocationId, int ackNumber); + } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs index 516496eef..f419ae916 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs @@ -21,11 +21,13 @@ namespace Microsoft.Azure.SignalR; internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager { private readonly ILogger _logger; + private readonly IClientInvocationManager _clientInvocationManager; internal HubServiceEndpoint[] TargetEndpoints { get; } - public MultiEndpointMessageWriter(IReadOnlyCollection targetEndpoints, ILoggerFactory loggerFactory) + public MultiEndpointMessageWriter(IReadOnlyCollection targetEndpoints, IClientInvocationManager invocationManager, ILoggerFactory loggerFactory) { + _clientInvocationManager = invocationManager; _logger = loggerFactory.CreateLogger(); var normalized = new List(); if (targetEndpoints != null) @@ -52,6 +54,12 @@ public MultiEndpointMessageWriter(IReadOnlyCollection targetEnd public Task WriteAsync(ServiceMessage serviceMessage) { + if (serviceMessage is ClientInvocationMessage invocationMessage) + { + // Accroding to target endpoints in method `WriteMultiEndpointMessageAsync` + _clientInvocationManager.Caller.SetAckNumber(invocationMessage.InvocationId, TargetEndpoints.Length); + } + return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage)); } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index e90ffb4f9..fc34434af 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -30,6 +30,8 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta private readonly object _lock = new object(); + private readonly IClientInvocationManager _clientInvocationManager; + private (bool needRouter, IReadOnlyList endpoints) _routerEndpoints; private int _started; @@ -56,6 +58,7 @@ public MultiEndpointServiceConnectionContainer( int? maxCount, IServiceEndpointManager endpointManager, IMessageRouter router, + IClientInvocationManager clientInvocationManager, ILoggerFactory loggerFactory, TimeSpan? scaleTimeout = null ) : this( @@ -63,6 +66,7 @@ public MultiEndpointServiceConnectionContainer( endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, maxCount, loggerFactory), endpointManager, router, + clientInvocationManager, loggerFactory, scaleTimeout) { @@ -73,6 +77,7 @@ internal MultiEndpointServiceConnectionContainer( Func generator, IServiceEndpointManager endpointManager, IMessageRouter router, + IClientInvocationManager clientInvocationManager, ILoggerFactory loggerFactory, TimeSpan? scaleTimeout = null) { @@ -90,6 +95,7 @@ internal MultiEndpointServiceConnectionContainer( _loggerFactory = loggerFactory; _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); _serviceEndpointManager = endpointManager; + _clientInvocationManager = clientInvocationManager; _scaleTimeout = scaleTimeout ?? Constants.Periods.DefaultScaleTimeout; // Reserve generator for potential scale use. @@ -158,7 +164,7 @@ public Task WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel public IAsyncEnumerable ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default) { var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints; - var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory); + var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _clientInvocationManager, _loggerFactory); return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token); } @@ -271,7 +277,7 @@ private static IServiceConnectionContainer CreateContainer(IServiceConnectionFac private MultiEndpointMessageWriter CreateMessageWriter(ServiceMessage serviceMessage) { var targetEndpoints = GetRoutedEndpoints(serviceMessage)?.ToList(); - return new MultiEndpointMessageWriter(targetEndpoints, _loggerFactory); + return new MultiEndpointMessageWriter(targetEndpoints, _clientInvocationManager, _loggerFactory); } private void OnAdd(HubServiceEndpoint endpoint) diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs index 1f87e2bd5..c4aa18f67 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -18,16 +18,20 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa private readonly IServiceConnectionFactory _serviceConnectionFactory; + private readonly IClientInvocationManager _clientInvocationManager; + public ServiceConnectionContainerFactory(IServiceConnectionFactory serviceConnectionFactory, IServiceEndpointManager serviceEndpointManager, IMessageRouter router, IServiceEndpointOptions options, + IClientInvocationManager clientInvocationManager, ILoggerFactory loggerFactory) { _serviceConnectionFactory = serviceConnectionFactory; _serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)); _router = router ?? throw new ArgumentNullException(nameof(router)); _options = options; + _clientInvocationManager = clientInvocationManager; _loggerFactory = loggerFactory; } @@ -39,6 +43,7 @@ public IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTime _options.MaxHubServerConnectionCount, _serviceEndpointManager, _router, + _clientInvocationManager, _loggerFactory, serviceScaleTimeout); } diff --git a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs index a1b1f2af1..1a68bf642 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using Microsoft.Azure.SignalR.Common; @@ -14,14 +14,16 @@ internal class MultiEndpointConnectionContainerFactory private readonly IServiceEndpointManager _endpointManager; private readonly int _connectionCount; private readonly IEndpointRouter _router; + private readonly IClientInvocationManager _clientInvocationManager; - public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions options, IEndpointRouter router = null) + public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions options, IEndpointRouter router = null, IClientInvocationManager clientInvocationManager = null) { _connectionFactory = connectionFactory; _loggerFactory = loggerFactory; _endpointManager = serviceEndpointManager; _connectionCount = options.Value.ConnectionCount; _router = router; + _clientInvocationManager = clientInvocationManager; } public MultiEndpointServiceConnectionContainer Create(string hubName) @@ -31,8 +33,9 @@ public MultiEndpointServiceConnectionContainer Create(string hubName) endpoint => new WeakServiceConnectionContainer(_connectionFactory, _connectionCount, endpoint, _loggerFactory.CreateLogger()), _endpointManager, _router, + _clientInvocationManager, _loggerFactory); return container; } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs b/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs index 147399dcf..59c9a4aa6 100644 --- a/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs +++ b/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -96,7 +96,7 @@ public override ServiceHubContext WithEndpoints(IEnumerable end private sealed class MessageWriterServiceContainerWrapper : MultiEndpointMessageWriter, IServiceConnectionContainer { public MessageWriterServiceContainerWrapper(IReadOnlyCollection targetEndpoints, ILoggerFactory loggerFactory) - : base(targetEndpoints, loggerFactory) { } + : base(targetEndpoints, null, loggerFactory) { } public Task StartAsync() => Task.CompletedTask; @@ -125,4 +125,4 @@ public void Dispose() #endregion Not supported method or properties } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index ed86f9534..4652acabb 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. #if NET7_0_OR_GREATER using System; @@ -206,6 +206,14 @@ public void RemoveInvocation(string invocationId) _pendingInvocations.TryRemove(invocationId, out _); } + public void SetAckNumber(string invocationId, int ackNumber) + { + if (_pendingInvocations.TryGetValue(invocationId, out var item)) + { + _ackHandler.SetExpectedCount(item.AckId, ackNumber); + } + } + // Unused, here to honor the IInvocationBinder interface but should never be called public IReadOnlyList GetParameterTypes(string methodName) => throw new NotImplementedException(); @@ -218,4 +226,4 @@ private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int } } } -#endif \ No newline at end of file +#endif diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs index fe9810363..b06da9f4b 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -195,6 +195,7 @@ private IServiceConnectionContainer GetServiceConnectionContainer(ConnectionDele _serviceEndpointManager, _router, _options, + _clientInvocationManager, _loggerFactory ); } diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs index 0632afc5c..d1652828b 100644 --- a/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -486,7 +486,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub, Func generator, IServiceEndpointManager endpoint, IEndpointRouter router, - ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, loggerFactory) + ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, null, loggerFactory) { } } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnections/MultiEndpointMessageWriterTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnections/MultiEndpointMessageWriterTests.cs index 19d52c685..157d53b39 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnections/MultiEndpointMessageWriterTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnections/MultiEndpointMessageWriterTests.cs @@ -36,7 +36,7 @@ public async Task ListConnectionsInGroup(int? top, int resultCount, params int?[ endpoint.ConnectionContainer = containerMock.Object; targetEndpoints.Add(endpoint); } - var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, Mock.Of()); + var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, null, Mock.Of()); var resultMembers = new List(); await foreach (var member in multiEndpointWriter.ListConnectionsInGroupAsync("group", top)) { diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs index 93cb436b5..4300dd64b 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs @@ -36,13 +36,18 @@ public class ClientInvocationManagerTests private const string SuccessCompleteResult = "success-result"; private const string ErrorCompleteResult = "error-result"; - private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1) + private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1, int badEndpointsCount = 0) { var services = new ServiceCollection(); var endpoints = Enumerable.Range(0, endpointCount) .Select(i => new ServiceEndpoint($"Endpoint=https://test{i}connectionstring;AccessKey=1")) .ToArray(); + for (var i = 0; i < badEndpointsCount && i < endpointCount; i++) + { + endpoints[i].Online = false; + } + var config = new ConfigurationBuilder().Build(); var serviceProvider = services.AddLogging() @@ -175,14 +180,19 @@ public void TestCallerManagerCancellation() } [Theory] - [InlineData(true, 2)] - [InlineData(false, 2)] - [InlineData(true, 3)] - [InlineData(false, 3)] + [InlineData(true, 2, 0)] + [InlineData(true, 2, 1)] + [InlineData(true, 2, 2)] + [InlineData(false, 2, 0)] + [InlineData(false, 2, 1)] + [InlineData(false, 2, 2)] + [InlineData(true, 3, 0)] + [InlineData(false, 3, 0)] // isCompletionWithResult: the invocation is completed with result or error - public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount) + public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount, int badEndpointsCount) { Assert.True(endpointsCount > 1); + Assert.True(endpointsCount >= badEndpointsCount); var clientInvocationManager = GetTestClientInvocationManager(endpointsCount); var connectionId = TestConnectionIds[0]; var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs index 34f02f5a9..128669184 100644 --- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -174,9 +174,9 @@ public async Task TestStatusPingChangesEndpointStatus() var connectionFactory1 = new TestServiceConnectionFactory(); var connectionFactory2 = new TestServiceConnectionFactory(); - var hub1 = new MultiEndpointServiceConnectionContainer(connectionFactory1, "hub1", 2, null, sem, router, + var hub1 = new MultiEndpointServiceConnectionContainer(connectionFactory1, "hub1", 2, null, sem, router, null, loggerFactory); - var hub2 = new MultiEndpointServiceConnectionContainer(connectionFactory2, "hub2", 2, null, sem, router, + var hub2 = new MultiEndpointServiceConnectionContainer(connectionFactory2, "hub2", 2, null, sem, router, null, loggerFactory); var connections = connectionFactory1.CreatedConnections.SelectMany(kv => kv.Value).ToArray(); @@ -1985,7 +1985,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub, IEndpointRouter router, ILoggerFactory loggerFactory, TimeSpan? _ = null - ) : base(hub, generator, endpoint, router, loggerFactory) + ) : base(hub, generator, endpoint, router, null, loggerFactory) { } From 9431689e88efa3949cae5849b0cfd4779f24c3c7 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Wed, 19 Feb 2025 15:39:21 +0800 Subject: [PATCH 2/4] format --- .../ClientInvocation/ICallerClientResultsManager.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs index 08140d712..474a47209 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs @@ -33,6 +33,5 @@ internal interface ICallerClientResultsManager : IClientResultsManager void RemoveInvocation(string invocationId); void SetAckNumber(string invocationId, int ackNumber); - } } From 68b7aa7266fe408cd3986847385e4d5e05e4538e Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:58:27 +0800 Subject: [PATCH 3/4] resolve comments and fix encoding issue --- .../ClientInvocation/ICallerClientResultsManager.cs | 2 +- .../ServiceConnections/ServiceConnectionContainerFactory.cs | 2 +- .../MultiEndpointConnectionContainerFactory.cs | 2 +- .../ServiceHubContextImpl.cs | 2 +- .../ClientInvocation/CallerClientResultsManager.cs | 4 +--- src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs | 2 +- src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs | 3 ++- 7 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs index 474a47209..93e6b8fdb 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Threading; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs index c4aa18f67..9a5966616 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; diff --git a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs index 1a68bf642..1b1cabc55 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/MultiEndpointConnectionContainerFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using Microsoft.Azure.SignalR.Common; diff --git a/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs b/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs index 59c9a4aa6..aa1e02a97 100644 --- a/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs +++ b/src/Microsoft.Azure.SignalR.Management/ServiceHubContextImpl.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index 4652acabb..16c68a927 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. #if NET7_0_OR_GREATER using System; @@ -48,8 +48,6 @@ public Task AddInvocation(string hub, string connectionId, string invocati var multiAck = _ackHandler.CreateMultiAck(out var ackId); - _ackHandler.SetExpectedCount(ackId, ackNumber); - // When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server. // To handle this condition, CallerClientResultsManager itself should record this mapping information rather than waiting for a ServiceMappingMessage sent by service. Only in this condition, this method is called with instanceId != null. var result = _pendingInvocations.TryAdd(invocationId, diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs index b06da9f4b..711efa934 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 2c60e3b62..b4e20e7ab 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -129,8 +129,9 @@ public override async Task InvokeConnectionAsync(string connectionId, stri var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId); var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId))); - await WriteAsync(message); + // The ack number of invocation will be set inside `WriteAsync`. So adding invocation should be first. var task = _clientInvocationManager.Caller.AddInvocation(_hub, connectionId, invocationId, cancellationToken); + await WriteAsync(message); // Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349 try From db344b2377b6833fc403319d4869951b9b4d94d5 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 24 Feb 2025 17:27:28 +0800 Subject: [PATCH 4/4] resolve comments and fix ut --- .../MultiEndpointMessageWriter.cs | 10 ++++- .../CallerClientResultsManager.cs | 5 +-- .../HubHost/ServiceLifetimeManager.cs | 5 +++ .../ClientInvocationManagerTests.cs | 39 +++++++++++-------- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs index f419ae916..ed1ec544f 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointMessageWriter.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Common; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; @@ -58,6 +59,13 @@ public Task WriteAsync(ServiceMessage serviceMessage) { // Accroding to target endpoints in method `WriteMultiEndpointMessageAsync` _clientInvocationManager.Caller.SetAckNumber(invocationMessage.InvocationId, TargetEndpoints.Length); + if (TargetEndpoints.Length == 0) + { + _clientInvocationManager.Caller.TryCompleteResult( + invocationMessage.ConnectionId, + CompletionMessage.WithError(invocationMessage.InvocationId, "No available endpoint to send invocation message.") + ); + } } return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage)); diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index 16c68a927..481d5a931 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. #if NET7_0_OR_GREATER using System; @@ -43,9 +43,6 @@ public Task AddInvocation(string hub, string connectionId, string invocati cancellationToken, () => TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Canceled"))); - var serviceEndpoints = _serviceEndpointManager.GetEndpoints(hub); - var ackNumber = _endpointRouter.GetEndpointsForConnection(connectionId, serviceEndpoints).Count(); - var multiAck = _ackHandler.CreateMultiAck(out var ackId); // When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server. diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index b4e20e7ab..4bdc00c7d 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -132,6 +132,11 @@ public override async Task InvokeConnectionAsync(string connectionId, stri // The ack number of invocation will be set inside `WriteAsync`. So adding invocation should be first. var task = _clientInvocationManager.Caller.AddInvocation(_hub, connectionId, invocationId, cancellationToken); await WriteAsync(message); + if (ServiceConnectionContainer is not MultiEndpointServiceConnectionContainer) + { + // `WriteAsync` in test class `TestServiceConnectionHandler` does not set ack number. Set the number manually. + _clientInvocationManager.Caller.SetAckNumber(invocationId, 1); + } // Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349 try diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs index 4300dd64b..28c6ca3b2 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. #if NET7_0_OR_GREATER @@ -34,7 +34,8 @@ public class ClientInvocationManagerTests private static readonly List TestInstanceIds = new() { "instance0", "instance1" }; private static readonly List TestServerIds = new() { "server1", "server2" }; private const string SuccessCompleteResult = "success-result"; - private const string ErrorCompleteResult = "error-result"; + private const string CommonErrorCompleteResult = "error-result"; + private const string NoEndpointErrorCompleteResult = "No available endpoint to send invocation message."; private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1, int badEndpointsCount = 0) { @@ -82,6 +83,7 @@ public async Task TestCompleteWithoutRouterServer(bool isCompletionWithResult) var cancellationToken = new CancellationToken(); // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); + clientInvocationManager.Caller.SetAckNumber(invocationId, 1); var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); @@ -90,7 +92,7 @@ public async Task TestCompleteWithoutRouterServer(bool isCompletionWithResult) var completionMessage = isCompletionWithResult ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult) - : CompletionMessage.WithError(invocationId, ErrorCompleteResult); + : CompletionMessage.WithError(invocationId, CommonErrorCompleteResult); ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, completionMessage); Assert.True(ret); @@ -104,7 +106,7 @@ public async Task TestCompleteWithoutRouterServer(bool isCompletionWithResult) catch (Exception e) { Assert.False(isCompletionWithResult); - Assert.Equal(ErrorCompleteResult, e.Message); + Assert.Equal(CommonErrorCompleteResult, e.Message); } } @@ -131,12 +133,14 @@ public async Task TestCompleteWithRouterServer(string protocol, bool isCompletio var cancellationToken = new CancellationToken(); // Server 1 doesn't know the InstanceId of Client 2, so `instaceId` is null for `AddInvocation` var task = ciManagers[0].Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cancellationToken); + ciManagers[0].Caller.SetAckNumber(invocationId, 1); ciManagers[0].Caller.AddServiceMapping(new ServiceMappingMessage(invocationId, TestConnectionIds[1], TestInstanceIds[1])); + ciManagers[1].Router.AddInvocation(TestConnectionIds[1], invocationId, serverIds[0], new CancellationToken()); var completionMessage = isCompletionWithResult ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult) - : CompletionMessage.WithError(invocationId, ErrorCompleteResult); + : CompletionMessage.WithError(invocationId, CommonErrorCompleteResult); var ret = ciManagers[1].Router.TryCompleteResult(TestConnectionIds[1], completionMessage); Assert.True(ret); @@ -156,17 +160,18 @@ public async Task TestCompleteWithRouterServer(string protocol, bool isCompletio catch (Exception e) { Assert.False(isCompletionWithResult); - Assert.Equal(ErrorCompleteResult, e.Message); + Assert.Equal(CommonErrorCompleteResult, e.Message); } } [Fact] public void TestCallerManagerCancellation() { - var clientInvocationManager = GetTestClientInvocationManager(); + var clientInvocationManager = GetTestClientInvocationManager(1); var invocationId = clientInvocationManager.Caller.GenerateInvocationId(TestConnectionIds[0]); var cts = new CancellationTokenSource(); var task = clientInvocationManager.Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cts.Token); + clientInvocationManager.Caller.SetAckNumber(invocationId, 1); // Check if the invocation is existing Assert.True(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _)); @@ -182,7 +187,6 @@ public void TestCallerManagerCancellation() [Theory] [InlineData(true, 2, 0)] [InlineData(true, 2, 1)] - [InlineData(true, 2, 2)] [InlineData(false, 2, 0)] [InlineData(false, 2, 1)] [InlineData(false, 2, 2)] @@ -191,15 +195,16 @@ public void TestCallerManagerCancellation() // isCompletionWithResult: the invocation is completed with result or error public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount, int badEndpointsCount) { - Assert.True(endpointsCount > 1); - Assert.True(endpointsCount >= badEndpointsCount); - var clientInvocationManager = GetTestClientInvocationManager(endpointsCount); + var availableEndpointsCount = endpointsCount - badEndpointsCount; + Assert.True(endpointsCount > 0 && availableEndpointsCount >= 0); + var clientInvocationManager = GetTestClientInvocationManager(endpointsCount, badEndpointsCount); var connectionId = TestConnectionIds[0]; var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); var cancellationToken = new CancellationToken(); // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); + clientInvocationManager.Caller.SetAckNumber(invocationId, endpointsCount - badEndpointsCount); var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); @@ -207,18 +212,17 @@ public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResul Assert.Equal(typeof(string), t); var completionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult); - var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult); + var errorCompletionMessage = CompletionMessage.WithError(invocationId, availableEndpointsCount > 0 ? CommonErrorCompleteResult : NoEndpointErrorCompleteResult); // The first `endpointsCount - 1` CompletionMessage complete the invocation with error // The last one completes the invocation according to `isCompletionWithResult` // The invocation should be uncompleted until the last one CompletionMessage - for (var i = 0; i < endpointsCount - 1; i++) + for (var i = 0; i < availableEndpointsCount - 1; i++) { var currentCompletionMessage = errorCompletionMessage; ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, currentCompletionMessage); Assert.False(ret); } - ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, isCompletionWithResult ? completionMessage : errorCompletionMessage); Assert.True(ret); @@ -226,12 +230,12 @@ public async Task TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResul { var result = await task; Assert.True(isCompletionWithResult); - Assert.Equal(SuccessCompleteResult, result); + Assert.Equal(completionMessage.Result, result); } catch (Exception e) { Assert.False(isCompletionWithResult); - Assert.Equal(ErrorCompleteResult, e.Message); + Assert.Equal(errorCompletionMessage.Error, e.Message); } } @@ -248,6 +252,7 @@ public async Task TestCompleteWithMultiEndpointAtMiddle(int endpointsCount) var cancellationToken = new CancellationToken(); // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); + clientInvocationManager.Caller.SetAckNumber(invocationId, endpointsCount); var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); @@ -255,7 +260,7 @@ public async Task TestCompleteWithMultiEndpointAtMiddle(int endpointsCount) Assert.Equal(typeof(string), t); var successCompletionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult); - var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult); + var errorCompletionMessage = CompletionMessage.WithError(invocationId, CommonErrorCompleteResult); // The first `endpointsCount - 2` CompletionMessage complete the invocation with error // The next one completes the invocation with result