From 8383e19722a295f51743a11a660c9f4ff3bcce21 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 18:10:00 -0500 Subject: [PATCH 01/30] Extract AsyncHelper from SqlUtil.cs into the utilities namespace --- .../src/Microsoft.Data.SqlClient.csproj | 3 + .../Data/SqlClient/SqlCommand.netcore.cs | 2 +- .../SqlClient/SqlInternalConnectionTds.cs | 1 + .../netfx/src/Microsoft.Data.SqlClient.csproj | 3 + .../Data/SqlClient/SqlCommand.netfx.cs | 5 +- .../SqlClient/SqlInternalConnectionTds.cs | 2 +- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 1 + .../Data/SqlClient/SqlCommand.NonQuery.cs | 1 + .../Data/SqlClient/SqlCommand.Reader.cs | 2 +- .../Data/SqlClient/SqlCommand.Xml.cs | 1 + .../Microsoft/Data/SqlClient/SqlConnection.cs | 3 +- .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 244 ----------------- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 11 +- .../Data/SqlClient/TdsParserStateObject.cs | 9 +- .../Data/SqlClient/Utilities/AsyncHelper.cs | 256 ++++++++++++++++++ 15 files changed, 278 insertions(+), 266 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 62c2f04ec8..b71ade16ef 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -762,6 +762,9 @@ Microsoft\Data\SqlClient\SSPI\SspiAuthenticationParameters.cs + + Microsoft\Data\SqlClient\Utilities\AsyncHelper.cs + Microsoft\Data\SqlClient\Utilities\ObjectPool.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs index e02abd4d54..c60ff1bf50 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs @@ -10,13 +10,13 @@ using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; -using System.IO; using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Diagnostics; +using Microsoft.Data.SqlClient.Utilities; // NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available. // New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future. diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 06736ab693..b0dfc22b16 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -18,6 +18,7 @@ using Microsoft.Data.Common.ConnectionString; using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ConnectionPool; +using Microsoft.Data.SqlClient.Utilities; using Microsoft.Identity.Client; namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 8cf025e182..95de12d922 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -957,6 +957,9 @@ Microsoft\Data\SqlClient\TdsValueSetter.cs + + Microsoft\Data\SqlClient\Utilities\AsyncHelper.cs + Microsoft\Data\SqlClient\Utilities\BufferWriterExtensions.netfx.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs index 68033e95cd..e6780e24db 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs @@ -3,21 +3,20 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data; using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; -using System.IO; using System.Runtime.CompilerServices; -using System.Runtime.ConstrainedExecution; using System.Security.Permissions; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; -using System.Collections.Concurrent; +using Microsoft.Data.SqlClient.Utilities; // NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available. // New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future. diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 9f4a612951..04c4348313 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -9,7 +9,6 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Net.Http.Headers; -using System.Runtime.CompilerServices; using System.Security; using System.Text; using System.Threading; @@ -19,6 +18,7 @@ using Microsoft.Data.Common.ConnectionString; using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ConnectionPool; +using Microsoft.Data.SqlClient.Utilities; using Microsoft.Identity.Client; namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 9784117a5e..cabacf2e9a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -15,6 +15,7 @@ using System.Threading.Tasks; using System.Xml; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index df5c372890..f86ea680ce 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 91900a13d2..5d516abd72 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -10,10 +10,10 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; -using Microsoft.Data.SqlClient.Utilities; #endif namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index 1a9a778cd0..8a5c9eb4b4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -10,6 +10,7 @@ using System.Xml; using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Server; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs index 276201ec1d..c9784bcab8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -22,7 +22,8 @@ using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ConnectionPool; using Microsoft.Data.SqlClient.Diagnostics; -using Microsoft.SqlServer.Server; +using Microsoft.Data.SqlClient.Utilities; + #if NETFRAMEWORK using System.Runtime.CompilerServices; using System.Security.Permissions; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index e13a77bd73..2c6d575eec 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -49,250 +49,6 @@ internal static ArgumentOutOfRangeException InvalidMinAndMaxPair(string minParam => new ArgumentOutOfRangeException(minParamName, StringsHelper.GetString(Strings.SqlRetryLogic_InvalidMinMaxPair, minValue, maxValue, minParamName, maxParamName)); } - internal static class AsyncHelper - { - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - Action onFailure = null) - { - if (task == null) - { - onSuccess(); - return null; - } - else - { - TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTaskWithState( - task, - completion, - state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: static (object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action success = parameters.Item1; - TaskCompletionSource taskCompletionSource = parameters.Item3; - success(); - taskCompletionSource.SetResult(null); - }, - onFailure: static (Exception exception, object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action failure = parameters.Item2; - failure?.Invoke(exception); - } - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) - { - if (task == null) - { - onSuccess(state); - return null; - } - else - { - var completion = new TaskCompletionSource(); - ContinueTaskWithState(task, completion, state, - onSuccess: (object continueState) => - { - onSuccess(continueState); - completion.SetResult(null); - }, - onFailure: onFailure - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - T1 arg1, - T2 arg2, - Action onFailure = null) - { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); - } - - internal static void ContinueTask(Task task, - TaskCompletionSource completion, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - tsk => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - try - { - onFailure?.Invoke(exc); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, TaskScheduler.Default - ); - } - - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - (Task tsk, object state2) => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - - try - { - onFailure?.Invoke(exc, state2); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(state2); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(state2); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, - state: state, - scheduler: TaskScheduler.Default - ); - } - - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) - { - try - { - task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); - } - catch (AggregateException ae) - { - if (rethrowExceptions) - { - Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); - } - } - if (!task.IsCompleted) - { - task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception - onTimeout?.Invoke(); - } - } - - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) - { - if (timeout > 0) - { - Task.Delay(timeout * 1000, ctoken).ContinueWith( - (Task task) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure()); - } - } - ); - } - } - - internal static void SetTimeoutExceptionWithState( - TaskCompletionSource completion, - int timeout, - object state, - Func onFailure, - CancellationToken cancellationToken) - { - if (timeout <= 0) - { - return; - } - - Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, innerState) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure(innerState)); - } - }, - state: state, - cancellationToken: CancellationToken.None); - } - } - internal static class SQL { // The class SQL defines the exceptions that are specific to the SQL Adapter. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index d88bf218f6..83a536c7a9 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -12,17 +12,11 @@ using System.Globalization; using System.IO; using System.Security.Authentication; -#if NETFRAMEWORK -using System.Runtime.CompilerServices; -#endif using System.Text; using System.Threading; using System.Threading.Tasks; using System.Xml; using Interop.Common.Sni; -#if NETFRAMEWORK -using Interop.Windows.Sni; -#endif using Microsoft.Data.Common; using Microsoft.Data.ProviderBase; using Microsoft.Data.Sql; @@ -30,10 +24,13 @@ using Microsoft.Data.SqlClient.LocalDb; using Microsoft.Data.SqlClient.Server; using Microsoft.Data.SqlClient.Utilities; +using Microsoft.SqlServer.Server; + #if NETFRAMEWORK +using System.Runtime.CompilerServices; +using Interop.Windows.Sni; using Microsoft.Data.SqlTypes; #endif -using Microsoft.SqlServer.Server; namespace Microsoft.Data.SqlClient { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 6a9b2d7369..1613c47eea 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -16,17 +16,10 @@ using Microsoft.Data.Common; using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ManagedSni; - -#if NETFRAMEWORK -using System.Runtime.ConstrainedExecution; -#endif +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { -#if NETFRAMEWORK - using RuntimeHelpers = System.Runtime.CompilerServices.RuntimeHelpers; -#endif - sealed internal class LastIOTimer { internal long _value; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs new file mode 100644 index 0000000000..3885d9da9e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -0,0 +1,256 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.Utilities +{ + internal static class AsyncHelper + { + internal static void ContinueTask(Task task, + TaskCompletionSource completion, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null) + { + task.ContinueWith( + tsk => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc); + } + try + { + onFailure?.Invoke(exc); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + try + { + onSuccess(); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + completion.SetException(e); + } + } + }, TaskScheduler.Default + ); + } + + // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure + // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties + internal static void ContinueTaskWithState(Task task, + TaskCompletionSource completion, + object state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null) + { + task.ContinueWith( + (Task tsk, object state2) => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc); + } + + try + { + onFailure?.Invoke(exc, state2); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(state2); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + try + { + onSuccess(state2); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + completion.SetException(e); + } + } + }, + state: state, + scheduler: TaskScheduler.Default + ); + } + + internal static Task CreateContinuationTask( + Task task, + Action onSuccess, + Action onFailure = null) + { + if (task == null) + { + onSuccess(); + return null; + } + else + { + TaskCompletionSource completion = new TaskCompletionSource(); + ContinueTaskWithState( + task, + completion, + state: Tuple.Create(onSuccess, onFailure, completion), + onSuccess: static (object state) => + { + var parameters = (Tuple, TaskCompletionSource>)state; + Action success = parameters.Item1; + TaskCompletionSource taskCompletionSource = parameters.Item3; + success(); + taskCompletionSource.SetResult(null); + }, + onFailure: static (Exception exception, object state) => + { + var parameters = (Tuple, TaskCompletionSource>)state; + Action failure = parameters.Item2; + failure?.Invoke(exception); + } + ); + return completion.Task; + } + } + + internal static Task CreateContinuationTask( + Task task, + Action onSuccess, + T1 arg1, + T2 arg2, + Action onFailure = null) + { + return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); + } + + internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) + { + if (task == null) + { + onSuccess(state); + return null; + } + else + { + var completion = new TaskCompletionSource(); + ContinueTaskWithState(task, completion, state, + onSuccess: (object continueState) => + { + onSuccess(continueState); + completion.SetResult(null); + }, + onFailure: onFailure + ); + return completion.Task; + } + } + + internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) + { + if (timeout > 0) + { + Task.Delay(timeout * 1000, ctoken).ContinueWith( + (Task task) => + { + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure()); + } + } + ); + } + } + + internal static void SetTimeoutExceptionWithState( + TaskCompletionSource completion, + int timeout, + object state, + Func onFailure, + CancellationToken cancellationToken) + { + if (timeout <= 0) + { + return; + } + + Task.Delay(timeout * 1000, cancellationToken).ContinueWith( + (task, innerState) => + { + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure(innerState)); + } + }, + state: state, + cancellationToken: CancellationToken.None); + } + + internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) + { + try + { + task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); + } + catch (AggregateException ae) + { + if (rethrowExceptions) + { + Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); + ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); + } + } + if (!task.IsCompleted) + { + task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception + onTimeout?.Invoke(); + } + } + } +} From 8119c0ce8a467d2d5ecde2318671602e56074654 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 18:32:12 -0500 Subject: [PATCH 02/30] Add generic ContinueTaskWithState - and it's static the whole way through! --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 3885d9da9e..b55b86706b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -65,6 +65,73 @@ internal static void ContinueTask(Task task, ); } + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + TaskCompletionSourceContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + taskToContinue.ContinueWith( + static (task, state2) => + { + TaskCompletionSourceContinuationState typedState2 = + (TaskCompletionSourceContinuationState)state2; + + if (task.Exception is not null) + { + // @TODO: Exception converter? + try + { + typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(task.Exception); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.SetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + } + + private record TaskCompletionSourceContinuationState( + Action OnCancellation, + Action OnFailure, + Action OnSuccess, + TState State, + TaskCompletionSource TaskCompletionSource); + // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties internal static void ContinueTaskWithState(Task task, From 3df7400bc345b02d15406cbd6dd19f17545d3c18 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 21:30:05 -0500 Subject: [PATCH 03/30] Use generic version where not difficult --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 29 ++++++------- .../Data/SqlClient/SqlCommand.NonQuery.cs | 15 +++---- .../Data/SqlClient/SqlCommand.Reader.cs | 42 ++++++++----------- .../Data/SqlClient/SqlCommand.Xml.cs | 11 ++--- 4 files changed, 43 insertions(+), 54 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index cabacf2e9a..c638042d2e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2684,18 +2684,19 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, task, source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + Task continuedTask = this2.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), - onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true)); + onFailure: static (this2, _) => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: static this2 => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: true)); return source.Task; } @@ -2746,24 +2747,24 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal writeTask, source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; try { - sqlBulkCopy.RunParser(); - sqlBulkCopy.CommitTransaction(); + this2.RunParser(); + this2.CommitTransaction(); } catch (Exception) { - sqlBulkCopy.CopyBatchesAsyncContinuedOnError(cleanupParser: false); + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false); throw; } // Always call back into CopyBatchesAsync - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false)); + onFailure: static (this2, _) => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false)); return source.Task; } } @@ -3007,7 +3008,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio reconnectTask, cancellableReconnectTS, state: cancellableReconnectTS, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutExceptionWithState( diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index f86ea680ce..9833a2f7e3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -220,14 +220,11 @@ private IAsyncResult BeginExecuteNonQueryInternal( if (execNonQuery is not null) { AsyncHelper.ContinueTaskWithState( - task: execNonQuery, - completion: localCompletion, + taskToContinue: execNonQuery, + taskCompletionSource: localCompletion, state: Tuple.Create(this, localCompletion), onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteNonQueryInternalReadStage(parameters.Item2); - }); + state.Item1.BeginExecuteNonQueryInternalReadStage(state.Item2)); } else { @@ -897,10 +894,10 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( else { AsyncHelper.ContinueTaskWithState( - subTask, - completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 5d516abd72..d5e6616987 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -308,14 +308,11 @@ private IAsyncResult BeginExecuteReaderInternal( if (writeTask is not null) { AsyncHelper.ContinueTaskWithState( - writeTask, - localCompletion, + taskToContinue: writeTask, + taskCompletionSource: localCompletion, state: Tuple.Create(this, localCompletion), onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteReaderInternalReadStage(parameters.Item2); - }); + state.Item1.BeginExecuteReaderInternalReadStage(state.Item2)); } else { @@ -1652,10 +1649,10 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( else { AsyncHelper.ContinueTaskWithState( - subTask, - completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }); } @@ -1688,14 +1685,13 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( // @TODO: This is a prime candidate for proper async-await execution TaskCompletionSource completion = new TaskCompletionSource(); AsyncHelper.ContinueTaskWithState( - task: describeParameterEncryptionTask, - completion: completion, + taskToContinue: describeParameterEncryptionTask, + taskCompletionSource: completion, state: this, - onSuccess: state => + onSuccess: this2 => { - SqlCommand command = (SqlCommand)state; - command.GenerateEnclavePackage(); - command.RunExecuteReaderTds( + this2.GenerateEnclavePackage(); + this2.RunExecuteReaderTds( cmdBehavior, runBehavior, returnStream, @@ -1714,24 +1710,22 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( else { AsyncHelper.ContinueTaskWithState( - task: subTask, - completion: completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }, - onFailure: static (exception, state) => + onFailure: static (this2, exception) => { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); + this2.CachedAsyncState?.ResetAsyncState(); if (exception is not null) { + // @TODO: This doesn't do anything, afaik. throw exception; } }, - onCancellation: static state => - { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); - }); + onCancellation: static this2 => this2.CachedAsyncState?.ResetAsyncState()); task = completion.Task; return ds; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index 8a5c9eb4b4..dc97f95696 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -269,15 +269,12 @@ private IAsyncResult BeginExecuteXmlReaderInternal( if (writeTask is not null) { - AsyncHelper.ContinueTaskWithState( - task: writeTask, - completion: localCompletion, + AsyncHelper.ContinueTaskWithState>>( + taskToContinue: writeTask, + taskCompletionSource: localCompletion, state: Tuple.Create(this, localCompletion), onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteXmlReaderInternalReadStage(parameters.Item2); - }); + state.Item1.BeginExecuteXmlReaderInternalReadStage(state.Item2)); } else { From 65002403e323b9cef3a1c9501586548f6e78d19d Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 22:12:09 -0500 Subject: [PATCH 04/30] Add stateful version of CreateContinuationTask --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index b55b86706b..857d1e06fc 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -114,6 +114,7 @@ internal static void ContinueTaskWithState( try { typedState2.OnSuccess(typedState2.State); + // @TODO: The one unpleasant thing with this code is that the TCS is not set completed and left to the caller to do or not do (which is more unpleasant) } catch (Exception e) { @@ -191,6 +192,83 @@ internal static void ContinueTaskWithState(Task task, ); } + internal static Task CreateContinuationTaskWithState( + Task taskToContinue, + TState state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + // Note: this code is almost identical to ContinueTaskWithState, but creates its own + // task completion source and completes it on success. + // Yes, we could just chain into the ContinueTaskWithState, but that requires wrapping + // more state in a tuple and confusing the heck out of people. So, duplicating code + // just makes things more clean. Besides, @TODO: We should get rid of these helpers and + // just use async/await natives. + + if (taskToContinue is null) + { + onSuccess(state); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSourceContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + taskToContinue.ContinueWith( + static (task, state2) => + { + TaskCompletionSourceContinuationState typedState2 = + (TaskCompletionSourceContinuationState)state2; + + if (task.Exception is not null) + { + try + { + typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(task.Exception); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + typedState2.TaskCompletionSource.SetResult(null); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.SetException(e); + } + } + + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + return taskCompletionSource.Task; + } + internal static Task CreateContinuationTask( Task task, Action onSuccess, From f4eea9f7bddee6c1d6e7880c0f8b01fd3fcaf164 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 22:26:12 -0500 Subject: [PATCH 05/30] Introduce stateless version that chains into the stateful version --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 857d1e06fc..d7b2340e2c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -192,6 +192,20 @@ internal static void ContinueTaskWithState(Task task, ); } + internal static Task CreateContinuationTask( + Task taskToContinue, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + return CreateContinuationTaskWithState( + taskToContinue, + state: Tuple.Create(onSuccess, onFailure, onCancellation), + onSuccess: static state => state.Item1(), + onFailure: static (state, exception) => state.Item2?.Invoke(exception), + onCancellation: static state => state.Item3?.Invoke()); + } + internal static Task CreateContinuationTaskWithState( Task taskToContinue, TState state, From 8450aaa21404b182b3a636b8e14421b32a4f06ee Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 14 Oct 2025 17:15:56 -0500 Subject: [PATCH 06/30] Introduce two-generic ContinueTaskWithState --- .../Data/SqlClient/SqlCommand.NonQuery.cs | 7 +- .../Data/SqlClient/SqlCommand.Reader.cs | 7 +- .../Data/SqlClient/SqlCommand.Xml.cs | 9 +- .../Data/SqlClient/Utilities/AsyncHelper.cs | 83 +++++++++++++++++-- 4 files changed, 89 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index 9833a2f7e3..32198ea0a9 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -222,9 +222,10 @@ private IAsyncResult BeginExecuteNonQueryInternal( AsyncHelper.ContinueTaskWithState( taskToContinue: execNonQuery, taskCompletionSource: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - state.Item1.BeginExecuteNonQueryInternalReadStage(state.Item2)); + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteNonQueryInternalReadStage(localCompletion2)); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index d5e6616987..c5a60cb51c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -310,9 +310,10 @@ private IAsyncResult BeginExecuteReaderInternal( AsyncHelper.ContinueTaskWithState( taskToContinue: writeTask, taskCompletionSource: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - state.Item1.BeginExecuteReaderInternalReadStage(state.Item2)); + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteReaderInternalReadStage(localCompletion2)); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index dc97f95696..d6b12d03d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -269,12 +269,13 @@ private IAsyncResult BeginExecuteXmlReaderInternal( if (writeTask is not null) { - AsyncHelper.ContinueTaskWithState>>( + AsyncHelper.ContinueTaskWithState( taskToContinue: writeTask, taskCompletionSource: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - state.Item1.BeginExecuteXmlReaderInternalReadStage(state.Item2)); + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteXmlReaderInternalReadStage(localCompletion2)); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index d7b2340e2c..1261973109 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -73,7 +73,7 @@ internal static void ContinueTaskWithState( Action onFailure = null, Action onCancellation = null) { - TaskCompletionSourceContinuationState continuationState = new( + ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, OnSuccess: onSuccess, @@ -83,8 +83,8 @@ internal static void ContinueTaskWithState( taskToContinue.ContinueWith( static (task, state2) => { - TaskCompletionSourceContinuationState typedState2 = - (TaskCompletionSourceContinuationState)state2; + ContinuationState typedState2 = + (ContinuationState)state2; if (task.Exception is not null) { @@ -126,13 +126,82 @@ internal static void ContinueTaskWithState( scheduler: TaskScheduler.Default); } - private record TaskCompletionSourceContinuationState( + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState1 state1, + TState2 state2, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State1: state1, + State2: state2, + TaskCompletionSource: taskCompletionSource); + + taskToContinue.ContinueWith( + static (task, state2) => + { + ContinuationState typedState2 = (ContinuationState)state2; + + if (task.Exception is not null) + { + // @TODO: Exception converter? + try + { + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, task.Exception); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(task.Exception); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State1, typedState2.State2); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.SetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + } + + private record ContinuationState( Action OnCancellation, Action OnFailure, Action OnSuccess, TState State, TaskCompletionSource TaskCompletionSource); + private record ContinuationState( + Action OnCancellation, + Action OnFailure, + Action OnSuccess, + TState1 State1, + TState2 State2, + TaskCompletionSource TaskCompletionSource); + // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties internal static void ContinueTaskWithState(Task task, @@ -228,7 +297,7 @@ internal static Task CreateContinuationTaskWithState( // @TODO: Can totally use a non-generic TaskCompletionSource TaskCompletionSource taskCompletionSource = new(); - TaskCompletionSourceContinuationState continuationState = new( + ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, OnSuccess: onSuccess, @@ -238,8 +307,8 @@ internal static Task CreateContinuationTaskWithState( taskToContinue.ContinueWith( static (task, state2) => { - TaskCompletionSourceContinuationState typedState2 = - (TaskCompletionSourceContinuationState)state2; + ContinuationState typedState2 = + (ContinuationState)state2; if (task.Exception is not null) { From 547af516b637d3c0bb9594f45730ddc4da49fe66 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 14 Oct 2025 17:37:36 -0500 Subject: [PATCH 07/30] Introduce two-generic CreateContinuationTaskWithState --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 225 ++++++++++++------ 1 file changed, 146 insertions(+), 79 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 1261973109..82a7118169 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -187,100 +187,89 @@ internal static void ContinueTaskWithState( scheduler: TaskScheduler.Default); } - private record ContinuationState( - Action OnCancellation, - Action OnFailure, - Action OnSuccess, - TState State, - TaskCompletionSource TaskCompletionSource); + internal static Task CreateContinuationTaskWithState( + Task taskToContinue, + TState state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + // Note: this code is almost identical to ContinueTaskWithState, but creates its own + // task completion source and completes it on success. + // Yes, we could just chain into the ContinueTaskWithState, but that requires wrapping + // more state in a tuple and confusing the heck out of people. So, duplicating code + // just makes things more clean. Besides, @TODO: We should get rid of these helpers and + // just use async/await natives. - private record ContinuationState( - Action OnCancellation, - Action OnFailure, - Action OnSuccess, - TState1 State1, - TState2 State2, - TaskCompletionSource TaskCompletionSource); + if (taskToContinue is null) + { + onSuccess(state); + return null; + } - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - (Task tsk, object state2) => + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + taskToContinue.ContinueWith( + static (task, state2) => { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } + ContinuationState typedState2 = (ContinuationState)state2; + if (task.Exception is not null) + { try { - onFailure?.Invoke(exc, state2); + typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); } finally { - completion.TrySetException(exc); + typedState2.TaskCompletionSource.TrySetException(task.Exception); } } - else if (tsk.IsCanceled) + else if (task.IsCanceled) { try { - onCancellation?.Invoke(state2); + typedState2.OnCancellation?.Invoke(typedState2.State); } finally { - completion.TrySetCanceled(); + typedState2.TaskCompletionSource.TrySetCanceled(); } } else { try { - onSuccess(state2); + typedState2.OnSuccess(typedState2.State); + typedState2.TaskCompletionSource.SetResult(null); } - // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception e) { - completion.SetException(e); + typedState2.TaskCompletionSource.SetException(e); } } + }, - state: state, - scheduler: TaskScheduler.Default - ); - } + state: continuationState, + scheduler: TaskScheduler.Default); - internal static Task CreateContinuationTask( - Task taskToContinue, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null) - { - return CreateContinuationTaskWithState( - taskToContinue, - state: Tuple.Create(onSuccess, onFailure, onCancellation), - onSuccess: static state => state.Item1(), - onFailure: static (state, exception) => state.Item2?.Invoke(exception), - onCancellation: static state => state.Item3?.Invoke()); + return taskCompletionSource.Task; } - internal static Task CreateContinuationTaskWithState( + internal static Task CreateContinuationTaskWithState( Task taskToContinue, - TState state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + TState1 state1, + TState2 state2, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) { // Note: this code is almost identical to ContinueTaskWithState, but creates its own // task completion source and completes it on success. @@ -291,30 +280,30 @@ internal static Task CreateContinuationTaskWithState( if (taskToContinue is null) { - onSuccess(state); + onSuccess(state1, state2); return null; } // @TODO: Can totally use a non-generic TaskCompletionSource TaskCompletionSource taskCompletionSource = new(); - ContinuationState continuationState = new( + ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, OnSuccess: onSuccess, - State: state, + State1: state1, + State2: state2, TaskCompletionSource: taskCompletionSource); taskToContinue.ContinueWith( static (task, state2) => { - ContinuationState typedState2 = - (ContinuationState)state2; + ContinuationState typedState2 = (ContinuationState)state2; if (task.Exception is not null) { try { - typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, task.Exception); } finally { @@ -325,7 +314,7 @@ internal static Task CreateContinuationTaskWithState( { try { - typedState2.OnCancellation?.Invoke(typedState2.State); + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); } finally { @@ -336,7 +325,7 @@ internal static Task CreateContinuationTaskWithState( { try { - typedState2.OnSuccess(typedState2.State); + typedState2.OnSuccess(typedState2.State1, typedState2.State2); typedState2.TaskCompletionSource.SetResult(null); } catch (Exception e) @@ -352,6 +341,94 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } + private record ContinuationState( + Action OnCancellation, + Action OnFailure, + Action OnSuccess, + TState State, + TaskCompletionSource TaskCompletionSource); + + private record ContinuationState( + Action OnCancellation, + Action OnFailure, + Action OnSuccess, + TState1 State1, + TState2 State2, + TaskCompletionSource TaskCompletionSource); + + // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure + // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties + internal static void ContinueTaskWithState(Task task, + TaskCompletionSource completion, + object state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null) + { + task.ContinueWith( + (Task tsk, object state2) => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc); + } + + try + { + onFailure?.Invoke(exc, state2); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(state2); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + try + { + onSuccess(state2); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + completion.SetException(e); + } + } + }, + state: state, + scheduler: TaskScheduler.Default + ); + } + + internal static Task CreateContinuationTask( + Task taskToContinue, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + return CreateContinuationTaskWithState( + taskToContinue, + state: Tuple.Create(onSuccess, onFailure, onCancellation), + onSuccess: static state => state.Item1(), + onFailure: static (state, exception) => state.Item2?.Invoke(exception), + onCancellation: static state => state.Item3?.Invoke()); + } + internal static Task CreateContinuationTask( Task task, Action onSuccess, @@ -388,16 +465,6 @@ internal static Task CreateContinuationTask( } } - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - T1 arg1, - T2 arg2, - Action onFailure = null) - { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); - } - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) { if (task == null) From 847fd254ab2822ff0d46d66018e63234ffe3b5e1 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 15 Oct 2025 11:56:44 -0500 Subject: [PATCH 08/30] Replacing usages that didn't show up before? --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 127 +++++++++--------- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 20 +-- .../Data/SqlClient/TdsParserStateObject.cs | 5 +- 3 files changed, 81 insertions(+), 71 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index c638042d2e..7ce6fc54fd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2051,10 +2051,11 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.ContinueTaskWithState(writeTask, tcs, + AsyncHelper.ContinueTaskWithState( + taskToContinue: writeTask, + taskCompletionSource: tcs, state: tcs, - onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) - ); + onSuccess: static tcs2 => tcs2.SetResult(null)); } }, ctoken); // We do not need to propagate exception, etc, from reconnect task, we just need to wait for it to finish. return tcs.Task; @@ -2363,15 +2364,15 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource source = nul private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource source, Task task, int i) { AsyncHelper.ContinueTaskWithState( - task, - source, + taskToContinue: task, + taskCompletionSource: source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - if (i + 1 < sqlBulkCopy._sortedColumnMappings.Count) + if (i + 1 < this2._sortedColumnMappings.Count) { - sqlBulkCopy.CopyColumnsAsync(i + 1, source); //continue from the next column + // continue from the next column + this2.CopyColumnsAsync(i + 1, source); } else { @@ -2506,18 +2507,17 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, Task readTask = ReadFromRowSourceAsync(cts); // Read the next row. Caution: more is only valid if the task returns null. Otherwise, we wait for Task.Result if (readTask != null) { - if (source == null) - { - source = new TaskCompletionSource(); - } + source ??= new TaskCompletionSource(); resultTask = source.Task; AsyncHelper.ContinueTaskWithState( - readTask, - source, + taskToContinue: readTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source)); - return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled. + onSuccess: this2 => this2.CopyRowsAsync(i + 1, totalRows, cts, source)); + + // Associated task will be completed when all rows are copied to server/exception/cancelled. + return resultTask; } } else @@ -2525,34 +2525,35 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, source = source ?? new TaskCompletionSource(); resultTask = source.Task; - AsyncHelper.ContinueTaskWithState(task, source, this, - onSuccess: (object state) => + AsyncHelper.ContinueTaskWithState( + taskToContinue: task, + taskCompletionSource: source, + state: this, + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - sqlBulkCopy.CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. + // Check for notification now as the current row copy is done at this moment. + this2.CheckAndRaiseNotification(); - Task readTask = sqlBulkCopy.ReadFromRowSourceAsync(cts); - if (readTask == null) + Task readTask = this2.ReadFromRowSourceAsync(cts); + if (readTask is null) { - sqlBulkCopy.CopyRowsAsync(i + 1, totalRows, cts, source); + this2.CopyRowsAsync(i + 1, totalRows, cts, source); } else { AsyncHelper.ContinueTaskWithState( - readTask, - source, - state: sqlBulkCopy, - onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source)); + taskToContinue: readTask, + taskCompletionSource: source, + state: this2, + onSuccess: this3 => this3.CopyRowsAsync(i + 1, totalRows, cts, source)); } }); return resultTask; } } - if (source != null) - { - source.TrySetResult(null); // This is set only on the last call of async copy. But may not be set if everything runs synchronously. - } + // This is set only on the last call of async copy. But may not be set if everything runs synchronously. + source?.TrySetResult(null); } catch (Exception ex) { @@ -2614,17 +2615,21 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up } AsyncHelper.ContinueTaskWithState( - commandTask, - source, + taskToContinue: commandTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); - if (continuedTask == null) + Task continuedTask = this2.CopyBatchesAsyncContinued( + internalResults, + updateBulkCommandText, + cts, + source); + + if (continuedTask is null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }); return source.Task; @@ -2861,24 +2866,21 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int if (task != null) { - if (source == null) - { - source = new TaskCompletionSource(); - } + source ??= new TaskCompletionSource(); AsyncHelper.ContinueTaskWithState( - task, - source, + taskToContinue: task, + taskCompletionSource: source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + // @TODO: Split into oncancellation, onfailure, etc. // Bulk copy task is completed at this moment. if (task.IsCanceled) { - sqlBulkCopy._localColumnMappings = null; + this2._localColumnMappings = null; try { - sqlBulkCopy.CleanUpStateObject(); + this2.CleanUpStateObject(); } finally { @@ -2891,10 +2893,10 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } else { - sqlBulkCopy._localColumnMappings = null; + this2._localColumnMappings = null; try { - sqlBulkCopy.CleanUpStateObject(isCancelRequested: false); + this2.CleanUpStateObject(isCancelRequested: false); } finally { @@ -3088,10 +3090,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio if (internalResultsTask != null) { AsyncHelper.ContinueTaskWithState( - internalResultsTask, - source, + taskToContinue: internalResultsTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source)); + onSuccess: this2 => + this2.WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source)); } else { @@ -3160,17 +3163,21 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) else { Debug.Assert(_isAsyncBulkCopy, "Read must not return a Task in the Sync mode"); - AsyncHelper.ContinueTaskWithState(readTask, source, this, - onSuccess: (object state) => + AsyncHelper.ContinueTaskWithState( + taskToContinue: readTask, + taskCompletionSource: source, + state: this, + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - if (!sqlBulkCopy._hasMoreRowToCopy) + if (!this2._hasMoreRowToCopy) { - source.SetResult(null); // No rows to copy! + // No rows to copy! + source.SetResult(null); } else { - sqlBulkCopy.WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + // Passing the same completion which will be completed by the Callee. + this2.WriteToServerInternalRestAsync(ctoken, source); } }); return resultTask; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 83a536c7a9..cb46736569 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -12215,11 +12215,11 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } else { - return AsyncHelper.CreateContinuationTask( - unterminatedWriteTask, - onSuccess: WriteInt, - arg1: 0, - arg2: stateObj); + return AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: unterminatedWriteTask, + state1: this, + state2: stateObj, + onSuccess: static (this2, stateObj2) => this2.WriteInt(0, stateObj2)); } } else @@ -13182,11 +13182,11 @@ private Task WriteEncryptionMetadata(Task terminatedWriteTask, SqlColumnEncrypti else { // Otherwise, create a continuation task to write the encryption metadata after the previous write completes. - return AsyncHelper.CreateContinuationTask( - terminatedWriteTask, - onSuccess: WriteEncryptionMetadata, - arg1: columnEncryptionParameterInfo, - arg2: stateObj); + return AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: terminatedWriteTask, + state1: columnEncryptionParameterInfo, + state2: stateObj, + onSuccess: WriteEncryptionMetadata); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 1613c47eea..537565d15c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -3044,7 +3044,10 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false) if (willCancel) { // If we have been canceled, then ensure that we write the ATTN packet as well - task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket); + task = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: task, + state: this, + onSuccess: static this2 => this2.CancelWritePacket()); } return task; From cb6d1f983f0f97f91deaea2a2628e1394f2bf0d9 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 12:49:50 -0500 Subject: [PATCH 09/30] Make non-generic ContinueTaskWithState private, remove exteernal usages of it. --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 33 +++++++++++++------ .../Data/SqlClient/Utilities/AsyncHelper.cs | 2 +- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 7ce6fc54fd..2b0d8283fd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -3023,25 +3023,38 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio ); AsyncHelper.ContinueTaskWithState( - task: cancellableReconnectTS.Task, - completion: source, + taskToContinue:cancellableReconnectTS.Task, + taskCompletionSource: source, state: regReconnectCancel, - onSuccess: (object state) => + onSuccess: regReconnectCancel2 => { - ((StrongBox)state).Value.Dispose(); - if (_parserLock != null) + regReconnectCancel2.Value.Dispose(); + + if (_parserLock is not null) { _parserLock.Release(); - _parserLock = null; + _parserLock = null; // @TODO: Can be omitted b/c we reassign it directly below } _parserLock = _connection.GetOpenTdsConnection()._parserLock; _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - onFailure: static (_, state) => ((StrongBox)state).Value.Dispose(), - onCancellation: static state => ((StrongBox)state).Value.Dispose(), - exceptionConverter: ex => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex) - ); + onFailure: (regReconnectCancel2, exception) => + { + regReconnectCancel2.Value.Dispose(); + + // Convert exception and set it on the source + // Note: This is safe because the helper will only try to set the + // exception and b/c it is already set will pass without setting + // to the original exception. + Exception convertedException = SQL.BulkLoadInvalidDestinationTable( + _destinationTableName, + exception); + source.TrySetException(convertedException); + }, + onCancellation: static regReconnectCancel2 => + regReconnectCancel2.Value.Dispose()); + return; } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 82a7118169..aa2f51f673 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -358,7 +358,7 @@ private record ContinuationState( // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, + private static void ContinueTaskWithState(Task task, TaskCompletionSource completion, object state, Action onSuccess, From dd48b593363b2b591772d5ec849fdd161fa32734 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 13:37:40 -0500 Subject: [PATCH 10/30] Rewrite CreateContinuationTask --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 84 +++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index aa2f51f673..d7f664ee3e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -187,6 +187,70 @@ internal static void ContinueTaskWithState( scheduler: TaskScheduler.Default); } + internal static Task CreateContinuationTask( + Task taskToContinue, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null) + { + if (taskToContinue is null) + { + onSuccess(); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + + taskToContinue.ContinueWith(static (task, continuationState2) => + { + ContinuationState typedState = (ContinuationState)continuationState2; + if (task.Exception is not null) + { + try + { + typedState.OnFailure?.Invoke(task.Exception); + } + finally + { + typedState.TaskCompletionSource.TrySetException(task.Exception); + } + } + else if (task.IsCanceled) + { + try + { + typedState.OnCancellation?.Invoke(); + } + finally + { + typedState.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState.OnSuccess(); + typedState.TaskCompletionSource.SetResult(null); + } + catch (Exception e) + { + typedState.TaskCompletionSource.SetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + return taskCompletionSource.Task; + } + internal static Task CreateContinuationTaskWithState( Task taskToContinue, TState state, @@ -341,6 +405,12 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } + private record ContinuationState( + Action OnCancellation, + Action OnFailure, + Action OnSuccess, + TaskCompletionSource TaskCompletionSource); + private record ContinuationState( Action OnCancellation, Action OnFailure, @@ -415,20 +485,6 @@ private static void ContinueTaskWithState(Task task, ); } - internal static Task CreateContinuationTask( - Task taskToContinue, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null) - { - return CreateContinuationTaskWithState( - taskToContinue, - state: Tuple.Create(onSuccess, onFailure, onCancellation), - onSuccess: static state => state.Item1(), - onFailure: static (state, exception) => state.Item2?.Invoke(exception), - onCancellation: static state => state.Item3?.Invoke()); - } - internal static Task CreateContinuationTask( Task task, Action onSuccess, From 5b29b7fd3b90893c2379b7ed044980dac81fb7a9 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 13:41:54 -0500 Subject: [PATCH 11/30] Remove seemingly duplicated CreateContinuationTask overload --- .../Data/SqlClient/SqlCommand.netfx.cs | 109 +++++++++--------- .../Data/SqlClient/Utilities/AsyncHelper.cs | 36 ------ 2 files changed, 55 insertions(+), 90 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs index e6780e24db..75bd98e1ab 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs @@ -1248,69 +1248,70 @@ private void PrepareForTransparentEncryption( // Mark that we should not process the finally block since we have async execution pending. // Note that this should be done outside the task's continuation delegate. processFinallyBlock = false; - returnTask = AsyncHelper.CreateContinuationTask(fetchInputParameterEncryptionInfoTask, () => - { - bool processFinallyBlockAsync = true; - bool decrementAsyncCountInFinallyBlockAsync = true; - - try + returnTask = AsyncHelper.CreateContinuationTask( + taskToContinue: fetchInputParameterEncryptionInfoTask, + onSuccess: () => { - // Check for any exceptions on network write, before reading. - CheckThrowSNIException(); + bool processFinallyBlockAsync = true; + bool decrementAsyncCountInFinallyBlockAsync = true; - // If it is async, then TryFetchInputParameterEncryptionInfo-> RunExecuteReaderTds would have incremented the async count. - // Decrement it when we are about to complete async execute reader. - SqlInternalConnectionTds internalConnectionTds = _activeConnection.GetOpenTdsConnection(); - if (internalConnectionTds != null) + try { - internalConnectionTds.DecrementAsyncCount(); - decrementAsyncCountInFinallyBlockAsync = false; - } + // Check for any exceptions on network write, before reading. + CheckThrowSNIException(); - // Complete executereader. - describeParameterEncryptionDataReader = CompleteAsyncExecuteReader(isInternal: false, forDescribeParameterEncryption: true); - Debug.Assert(_stateObj == null, "non-null state object in PrepareForTransparentEncryption."); + // If it is async, then TryFetchInputParameterEncryptionInfo-> RunExecuteReaderTds would have incremented the async count. + // Decrement it when we are about to complete async execute reader. + SqlInternalConnectionTds internalConnectionTds = _activeConnection.GetOpenTdsConnection(); + if (internalConnectionTds != null) + { + internalConnectionTds.DecrementAsyncCount(); + decrementAsyncCountInFinallyBlockAsync = false; + } - // Read the results of describe parameter encryption. - ReadDescribeEncryptionParameterResults( - describeParameterEncryptionDataReader, - describeParameterEncryptionRpcOriginalRpcMap, - isRetry); + // Complete executereader. + describeParameterEncryptionDataReader = CompleteAsyncExecuteReader(isInternal: false, forDescribeParameterEncryption: true); + Debug.Assert(_stateObj == null, "non-null state object in PrepareForTransparentEncryption."); -#if DEBUG - // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. - if (_sleepAfterReadDescribeEncryptionParameterResults) + // Read the results of describe parameter encryption. + ReadDescribeEncryptionParameterResults( + describeParameterEncryptionDataReader, + describeParameterEncryptionRpcOriginalRpcMap, + isRetry); + + #if DEBUG + // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. + if (_sleepAfterReadDescribeEncryptionParameterResults) + { + Thread.Sleep(10000); + } + #endif //DEBUG + } + catch (Exception e) { - Thread.Sleep(10000); + processFinallyBlockAsync = ADP.IsCatchableExceptionType(e); + throw; } -#endif //DEBUG - } - catch (Exception e) - { - processFinallyBlockAsync = ADP.IsCatchableExceptionType(e); - throw; - } - finally - { - PrepareTransparentEncryptionFinallyBlock(closeDataReader: processFinallyBlockAsync, - decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, - clearDataStructures: processFinallyBlockAsync, - wasDescribeParameterEncryptionNeeded: describeParameterEncryptionNeeded, - describeParameterEncryptionRpcOriginalRpcMap: describeParameterEncryptionRpcOriginalRpcMap, - describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); - } - }, - onFailure: ((exception) => - { - if (CachedAsyncState != null) - { - CachedAsyncState.ResetAsyncState(); - } - if (exception != null) + finally + { + PrepareTransparentEncryptionFinallyBlock( + closeDataReader: processFinallyBlockAsync, + decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, + clearDataStructures: processFinallyBlockAsync, + wasDescribeParameterEncryptionNeeded: describeParameterEncryptionNeeded, + describeParameterEncryptionRpcOriginalRpcMap: + describeParameterEncryptionRpcOriginalRpcMap, + describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); + } + }, + onFailure: exception => { - throw exception; - } - })); + CachedAsyncState?.ResetAsyncState(); + if (exception != null) + { + throw exception; + } + }); decrementAsyncCountInFinallyBlock = false; } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index d7f664ee3e..f3f3fb7e03 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -485,42 +485,6 @@ private static void ContinueTaskWithState(Task task, ); } - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - Action onFailure = null) - { - if (task == null) - { - onSuccess(); - return null; - } - else - { - TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTaskWithState( - task, - completion, - state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: static (object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action success = parameters.Item1; - TaskCompletionSource taskCompletionSource = parameters.Item3; - success(); - taskCompletionSource.SetResult(null); - }, - onFailure: static (Exception exception, object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action failure = parameters.Item2; - failure?.Invoke(exception); - } - ); - return completion.Task; - } - } - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) { if (task == null) From 153b76407d383ea5d3546cd4d5a158e3ba2fc0f3 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 13:56:19 -0500 Subject: [PATCH 12/30] Remove the non-generic CreateContinuationTaskWithState overload --- .../Data/SqlClient/SqlCommand.netcore.cs | 36 +++++++++---------- .../Data/SqlClient/SqlCommand.Reader.cs | 18 +++++----- .../Data/SqlClient/TdsParserStateObject.cs | 9 +++-- .../Data/SqlClient/Utilities/AsyncHelper.cs | 22 ------------ 4 files changed, 30 insertions(+), 55 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs index c60ff1bf50..d9466a347e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs @@ -1334,21 +1334,23 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task SqlDataReader describeParameterEncryptionDataReader, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded, bool isRetry) { - returnTask = AsyncHelper.CreateContinuationTaskWithState(fetchInputParameterEncryptionInfoTask, this, - (object state) => + // @TODO: Why use the with state version if we can't make the lambda static? + returnTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: fetchInputParameterEncryptionInfoTask, + state: this, + onSuccess: this2 => { - SqlCommand command = (SqlCommand)state; bool processFinallyBlockAsync = true; bool decrementAsyncCountInFinallyBlockAsync = true; try { // Check for any exceptions on network write, before reading. - command.CheckThrowSNIException(); + this2.CheckThrowSNIException(); // If it is async, then TryFetchInputParameterEncryptionInfo-> RunExecuteReaderTds would have incremented the async count. // Decrement it when we are about to complete async execute reader. - SqlInternalConnectionTds internalConnectionTds = command._activeConnection.GetOpenTdsConnection(); + SqlInternalConnectionTds internalConnectionTds = this2._activeConnection.GetOpenTdsConnection(); if (internalConnectionTds != null) { internalConnectionTds.DecrementAsyncCount(); @@ -1356,19 +1358,21 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task } // Complete executereader. - describeParameterEncryptionDataReader = command.CompleteAsyncExecuteReader(isInternal: false, forDescribeParameterEncryption: true); - Debug.Assert(command._stateObj == null, "non-null state object in PrepareForTransparentEncryption."); + describeParameterEncryptionDataReader = this2.CompleteAsyncExecuteReader( + isInternal: false, + forDescribeParameterEncryption: true); + Debug.Assert(this2._stateObj == null, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - command.ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, isRetry); + this2.ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, isRetry); -#if DEBUG + #if DEBUG // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. if (_sleepAfterReadDescribeEncryptionParameterResults) { Thread.Sleep(10000); } -#endif + #endif } catch (Exception e) { @@ -1377,7 +1381,8 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task } finally { - command.PrepareTransparentEncryptionFinallyBlock(closeDataReader: processFinallyBlockAsync, + this2.PrepareTransparentEncryptionFinallyBlock( + closeDataReader: processFinallyBlockAsync, decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, clearDataStructures: processFinallyBlockAsync, wasDescribeParameterEncryptionNeeded: describeParameterEncryptionNeeded, @@ -1385,14 +1390,9 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); } }, - onFailure: static (Exception exception, object state) => + onFailure: static (this2, exception) => { - SqlCommand command = (SqlCommand)state; - if (command.CachedAsyncState != null) - { - command.CachedAsyncState.ResetAsyncState(); - } - + this2.CachedAsyncState?.ResetAsyncState(); if (exception != null) { throw exception; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index c5a60cb51c..a44f54ee1a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -1580,21 +1580,19 @@ private Task RunExecuteReaderTdsSetupContinuation( string optionSettings, Task writeTask) { - // @TODO: Why use the state version if we can't make this a static helper? return AsyncHelper.CreateContinuationTaskWithState( - task: writeTask, - state: _activeConnection, - onSuccess: state => + taskToContinue: writeTask, + state1: this, + state2: Tuple.Create(ds, runBehavior, optionSettings), + onSuccess: static (this2, parameters) => { // This will throw if the connection is closed. // @TODO: So... can we have something that specifically does that? - ((SqlConnection)state).GetOpenTdsConnection(); - CachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); + this2._activeConnection.GetOpenTdsConnection(); + this2.CachedAsyncState.SetAsyncReaderState(parameters.Item1, parameters.Item2, parameters.Item3); }, - onFailure: static (exception, state) => - { - ((SqlConnection)state).GetOpenTdsConnection().DecrementAsyncCount(); - }); + onFailure: static (this2, _, exception) => + this2._activeConnection.GetOpenTdsConnection().DecrementAsyncCount()); } // @TODO: This is way too many parameters being shoveled back and forth. We can do better. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 537565d15c..058741229f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -1270,13 +1270,12 @@ internal Task ExecuteFlush() else { return AsyncHelper.CreateContinuationTaskWithState( - task: writePacketTask, + taskToContinue: writePacketTask, state: this, - onSuccess: static (object state) => + onSuccess: static this2 => { - TdsParserStateObject stateObject = (TdsParserStateObject)state; - stateObject.HasPendingData = true; - stateObject._messageStatus = 0; + this2.HasPendingData = true; + this2._messageStatus = 0; } ); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index f3f3fb7e03..0dce784da5 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -485,28 +485,6 @@ private static void ContinueTaskWithState(Task task, ); } - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) - { - if (task == null) - { - onSuccess(state); - return null; - } - else - { - var completion = new TaskCompletionSource(); - ContinueTaskWithState(task, completion, state, - onSuccess: (object continueState) => - { - onSuccess(continueState); - completion.SetResult(null); - }, - onFailure: onFailure - ); - return completion.Task; - } - } - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) { if (timeout > 0) From 3584c99fac654b3c7f612082be70c0d4a65b6b4c Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 14:08:21 -0500 Subject: [PATCH 13/30] Rewrite SetTimeoutException --- .../Data/SqlClient/SqlCommand.Reader.cs | 2 +- .../Data/SqlClient/Utilities/AsyncHelper.cs | 101 +++++------------- 2 files changed, 27 insertions(+), 76 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index a44f54ee1a..02d4b10704 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -1613,7 +1613,7 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( AsyncHelper.SetTimeoutException( completion, timeout, - onFailure: static () => SQL.CR_ReconnectTimeout(), + onTimeout: static () => SQL.CR_ReconnectTimeout(), timeoutCts.Token); // @TODO: With an object to pass around we can use the state-based version diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 0dce784da5..c03e328a24 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -405,6 +405,32 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } + // @TODO: This is a pretty wonky way of doing timeouts, imo. + internal static void SetTimeoutException( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + Func onTimeout, + CancellationToken cancellationToken) + { + if (timeoutInSeconds <= 0) + { + return; + } + + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + task => + { + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout()); + } + }, + cancellationToken); + } + private record ContinuationState( Action OnCancellation, Action OnFailure, @@ -426,81 +452,6 @@ private record ContinuationState( TState2 State2, TaskCompletionSource TaskCompletionSource); - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - private static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - (Task tsk, object state2) => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - - try - { - onFailure?.Invoke(exc, state2); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(state2); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(state2); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, - state: state, - scheduler: TaskScheduler.Default - ); - } - - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) - { - if (timeout > 0) - { - Task.Delay(timeout * 1000, ctoken).ContinueWith( - (Task task) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure()); - } - } - ); - } - } - internal static void SetTimeoutExceptionWithState( TaskCompletionSource completion, int timeout, From 14fed2390d67d8f555508f40dab30e981b3d1447 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 14:40:18 -0500 Subject: [PATCH 14/30] SetTimeoutExceptionWithState --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 6 +-- .../Data/SqlClient/Utilities/AsyncHelper.cs | 54 ++++++++++--------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 2b0d8283fd..bbd8730bbf 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -3014,10 +3014,10 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutExceptionWithState( - completion: cancellableReconnectTS, - timeout: BulkCopyTimeout, + taskCompletionSource: cancellableReconnectTS, + timeoutInSeconds: BulkCopyTimeout, state: _destinationTableName, - onFailure: static state => + onTimeout: static state => SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), cancellationToken: CancellationToken.None ); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index c03e328a24..53093eb121 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -405,7 +405,6 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } - // @TODO: This is a pretty wonky way of doing timeouts, imo. internal static void SetTimeoutException( TaskCompletionSource taskCompletionSource, int timeoutInSeconds, @@ -428,7 +427,34 @@ internal static void SetTimeoutException( taskCompletionSource.TrySetException(onTimeout()); } }, - cancellationToken); + cancellationToken: CancellationToken.None); + } + + internal static void SetTimeoutExceptionWithState( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + TState state, + Func onTimeout, + CancellationToken cancellationToken) + { + if (timeoutInSeconds <= 0) + { + return; + } + + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + (task, state2) => + { + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout((TState)state2)); + } + }, + state: state, + cancellationToken: CancellationToken.None); } private record ContinuationState( @@ -452,30 +478,6 @@ private record ContinuationState( TState2 State2, TaskCompletionSource TaskCompletionSource); - internal static void SetTimeoutExceptionWithState( - TaskCompletionSource completion, - int timeout, - object state, - Func onFailure, - CancellationToken cancellationToken) - { - if (timeout <= 0) - { - return; - } - - Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, innerState) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure(innerState)); - } - }, - state: state, - cancellationToken: CancellationToken.None); - } - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) { try From 7684fe7240a00b753565629a6851df195e96d6fc Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 16 Oct 2025 14:54:44 -0500 Subject: [PATCH 15/30] Cleanup WaitForCompletion --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 53093eb121..e62686a683 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -457,6 +457,34 @@ internal static void SetTimeoutExceptionWithState( cancellationToken: CancellationToken.None); } + internal static void WaitForCompletion( + Task task, + int timeoutInSeconds, + Action onTimeout = null, + bool rethrowExceptions = true) + { + try + { + task.Wait(timeoutInSeconds > 0 ? 1000 * timeoutInSeconds : Timeout.Infinite); + } + catch (AggregateException ae) + { + if (rethrowExceptions) + { + Debug.Assert(ae.InnerException is not null, "Inner exception is null"); + Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); + ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); + } + } + + if (!task.IsCompleted) + { + //Ensure the task does not leave an unobserved exception + task.ContinueWith(static t => { var ignored = t.Exception; }); + onTimeout?.Invoke(); + } + } + private record ContinuationState( Action OnCancellation, Action OnFailure, @@ -477,26 +505,5 @@ private record ContinuationState( TState1 State1, TState2 State2, TaskCompletionSource TaskCompletionSource); - - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) - { - try - { - task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); - } - catch (AggregateException ae) - { - if (rethrowExceptions) - { - Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); - } - } - if (!task.IsCompleted) - { - task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception - onTimeout?.Invoke(); - } - } } } From da6ef3c426d5a3f5c86c6c164f9df6578ad4a139 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 20 Oct 2025 13:12:09 -0500 Subject: [PATCH 16/30] Add reference to Moq in unit test project --- src/Directory.Packages.props | 1 + .../tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index cfb7f85797..768981a4f2 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -35,6 +35,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj index 16f09d82cc..ce4108bc3d 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj @@ -11,7 +11,6 @@ - runtime; build; native; contentfiles; analyzers; buildtransitive @@ -28,6 +27,8 @@ + + From f8ccc85a41fa0ccb4b8ea575a7a6d639dc5e7eb5 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 20 Oct 2025 18:20:10 -0500 Subject: [PATCH 17/30] Adding tests --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 22 +- .../SqlClient/Utilities/AsyncHelperTest.cs | 1165 +++++++++++++++++ .../UnitTests/Utilities/MockExtensions.cs | 60 + 3 files changed, 1232 insertions(+), 15 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index e62686a683..b6c4242d42 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -12,14 +12,15 @@ namespace Microsoft.Data.SqlClient.Utilities { internal static class AsyncHelper { - internal static void ContinueTask(Task task, - TaskCompletionSource completion, + internal static void ContinueTask( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, Action onSuccess, Action onFailure = null, Action onCancellation = null, Func exceptionConverter = null) { - task.ContinueWith( + taskToContinue.ContinueWith( tsk => { if (tsk.Exception != null) @@ -35,7 +36,7 @@ internal static void ContinueTask(Task task, } finally { - completion.TrySetException(exc); + taskCompletionSource.TrySetException(exc); } } else if (tsk.IsCanceled) @@ -46,7 +47,7 @@ internal static void ContinueTask(Task task, } finally { - completion.TrySetCanceled(); + taskCompletionSource.TrySetCanceled(); } } else @@ -58,7 +59,7 @@ internal static void ContinueTask(Task task, // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception e) { - completion.SetException(e); + taskCompletionSource.SetException(e); } } }, TaskScheduler.Default @@ -88,7 +89,6 @@ internal static void ContinueTaskWithState( if (task.Exception is not null) { - // @TODO: Exception converter? try { typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); @@ -150,7 +150,6 @@ internal static void ContinueTaskWithState( if (task.Exception is not null) { - // @TODO: Exception converter? try { typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, task.Exception); @@ -258,13 +257,6 @@ internal static Task CreateContinuationTaskWithState( Action onFailure = null, Action onCancellation = null) { - // Note: this code is almost identical to ContinueTaskWithState, but creates its own - // task completion source and completes it on success. - // Yes, we could just chain into the ContinueTaskWithState, but that requires wrapping - // more state in a tuple and confusing the heck out of people. So, duplicating code - // just makes things more clean. Besides, @TODO: We should get rid of these helpers and - // just use async/await natives. - if (taskToContinue is null) { onSuccess(state); diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs new file mode 100644 index 0000000000..4f199a8362 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -0,0 +1,1165 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.UnitTests.Utilities; +using Microsoft.Data.SqlClient.Utilities; +using Moq; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient.Utilities +{ + public class AsyncHelperTest + { + #region ContinueTask + + [Fact] + public async Task ContinueTask_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + Mock> onFailure = new(); + Mock onCancellation = new(); + + // Note: We have to set up onSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock onSuccess = new(); + onSuccess.Setup(action => action()) + .Callback(() => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTask( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + onSuccess.Verify(action => action(), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTask_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + + Mock onSuccess = new(); + onSuccess.SetupThrows(); + + Mock> onFailure = new(); + Mock onCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + onSuccess.Verify(action => action(), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + + Mock onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + Mock onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task ContinueTask_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + Mock onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + Mock onSuccess = new(); + Mock onCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + onFailure.Verify(action => action(It.IsAny()), Times.Once); + } + + [Fact] + public async Task ContinueTask_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + Mock onSuccess = new(); + Mock onCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + onSuccess.Object, + onFailure: null, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Note: We have to set up onSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> onSuccess = new(); + onSuccess.Setup(action => action(state1)) + .Callback(_ => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onSuccess = new(); + onSuccess.Setup(action => action(It.IsAny())).Throws(); + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + Mock> onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + onFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure: null, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + const int state2 = 234; + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Note: We have to set up onSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> onSuccess = new(); + onSuccess.Setup(action => action(state1, state2)) + .Callback((_, _) => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + onSuccess.Verify(action => action(state1, state2), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + + Mock> onSuccess = new(); + onSuccess.Setup(o => o(It.IsAny())).Throws(); + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + // - onSuccess was called with state obj + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + const int state2 = 234; + + Mock> onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + Mock> onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + const int state2 = 234; + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + onFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = new(); + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + onSuccess.Object, + onFailure: null, + onCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTask + + [Fact] + public async Task CreateContinuationTask_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + Mock onSuccess = new(); + Mock> onFailure = new(); + Mock onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + onSuccess.Verify(action => action(), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + Mock> onFailure = new(); + Mock onCancellation = new(); + + Mock onSuccess = new(); + onSuccess.SetupThrows(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.Verify(action => action(), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + Mock> onFailure = new(); + Mock onSuccess = new(); + + Mock onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task CreateContinuationTask_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + Mock> onFailure = new(); + Mock onSuccess = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + Mock onSuccess = new(); + Mock onCancellation = new(); + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.Verify(action => action(It.IsAny()), Times.Once); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + Mock onSuccess = new(); + Mock onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + onSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> onSuccess = new(); + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + Mock> onSuccess = new(); + onSuccess.SetupThrows(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> onFailure = new(); + Mock> onSuccess = new(); + + Mock> onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> onFailure = new(); + Mock> onSuccess = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + onSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + onSuccess.Verify(action => action(state1, state2), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> onFailure = new(); + Mock> onCancellation = new(); + + Mock> onSuccess = new(); + onSuccess.SetupThrows(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.Verify(action => action(state1, state2), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> onFailure = new(); + Mock> onSuccess = new(); + + Mock> onCancellation = new(); + if (handlerShouldThrow) + { + onCancellation.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + onCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> onFailure = new(); + Mock> onSuccess = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + Mock> onFailure = new(); + if (handlerShouldThrow) + { + onFailure.SetupThrows(); + } + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + onCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + onSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + onSuccess.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + + #endregion + + private static Task GetCancelledTask() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + return Task.FromCanceled(cts.Token); + } + + private static async Task RunWithTimeout(Task taskToRun, TimeSpan timeout) + { + Task winner = await Task.WhenAny(taskToRun, Task.Delay(timeout)); + if (winner != taskToRun) + { + Assert.Fail("Timeout elapsed."); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs new file mode 100644 index 0000000000..b765ca5f3e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs @@ -0,0 +1,60 @@ +using System; +using Moq; + +namespace Microsoft.Data.SqlClient.UnitTests.Utilities +{ + public static class MockExtensions + { + public static void SetupThrows(this Mock mock) + where TException : Exception, new() + { + mock.Setup(action => action()) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void VerifyNeverCalled(this Mock mock) => + mock.Verify(action => action(), Times.Never); + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + } + } +} From 45ef3729e85630499b2e49aa3772648e72513c1b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 20 Oct 2025 18:27:23 -0500 Subject: [PATCH 18/30] Removing exception converter from state-less continue task helper --- .../Data/SqlClient/SqlCommand.NonQuery.cs | 4 ++-- .../Data/SqlClient/SqlCommand.Reader.cs | 7 +++---- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 21 ++++++++++++++----- .../Data/SqlClient/TdsParserStateObject.cs | 13 +++++++++--- .../Data/SqlClient/Utilities/AsyncHelper.cs | 11 +++------- 5 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index 32198ea0a9..eb3cc4baa4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -870,8 +870,8 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( timeoutCts.Token); AsyncHelper.ContinueTask( - reconnectTask, - completion, + taskToContinue: reconnectTask, + taskCompletionSource: completion, onSuccess: () => { if (completion.Task.IsCompleted) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 02d4b10704..29b30ccf67 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -1616,10 +1616,9 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( onTimeout: static () => SQL.CR_ReconnectTimeout(), timeoutCts.Token); - // @TODO: With an object to pass around we can use the state-based version AsyncHelper.ContinueTask( - reconnectTask, - completion, + taskToContinue: reconnectTask, + taskCompletionSource: completion, onSuccess: () => { if (completion.Task.IsCompleted) @@ -1651,7 +1650,7 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( taskToContinue: subTask, taskCompletionSource: completion, state: completion, - onSuccess: static state => state.SetResult(null)); + onSuccess: static completion2 => completion2.SetResult(null)); } }); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index cb46736569..62536c56fe 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -10761,11 +10761,23 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet } // This is in its own method to avoid always allocating the lambda in TDSExecuteRPCParameter - private void TDSExecuteRPCParameterSetupWriteCompletion(SqlCommand cmd, IList<_SqlRPC> rpcArray, int timeout, bool inSchema, SqlNotificationRequest notificationRequest, TdsParserStateObject stateObj, bool isCommandProc, bool sync, TaskCompletionSource completion, int startRpc, int startParam, Task writeParamTask) + private void TDSExecuteRPCParameterSetupWriteCompletion( + SqlCommand cmd, + IList<_SqlRPC> rpcArray, + int timeout, + bool inSchema, + SqlNotificationRequest notificationRequest, + TdsParserStateObject stateObj, + bool isCommandProc, + bool sync, + TaskCompletionSource completion, + int startRpc, + int startParam, + Task writeParamTask) { AsyncHelper.ContinueTask( - writeParamTask, - completion, + taskToContinue: writeParamTask, + taskCompletionSource: completion, onSuccess: () => TdsExecuteRPC( cmd, rpcArray, @@ -10777,8 +10789,7 @@ private void TDSExecuteRPCParameterSetupWriteCompletion(SqlCommand cmd, IList<_S sync, completion, startRpc, - startParam - ), + startParam), onFailure: exc => TdsExecuteRPC_OnFailure(exc, stateObj)); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 058741229f..48bcf01f2a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -4353,9 +4353,16 @@ private Task WriteBytes(ReadOnlySpan b, int len, int offsetBuffer, bool ca // This is in its own method to avoid always allocating the lambda in WriteBytes private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource completion, int offset, Task packetTask) { - AsyncHelper.ContinueTask(packetTask, completion, - onSuccess: () => WriteBytes(ReadOnlySpan.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array) - ); + AsyncHelper.ContinueTask( + taskToContinue: packetTask, + taskCompletionSource: completion, + onSuccess: () => WriteBytes( + ReadOnlySpan.Empty, + len: len, + offsetBuffer: offset, + canAccumulate: false, + completion: completion, + array)); } /// diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index b6c4242d42..926555f18c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -17,8 +17,7 @@ internal static void ContinueTask( TaskCompletionSource taskCompletionSource, Action onSuccess, Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) + Action onCancellation = null) { taskToContinue.ContinueWith( tsk => @@ -26,10 +25,6 @@ internal static void ContinueTask( if (tsk.Exception != null) { Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } try { onFailure?.Invoke(exc); @@ -62,8 +57,8 @@ internal static void ContinueTask( taskCompletionSource.SetException(e); } } - }, TaskScheduler.Default - ); + }, + TaskScheduler.Default); } internal static void ContinueTaskWithState( From 0e3250f0fe3b80dfebb6489300a59bd08baa0921 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 20 Oct 2025 18:33:11 -0500 Subject: [PATCH 19/30] Adding null task cases for CreateContinuationTask --- .../SqlClient/Utilities/AsyncHelperTest.cs | 81 +++++++++++++++++++ .../UnitTests/Utilities/MockExtensions.cs | 4 + 2 files changed, 85 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 4f199a8362..5bbb62a0db 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; using System.Threading; using System.Threading.Tasks; @@ -619,6 +623,29 @@ public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() #region CreateContinuationTask + [Fact] + public void CreateContinuationTask_NullTask() + { + // Arrange + Mock onSuccess = new(); + Mock> onFailure = new(); + Mock onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue: null, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + onSuccess.Verify(action => action(), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + [Fact] public async Task CreateContinuationTask_TaskCompletes() { @@ -638,6 +665,7 @@ public async Task CreateContinuationTask_TaskCompletes() // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + onSuccess.Verify(action => action(), Times.Once); onFailure.VerifyNeverCalled(); onCancellation.VerifyNeverCalled(); @@ -779,6 +807,31 @@ public async Task CreateContinuationTask_TaskFaultsNoHandler() #region CreateContinuationTaskWithState + [Fact] + public void CreateContinuationTaskWithState_1Generic_NullTask() + { + // Arrange + const int state1 = 123; + Mock> onSuccess = new(); + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + onSuccess.Verify(action => action(state1), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + [Fact] public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() { @@ -957,6 +1010,34 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() #region CreateContinuationTaskWithState + [Fact] + public void CreateContinuationTaskWithState_2Generics_NullTask() + { + // Arrange + const int state1 = 123; + const int state2 = 234; + + Mock> onSuccess = new(); + Mock> onFailure = new(); + Mock> onCancellation = new(); + + // Act + Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + state2, + onSuccess.Object, + onFailure.Object, + onCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + onSuccess.Verify(action => action(state1, state2), Times.Once); + onFailure.VerifyNeverCalled(); + onCancellation.VerifyNeverCalled(); + } + [Fact] public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() { diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs index b765ca5f3e..d1cfe89e2e 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using System; using Moq; From aabd04fc2b07b43373a5a241e120ea6f7195879b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 22 Oct 2025 14:02:24 -0500 Subject: [PATCH 20/30] Migrating SqlHelperTest in functional tests to AsyncHelper in unit tests --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 4 +- ...soft.Data.SqlClient.FunctionalTests.csproj | 1 - .../tests/FunctionalTests/SqlHelperTest.cs | 62 ------------------- .../SqlClient/Utilities/AsyncHelperTest.cs | 42 +++++++++++++ 4 files changed, 44 insertions(+), 65 deletions(-) delete mode 100644 src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 926555f18c..974615fc02 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -466,8 +466,8 @@ internal static void WaitForCompletion( if (!task.IsCompleted) { - //Ensure the task does not leave an unobserved exception - task.ContinueWith(static t => { var ignored = t.Exception; }); + // Ensure the task does not leave an unobserved exception + task.ContinueWith(static t => { _ = t.Exception; }); onTimeout?.Invoke(); } } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj index 7f6d8abd2c..28389b0335 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj @@ -63,7 +63,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs deleted file mode 100644 index 44286b8c0e..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs +++ /dev/null @@ -1,62 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Xunit; - -namespace Microsoft.Data.SqlClient.Tests -{ - public class SqlHelperTest - { - private void TimeOutATask() - { - var sqlClientAssembly = Assembly.GetAssembly(typeof(SqlCommand)); - //We're using reflection to avoid exposing the internals - MethodInfo waitForCompletion = sqlClientAssembly.GetType("Microsoft.Data.SqlClient.AsyncHelper") - ?.GetMethod("WaitForCompletion", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.False(waitForCompletion == null, "Running a test on SqlUtil.WaitForCompletion but could not find this method"); - TaskCompletionSource tcs = new TaskCompletionSource(); - waitForCompletion.Invoke(null, new object[] { tcs.Task, 1, null, true }); //Will time out as task uncompleted - tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error - } - - private Exception UnwrapException(Exception e) - { - return e?.InnerException != null ? UnwrapException(e.InnerException) : e; - } - - [Fact] - public void WaitForCompletion_DoesNotCreateUnobservedException() - { - var unobservedExceptionHappenedEvent = new AutoResetEvent(false); - Exception unhandledException = null; - void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a) - { unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); } - - TaskScheduler.UnobservedTaskException += handleUnobservedException; - - try - { - TimeOutATask(); //Create the task in another function so the task has no reference remaining - GC.Collect(); //Force collection of unobserved task - GC.WaitForPendingFinalizers(); - - bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1); - if (unobservedExceptionHappend) //Save doing string interpolation in the happy case - { - var e = UnwrapException(unhandledException); - Assert.Fail($"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\""); - } - } - finally - { - TaskScheduler.UnobservedTaskException -= handleUnobservedException; - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 5bbb62a0db..de8c4fb421 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -1226,6 +1226,48 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler( #endregion + #region WaitForCompletion + + [Fact] + public void WaitForCompletion_DoesNotCreateUnobservedException() + { + // Arrange + Exception? unhandledException = null; + EventHandler handleUnobservedException = + (_, args) => unhandledException = args.Exception; + + // @TODO: Can we do this with a custom scheduler to avoid changing global state? + TaskScheduler.UnobservedTaskException += handleUnobservedException; + + try + { + // Act + // - Run task that will always time out + TaskCompletionSource tcs = new(); + AsyncHelper.WaitForCompletion( + tcs.Task, + timeoutInSeconds: 1, + onTimeout: null, + rethrowExceptions: true); + + // - Force collection of unobserved task + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // Assert + // - Make sure no unobserved tasks happened + Assert.Null(unhandledException); + } + finally + { + // Cleanup + // - Remove the unobserved task handler + TaskScheduler.UnobservedTaskException -= handleUnobservedException; + } + } + + #endregion + private static Task GetCancelledTask() { using CancellationTokenSource cts = new(); From ba65f50c8296bd5863a6e25f4bdee884d4de7555 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 22 Oct 2025 18:25:53 -0500 Subject: [PATCH 21/30] Address copilot comments, address unobserved exception issue that breaks WaitForCompletion test --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 75 ++++++++++++------- .../SqlClient/Utilities/AsyncHelperTest.cs | 3 + 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 974615fc02..d8177b08b2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -19,19 +19,19 @@ internal static void ContinueTask( Action onFailure = null, Action onCancellation = null) { - taskToContinue.ContinueWith( + Task continuationTask = taskToContinue.ContinueWith( tsk => { if (tsk.Exception != null) { - Exception exc = tsk.Exception.InnerException; + Exception innerException = tsk.Exception.InnerException; try { - onFailure?.Invoke(exc); + onFailure?.Invoke(innerException); } finally { - taskCompletionSource.TrySetException(exc); + taskCompletionSource.TrySetException(innerException); } } else if (tsk.IsCanceled) @@ -59,6 +59,9 @@ internal static void ContinueTask( } }, TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); } internal static void ContinueTaskWithState( @@ -76,7 +79,7 @@ internal static void ContinueTaskWithState( State: state, TaskCompletionSource: taskCompletionSource); - taskToContinue.ContinueWith( + Task continuationTask = taskToContinue.ContinueWith( static (task, state2) => { ContinuationState typedState2 = @@ -84,13 +87,14 @@ internal static void ContinueTaskWithState( if (task.Exception is not null) { + Exception innerException = task.Exception.InnerException; try { - typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); + typedState2.OnFailure?.Invoke(typedState2.State, innerException); } finally { - typedState2.TaskCompletionSource.TrySetException(task.Exception); + typedState2.TaskCompletionSource.TrySetException(innerException); } } else if (task.IsCanceled) @@ -119,6 +123,9 @@ internal static void ContinueTaskWithState( }, state: continuationState, scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); } internal static void ContinueTaskWithState( @@ -138,20 +145,21 @@ internal static void ContinueTaskWithState( State2: state2, TaskCompletionSource: taskCompletionSource); - taskToContinue.ContinueWith( + Task continuationTask = taskToContinue.ContinueWith( static (task, state2) => { ContinuationState typedState2 = (ContinuationState)state2; if (task.Exception is not null) { + Exception innerException = task.Exception.InnerException; try { - typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, task.Exception); + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); } finally { - typedState2.TaskCompletionSource.TrySetException(task.Exception); + typedState2.TaskCompletionSource.TrySetException(innerException); } } else if (task.IsCanceled) @@ -179,6 +187,9 @@ internal static void ContinueTaskWithState( }, state: continuationState, scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); } internal static Task CreateContinuationTask( @@ -201,18 +212,19 @@ internal static Task CreateContinuationTask( OnSuccess: onSuccess, TaskCompletionSource: taskCompletionSource); - taskToContinue.ContinueWith(static (task, continuationState2) => + Task continuationTask = taskToContinue.ContinueWith(static (task, continuationState2) => { ContinuationState typedState = (ContinuationState)continuationState2; if (task.Exception is not null) { + Exception innerException = task.Exception.InnerException; try { - typedState.OnFailure?.Invoke(task.Exception); + typedState.OnFailure?.Invoke(innerException); } finally { - typedState.TaskCompletionSource.TrySetException(task.Exception); + typedState.TaskCompletionSource.TrySetException(innerException); } } else if (task.IsCanceled) @@ -242,6 +254,9 @@ internal static Task CreateContinuationTask( state: continuationState, scheduler: TaskScheduler.Default); + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + return taskCompletionSource.Task; } @@ -267,20 +282,21 @@ internal static Task CreateContinuationTaskWithState( State: state, TaskCompletionSource: taskCompletionSource); - taskToContinue.ContinueWith( + Task continuationTask = taskToContinue.ContinueWith( static (task, state2) => { ContinuationState typedState2 = (ContinuationState)state2; if (task.Exception is not null) { + Exception innerException = task.Exception.InnerException; try { - typedState2.OnFailure?.Invoke(typedState2.State, task.Exception); + typedState2.OnFailure?.Invoke(typedState2.State, innerException); } finally { - typedState2.TaskCompletionSource.TrySetException(task.Exception); + typedState2.TaskCompletionSource.TrySetException(innerException); } } else if (task.IsCanceled) @@ -311,6 +327,9 @@ internal static Task CreateContinuationTaskWithState( state: continuationState, scheduler: TaskScheduler.Default); + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + return taskCompletionSource.Task; } @@ -322,13 +341,6 @@ internal static Task CreateContinuationTaskWithState( Action onFailure = null, Action onCancellation = null) { - // Note: this code is almost identical to ContinueTaskWithState, but creates its own - // task completion source and completes it on success. - // Yes, we could just chain into the ContinueTaskWithState, but that requires wrapping - // more state in a tuple and confusing the heck out of people. So, duplicating code - // just makes things more clean. Besides, @TODO: We should get rid of these helpers and - // just use async/await natives. - if (taskToContinue is null) { onSuccess(state1, state2); @@ -345,20 +357,21 @@ internal static Task CreateContinuationTaskWithState( State2: state2, TaskCompletionSource: taskCompletionSource); - taskToContinue.ContinueWith( + Task continuationTask = taskToContinue.ContinueWith( static (task, state2) => { ContinuationState typedState2 = (ContinuationState)state2; if (task.Exception is not null) { + Exception innerException = task.Exception.InnerException; try { - typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, task.Exception); + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); } finally { - typedState2.TaskCompletionSource.TrySetException(task.Exception); + typedState2.TaskCompletionSource.TrySetException(innerException); } } else if (task.IsCanceled) @@ -389,6 +402,9 @@ internal static Task CreateContinuationTaskWithState( state: continuationState, scheduler: TaskScheduler.Default); + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + return taskCompletionSource.Task; } @@ -472,6 +488,13 @@ internal static void WaitForCompletion( } } + private static void ObserveContinuationException(Task continuationTask) + { + continuationTask.ContinueWith( + static task => _ = task.Exception, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); + } + private record ContinuationState( Action OnCancellation, Action OnFailure, diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index de8c4fb421..5ce8ed5cb8 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -1283,6 +1283,9 @@ private static async Task RunWithTimeout(Task taskToRun, TimeSpan timeout) { Assert.Fail("Timeout elapsed."); } + + // Force observation of any exception + _ = taskToRun.Exception; } } } From db584d7f84fa0e3b7d38e76af1d841f1414cb66a Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 27 Oct 2025 17:49:29 -0500 Subject: [PATCH 22/30] Addressing a couple bits of PR comments --- .../src/Microsoft/Data/SqlClient/SqlBulkCopy.cs | 11 ++++++----- .../src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index bbd8730bbf..bf972048ba 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2366,17 +2366,18 @@ private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource sour AsyncHelper.ContinueTaskWithState( taskToContinue: task, taskCompletionSource: source, - state: this, - onSuccess: this2 => + state1: this, + state2: Tuple.Create(source, i), + onSuccess: static (this2, parameters) => { - if (i + 1 < this2._sortedColumnMappings.Count) + if (parameters.Item2 + 1 < this2._sortedColumnMappings.Count) { // continue from the next column - this2.CopyColumnsAsync(i + 1, source); + this2.CopyColumnsAsync(parameters.Item2 + 1, parameters.Item1); } else { - source.SetResult(null); + parameters.Item1.SetResult(null); } }); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 29b30ccf67..ed074673f3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -1591,7 +1591,7 @@ private Task RunExecuteReaderTdsSetupContinuation( this2._activeConnection.GetOpenTdsConnection(); this2.CachedAsyncState.SetAsyncReaderState(parameters.Item1, parameters.Item2, parameters.Item3); }, - onFailure: static (this2, _, exception) => + onFailure: static (this2, _, _) => this2._activeConnection.GetOpenTdsConnection().DecrementAsyncCount()); } From 2340c83f3d3df15f6f3e86f3055f7750ef51c6f4 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 30 Oct 2025 15:56:39 -0500 Subject: [PATCH 23/30] Address PR comments * Enable nullable * Add comment blocks * Use TimeSpan * Rename state variables to match better --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 376 ++++++-- .../SqlClient/Utilities/AsyncHelperTest.cs | 817 ++++++++++-------- 2 files changed, 745 insertions(+), 448 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index d8177b08b2..d05ebb1197 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -8,69 +8,150 @@ using System.Threading; using System.Threading.Tasks; +#nullable enable + namespace Microsoft.Data.SqlClient.Utilities { + /// + /// Provides helpers for interacting with asynchronous tasks. + /// + /// + /// These helpers mainly provide continuation and timeout functionality. They utilize + /// at their core, and as such are fairly antiquated + /// implementations. If possible these methods should be utilized less and async/await native + /// constructs should be used. + /// internal static class AsyncHelper { + /// + /// Continues a task and signals failure of the continuation via the provided + /// . + /// + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// exception is set on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) internal static void ContinueTask( Task taskToContinue, - TaskCompletionSource taskCompletionSource, + TaskCompletionSource taskCompletionSource, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { + ContinuationState continuationState = new ContinuationState( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + Task continuationTask = taskToContinue.ContinueWith( - tsk => + static (tsk, continuationState2) => { + ContinuationState typedState = (ContinuationState)continuationState2!; + if (tsk.Exception != null) { - Exception innerException = tsk.Exception.InnerException; + Exception innerException = tsk.Exception.InnerException ?? tsk.Exception; try { - onFailure?.Invoke(innerException); + typedState.OnFailure?.Invoke(innerException); } finally { - taskCompletionSource.TrySetException(innerException); + typedState.TaskCompletionSource.TrySetException(innerException); } } else if (tsk.IsCanceled) { try { - onCancellation?.Invoke(); + typedState.OnCancellation?.Invoke(); } finally { - taskCompletionSource.TrySetCanceled(); + typedState.TaskCompletionSource.TrySetCanceled(); } } else { try { - onSuccess(); + typedState.OnSuccess(); } // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception e) { - taskCompletionSource.SetException(e); + typedState.TaskCompletionSource.TrySetException(e); } } }, - TaskScheduler.Default); + state: continuationState, + scheduler: TaskScheduler.Default); // Explicitly follow up by observing any exception thrown during continuation ObserveContinuationException(continuationTask); } + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides a single state object + /// to the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// exception is set on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Type of the state object to provide to the callbacks + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// State object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) internal static void ContinueTaskWithState( Task taskToContinue, - TaskCompletionSource taskCompletionSource, + TaskCompletionSource taskCompletionSource, TState state, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { ContinuationState continuationState = new( OnCancellation: onCancellation, @@ -80,14 +161,13 @@ internal static void ContinueTaskWithState( TaskCompletionSource: taskCompletionSource); Task continuationTask = taskToContinue.ContinueWith( - static (task, state2) => + static (task, continuationState2) => { - ContinuationState typedState2 = - (ContinuationState)state2; + ContinuationState typedState2 = (ContinuationState)continuationState2!; if (task.Exception is not null) { - Exception innerException = task.Exception.InnerException; + Exception innerException = task.Exception.InnerException ?? task.Exception; try { typedState2.OnFailure?.Invoke(typedState2.State, innerException); @@ -117,7 +197,7 @@ internal static void ContinueTaskWithState( } catch (Exception e) { - typedState2.TaskCompletionSource.SetException(e); + typedState2.TaskCompletionSource.TrySetException(e); } } }, @@ -128,14 +208,49 @@ internal static void ContinueTaskWithState( ObserveContinuationException(continuationTask); } + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides two state objects to + /// the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// exception is set on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more subsequent to + /// this current continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Type of the first state object to provide to callbacks + /// Type of the second state object to provide to callbacks + /// First state object to provide to callbacks + /// Second state object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) internal static void ContinueTaskWithState( Task taskToContinue, - TaskCompletionSource taskCompletionSource, + TaskCompletionSource taskCompletionSource, TState1 state1, TState2 state2, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { ContinuationState continuationState = new( OnCancellation: onCancellation, @@ -146,13 +261,14 @@ internal static void ContinueTaskWithState( TaskCompletionSource: taskCompletionSource); Task continuationTask = taskToContinue.ContinueWith( - static (task, state2) => + static (task, continuationState2) => { - ContinuationState typedState2 = (ContinuationState)state2; + ContinuationState typedState2 = + (ContinuationState)continuationState2!; if (task.Exception is not null) { - Exception innerException = task.Exception.InnerException; + Exception innerException = task.Exception.InnerException ?? task.Exception; try { typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); @@ -192,11 +308,37 @@ internal static void ContinueTaskWithState( ObserveContinuationException(continuationTask); } - internal static Task CreateContinuationTask( - Task taskToContinue, + /// + /// Continues a task and returns the continuation task. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTask( + Task? taskToContinue, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { if (taskToContinue is null) { @@ -205,19 +347,20 @@ internal static Task CreateContinuationTask( } // @TODO: Can totally use a non-generic TaskCompletionSource - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, OnSuccess: onSuccess, TaskCompletionSource: taskCompletionSource); - Task continuationTask = taskToContinue.ContinueWith(static (task, continuationState2) => + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => { - ContinuationState typedState = (ContinuationState)continuationState2; + ContinuationState typedState = (ContinuationState)continuationState2!; if (task.Exception is not null) { - Exception innerException = task.Exception.InnerException; + Exception innerException = task.Exception.InnerException ?? task.Exception; try { typedState.OnFailure?.Invoke(innerException); @@ -260,12 +403,41 @@ internal static Task CreateContinuationTask( return taskCompletionSource.Task; } - internal static Task CreateContinuationTaskWithState( - Task taskToContinue, + /// + /// Continues a task and returns the continuation task. This overload allows a state object + /// to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// State object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, TState state, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { if (taskToContinue is null) { @@ -274,7 +446,7 @@ internal static Task CreateContinuationTaskWithState( } // @TODO: Can totally use a non-generic TaskCompletionSource - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, @@ -283,13 +455,13 @@ internal static Task CreateContinuationTaskWithState( TaskCompletionSource: taskCompletionSource); Task continuationTask = taskToContinue.ContinueWith( - static (task, state2) => + static (task, continuationState2) => { - ContinuationState typedState2 = (ContinuationState)state2; + ContinuationState typedState2 = (ContinuationState)continuationState2!; if (task.Exception is not null) { - Exception innerException = task.Exception.InnerException; + Exception innerException = task.Exception.InnerException ?? task.Exception; try { typedState2.OnFailure?.Invoke(typedState2.State, innerException); @@ -333,13 +505,44 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } - internal static Task CreateContinuationTaskWithState( - Task taskToContinue, + /// + /// Continues a task and returns the continuation task. This overload allows two state + /// objects to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the first state object to pass to callbacks + /// Type of the second state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// First state object to pass to the callbacks + /// Second state object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, TState1 state1, TState2 state2, Action onSuccess, - Action onFailure = null, - Action onCancellation = null) + Action? onFailure = null, + Action? onCancellation = null) { if (taskToContinue is null) { @@ -348,7 +551,7 @@ internal static Task CreateContinuationTaskWithState( } // @TODO: Can totally use a non-generic TaskCompletionSource - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); ContinuationState continuationState = new( OnCancellation: onCancellation, OnFailure: onFailure, @@ -358,13 +561,14 @@ internal static Task CreateContinuationTaskWithState( TaskCompletionSource: taskCompletionSource); Task continuationTask = taskToContinue.ContinueWith( - static (task, state2) => + static (task, continuationState2) => { - ContinuationState typedState2 = (ContinuationState)state2; + ContinuationState typedState2 = + (ContinuationState)continuationState2!; if (task.Exception is not null) { - Exception innerException = task.Exception.InnerException; + Exception innerException = task.Exception.InnerException ?? task.Exception; try { typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); @@ -394,7 +598,7 @@ internal static Task CreateContinuationTaskWithState( } catch (Exception e) { - typedState2.TaskCompletionSource.SetException(e); + typedState2.TaskCompletionSource.TrySetException(e); } } @@ -408,6 +612,19 @@ internal static Task CreateContinuationTaskWithState( return taskCompletionSource.Task; } + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout internal static void SetTimeoutException( TaskCompletionSource taskCompletionSource, int timeoutInSeconds, @@ -433,6 +650,21 @@ internal static void SetTimeoutException( cancellationToken: CancellationToken.None); } + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// This overload provides a state object to the timeout callback. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// State object to pass to the callback + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout internal static void SetTimeoutExceptionWithState( TaskCompletionSource taskCompletionSource, int timeoutInSeconds, @@ -453,22 +685,38 @@ internal static void SetTimeoutExceptionWithState( // then the timeout expired first, run the timeout handler if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) { - taskCompletionSource.TrySetException(onTimeout((TState)state2)); + taskCompletionSource.TrySetException(onTimeout((TState)state2!)); } }, state: state, cancellationToken: CancellationToken.None); } + /// + /// Waits for a maximum of seconds for completion of + /// the provided . + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. + /// + /// + /// If true, the inner exception of any raised + /// during execution, including timeout of the task, will be rethrown. + /// internal static void WaitForCompletion( Task task, int timeoutInSeconds, - Action onTimeout = null, + Action? onTimeout = null, bool rethrowExceptions = true) { try { - task.Wait(timeoutInSeconds > 0 ? 1000 * timeoutInSeconds : Timeout.Infinite); + TimeSpan timeout = timeoutInSeconds > 0 + ? TimeSpan.FromSeconds(timeoutInSeconds) + : Timeout.InfiniteTimeSpan; + task.Wait(timeout); } catch (AggregateException ae) { @@ -476,7 +724,7 @@ internal static void WaitForCompletion( { Debug.Assert(ae.InnerException is not null, "Inner exception is null"); Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); + ExceptionDispatchInfo.Capture(ae.InnerException!).Throw(); } } @@ -496,24 +744,24 @@ private static void ObserveContinuationException(Task continuationTask) } private record ContinuationState( - Action OnCancellation, - Action OnFailure, + Action? OnCancellation, + Action? OnFailure, Action OnSuccess, - TaskCompletionSource TaskCompletionSource); + TaskCompletionSource TaskCompletionSource); private record ContinuationState( - Action OnCancellation, - Action OnFailure, + Action? OnCancellation, + Action? OnFailure, Action OnSuccess, TState State, - TaskCompletionSource TaskCompletionSource); + TaskCompletionSource TaskCompletionSource); private record ContinuationState( - Action OnCancellation, - Action OnFailure, + Action? OnCancellation, + Action? OnFailure, Action OnSuccess, TState1 State1, TState2 State2, - TaskCompletionSource TaskCompletionSource); + TaskCompletionSource TaskCompletionSource); } } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 5ce8ed5cb8..07595f5d93 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.SqlClient.UnitTests.Utilities; @@ -20,61 +21,64 @@ public class AsyncHelperTest public async Task ContinueTask_TaskCompletes() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); - Mock> onFailure = new(); - Mock onCancellation = new(); + TaskCompletionSource taskCompletionSource = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); - // Note: We have to set up onSuccess to set a result on the task completion source, + // Note: We have to set up mockOnSuccess to set a result on the task completion source, // since the AsyncHelper will not do it, and without that, we cannot reliably // know when the continuation completed. We will use SetResult b/c it will throw // if it has already been set. - Mock onSuccess = new(); - onSuccess.Setup(action => action()) + Mock mockOnSuccess = new(); + mockOnSuccess.Setup(action => action()) .Callback(() => taskCompletionSource.SetResult(0)); // Act AsyncHelper.ContinueTask( taskToContinue: taskToContinue, taskCompletionSource: taskCompletionSource, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - onSuccess.Verify(action => action(), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task ContinueTask_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); - Mock onSuccess = new(); - onSuccess.SetupThrows(); + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); - Mock> onFailure = new(); - Mock onCancellation = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); // Act AsyncHelper.ContinueTask( taskToContinue, taskCompletionSource, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.Verify(action => action(), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -83,59 +87,61 @@ public async Task ContinueTask_TaskCompletesHandlerThrows() public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue that is cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); - Mock onCancellation = new(); + Mock mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } - Mock onSuccess = new(); - Mock> onFailure = new(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTask( taskToContinue, taskCompletionSource, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); } [Fact] public async Task ContinueTask_TaskCancelsNoHandler() { // Arrange + // - Task to continue that is cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); - Mock onSuccess = new(); - Mock> onFailure = new(); + TaskCompletionSource taskCompletionSource = new(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTask( taskToContinue, taskCompletionSource, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -144,60 +150,62 @@ public async Task ContinueTask_TaskCancelsNoHandler() public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue that is faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } - Mock onSuccess = new(); - Mock onCancellation = new(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); // Act AsyncHelper.ContinueTask( taskToContinue, taskCompletionSource, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); - onFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); } [Fact] public async Task ContinueTask_TaskFaultsNoHandler() { // Arrange + // - Task to continue that is cancelled Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); - Mock onSuccess = new(); - Mock onCancellation = new(); + TaskCompletionSource taskCompletionSource = new(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); // Act AsyncHelper.ContinueTask( taskToContinue, taskCompletionSource, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, - onCancellation.Object); + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -208,19 +216,20 @@ public async Task ContinueTask_TaskFaultsNoHandler() public async Task ContinueTaskWithState_1Generic_TaskCompletes() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); - // Note: We have to set up onSuccess to set a result on the task completion source, + // Note: We have to set up mockOnSuccess to set a result on the task completion source, // since the AsyncHelper will not do it, and without that, we cannot reliably // know when the continuation completed. We will use SetResult b/c it will throw // if it has already been set. - Mock> onSuccess = new(); - onSuccess.Setup(action => action(state1)) + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1)) .Callback(_ => taskCompletionSource.SetResult(0)); // Act @@ -228,48 +237,50 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletes() taskToContinue: taskToContinue, taskCompletionSource: taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onSuccess = new(); - onSuccess.Setup(action => action(It.IsAny())).Throws(); + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(It.IsAny())).Throws(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have faulted Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -278,64 +289,66 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onCancellation = new(); + Mock> mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } - Mock> onSuccess = new(); - Mock> onFailure = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(state1), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); } [Fact] public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() { // Arrange + // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onSuccess = new(); - Mock> onFailure = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -344,64 +357,66 @@ public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); - onFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); } [Fact] public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() { // Arrange + // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, - onCancellation.Object); + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -412,20 +427,21 @@ public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() public async Task ContinueTaskWithState_2Generics_TaskCompletes() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; const int state2 = 234; - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); - // Note: We have to set up onSuccess to set a result on the task completion source, + // Note: We have to set up mockOnSuccess to set a result on the task completion source, // since the AsyncHelper will not do it, and without that, we cannot reliably // know when the continuation completed. We will use SetResult b/c it will throw // if it has already been set. - Mock> onSuccess = new(); - onSuccess.Setup(action => action(state1, state2)) + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1, state2)) .Callback((_, _) => taskCompletionSource.SetResult(0)); // Act @@ -434,49 +450,51 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletes() taskCompletionSource: taskCompletionSource, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - onSuccess.Verify(action => action(state1, state2), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; - Mock> onSuccess = new(); - onSuccess.Setup(o => o(It.IsAny())).Throws(); + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(o => o(It.IsAny())).Throws(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have faulted Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - // - onSuccess was called with state obj - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + // - mockOnSuccess was called with state obj + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -485,19 +503,20 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; const int state2 = 234; - Mock> onCancellation = new(); + Mock> mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } - Mock> onSuccess = new(); - Mock> onFailure = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTaskWithState( @@ -505,31 +524,32 @@ public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShould taskCompletionSource, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onCancellation throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(state1, state2), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); } [Fact] public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() { // Arrange + // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onFailure = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); // Act AsyncHelper.ContinueTaskWithState( @@ -537,16 +557,16 @@ public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() taskCompletionSource, state1, state2, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -555,19 +575,20 @@ public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; const int state2 = 234; - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( @@ -575,31 +596,32 @@ public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldT taskCompletionSource, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert - // - taskCompletionSource should have been cancelled, regardless of onSuccess throwing + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); - onFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); } [Fact] public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() { // Arrange + // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = new(); const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( @@ -607,16 +629,16 @@ public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() taskCompletionSource, state1, state2, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, - onCancellation.Object); + mockOnCancellation.Object); await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); // Assert // - taskCompletionSource should have been cancelled Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -627,74 +649,77 @@ public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() public void CreateContinuationTask_NullTask() { // Arrange - Mock onSuccess = new(); - Mock> onFailure = new(); - Mock onCancellation = new(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue: null, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); // Assert Assert.Null(continuationTask); - onSuccess.Verify(action => action(), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTask_TaskCompletes() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; - Mock onSuccess = new(); - Mock> onFailure = new(); - Mock onCancellation = new(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); - onSuccess.Verify(action => action(), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTask_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; - Mock> onFailure = new(); - Mock onCancellation = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); - Mock onSuccess = new(); - onSuccess.SetupThrows(); + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.Verify(action => action(), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -703,51 +728,53 @@ public async Task CreateContinuationTask_TaskCompletesHandlerThrows() public async Task CreateContinuationTask_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue was cancelled Task taskToContinue = GetCancelledTask(); - Mock> onFailure = new(); - Mock onSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); - Mock onCancellation = new(); + Mock mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); } [Fact] public async Task CreateContinuationTask_TaskCancelsNoHandler() { // Arrange + // - Task to continue completed successfully Task taskToContinue = GetCancelledTask(); - Mock> onFailure = new(); - Mock onSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -756,51 +783,53 @@ public async Task CreateContinuationTask_TaskCancelsNoHandler() public async Task CreateContinuationTask_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue faulted Task taskToContinue = Task.FromException(new Exception()); - Mock onSuccess = new(); - Mock onCancellation = new(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.Verify(action => action(It.IsAny()), Times.Once); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTask_TaskFaultsNoHandler() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.FromException(new Exception()); - Mock onSuccess = new(); - Mock onCancellation = new(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTask( + Task? continuationTask = AsyncHelper.CreateContinuationTask( taskToContinue, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -812,80 +841,83 @@ public void CreateContinuationTaskWithState_1Generic_NullTask() { // Arrange const int state1 = 123; - Mock> onSuccess = new(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue: null, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); // Assert Assert.Null(continuationTask); - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; const int state1 = 123; - Mock> onSuccess = new(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; const int state1 = 123; - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); - Mock> onSuccess = new(); - onSuccess.SetupThrows(); + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.Verify(action => action(state1), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -894,57 +926,59 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerT public async Task CreateContinuationTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue was cancelled Task taskToContinue = GetCancelledTask(); const int state1 = 123; - Mock> onFailure = new(); - Mock> onSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(state1), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); } [Fact] public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler() { // Arrange + // - Task to continue was cancelled Task taskToContinue = GetCancelledTask(); const int state1 = 123; - Mock> onFailure = new(); - Mock> onSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -953,57 +987,59 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler( public async Task CreateContinuationTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue faulted Task taskToContinue = Task.FromException(new Exception()); const int state1 = 123; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.Verify(action => action(state1, It.IsAny()), Times.Once); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() { // Arrange + // - Task to continue faulted Task taskToContinue = Task.FromException(new Exception()); const int state1 = 123; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -1017,85 +1053,88 @@ public void CreateContinuationTaskWithState_2Generics_NullTask() const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue: null, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); // Assert Assert.Null(continuationTask); - onSuccess.Verify(action => action(state1, state2), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); - onSuccess.Verify(action => action(state1, state2), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandlerThrows() { // Arrange + // - Task to continue completed successfully Task taskToContinue = Task.CompletedTask; const int state1 = 123; const int state2 = 234; - Mock> onFailure = new(); - Mock> onCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); - Mock> onSuccess = new(); - onSuccess.SetupThrows(); + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.Verify(action => action(state1, state2), Times.Once); - onFailure.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -1104,61 +1143,63 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandler public async Task CreateContinuationTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) { // Arrange + // - Task to continue was cancelled Task taskToContinue = GetCancelledTask(); const int state1 = 123; const int state2 = 234; - Mock> onFailure = new(); - Mock> onSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnCancellation = new(); if (handlerShouldThrow) { - onCancellation.SetupThrows(); + mockOnCancellation.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); - onCancellation.Verify(action => action(state1, state2), Times.Once); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); } [Fact] public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler() { // Arrange + // - Task to continue was cancelled Task taskToContinue = GetCancelledTask(); const int state1 = 123; const int state2 = 234; - Mock> onFailure = new(); - Mock> onSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, - onFailure.Object, + mockOnSuccess.Object, + mockOnFailure.Object, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); } [Theory] @@ -1167,61 +1208,63 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler public async Task CreateContinuationTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) { // Arrange + // - Task to continue faulted Task taskToContinue = Task.FromException(new Exception()); const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); - Mock> onFailure = new(); + Mock> mockOnFailure = new(); if (handlerShouldThrow) { - onFailure.SetupThrows(); + mockOnFailure.SetupThrows(); } // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, - onFailure.Object, - onCancellation.Object); + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); } [Fact] public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler() { // Arrange + // - Task to continue faulted Task taskToContinue = Task.FromException(new Exception()); const int state1 = 123; const int state2 = 234; - Mock> onSuccess = new(); - Mock> onCancellation = new(); + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); // Act - Task continuationTask = AsyncHelper.CreateContinuationTaskWithState( + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( taskToContinue, state1, state2, - onSuccess.Object, + mockOnSuccess.Object, onFailure: null, onCancellation: null); await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); - onSuccess.VerifyNeverCalled(); - onCancellation.VerifyNeverCalled(); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } #endregion @@ -1232,6 +1275,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler( public void WaitForCompletion_DoesNotCreateUnobservedException() { // Arrange + // - Create a handler to capture any unhandled exception Exception? unhandledException = null; EventHandler handleUnobservedException = (_, args) => unhandledException = args.Exception; @@ -1243,7 +1287,7 @@ public void WaitForCompletion_DoesNotCreateUnobservedException() { // Act // - Run task that will always time out - TaskCompletionSource tcs = new(); + TaskCompletionSource tcs = new(); AsyncHelper.WaitForCompletion( tcs.Task, timeoutInSeconds: 1, @@ -1276,8 +1320,13 @@ private static Task GetCancelledTask() return Task.FromCanceled(cts.Token); } - private static async Task RunWithTimeout(Task taskToRun, TimeSpan timeout) + private static async Task RunWithTimeout([NotNull] Task? taskToRun, TimeSpan timeout) { + if (taskToRun is null) + { + Assert.Fail("Expected non-null task for timeout"); + } + Task winner = await Task.WhenAny(taskToRun, Task.Delay(timeout)); if (winner != taskToRun) { From 76770efadad19d345c1e767d3e4daf5f5d8e200a Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 31 Oct 2025 12:36:52 -0500 Subject: [PATCH 24/30] Increase timeout and ensure no stack dives on task completion --- .../SqlClient/Utilities/AsyncHelperTest.cs | 113 +++++++++--------- 1 file changed, 59 insertions(+), 54 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 07595f5d93..7f15a4708a 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -15,6 +15,8 @@ namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient.Utilities { public class AsyncHelperTest { + private static readonly TimeSpan RunTimeout = TimeSpan.FromSeconds(2); + #region ContinueTask [Fact] @@ -23,7 +25,7 @@ public async Task ContinueTask_TaskCompletes() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); Mock> mockOnFailure = new(); Mock mockOnCancellation = new(); @@ -42,7 +44,7 @@ public async Task ContinueTask_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert mockOnSuccess.Verify(action => action(), Times.Once); @@ -56,7 +58,7 @@ public async Task ContinueTask_TaskCompletesHandlerThrows() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); // - mockOnSuccess handler throws Mock mockOnSuccess = new(); @@ -72,7 +74,7 @@ public async Task ContinueTask_TaskCompletesHandlerThrows() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); @@ -89,7 +91,7 @@ public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) // Arrange // - Task to continue that is cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); Mock mockOnCancellation = new(); if (handlerShouldThrow) @@ -107,7 +109,7 @@ public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing @@ -124,7 +126,7 @@ public async Task ContinueTask_TaskCancelsNoHandler() // Arrange // - Task to continue that is cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); Mock mockOnSuccess = new(); Mock> mockOnFailure = new(); @@ -135,7 +137,7 @@ public async Task ContinueTask_TaskCancelsNoHandler() mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -152,7 +154,7 @@ public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) // Arrange // - Task to continue that is faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); Mock> mockOnFailure = new(); if (handlerShouldThrow) @@ -170,7 +172,7 @@ public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing @@ -187,7 +189,7 @@ public async Task ContinueTask_TaskFaultsNoHandler() // Arrange // - Task to continue that is cancelled Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); Mock mockOnSuccess = new(); Mock mockOnCancellation = new(); @@ -198,7 +200,7 @@ public async Task ContinueTask_TaskFaultsNoHandler() mockOnSuccess.Object, onFailure: null, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -218,7 +220,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletes() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; Mock> mockOnFailure = new(); @@ -240,7 +242,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert mockOnSuccess.Verify(action => action(state1), Times.Once); @@ -254,7 +256,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; // - mockOnSuccess handler throws @@ -272,7 +274,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have faulted @@ -291,7 +293,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldT // Arrange // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; Mock> mockOnCancellation = new(); @@ -311,7 +313,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldT mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing @@ -328,7 +330,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() // Arrange // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; Mock> mockOnSuccess = new(); @@ -342,7 +344,7 @@ public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -359,7 +361,7 @@ public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldTh // Arrange // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; Mock> mockOnFailure = new(); @@ -379,7 +381,7 @@ public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldTh mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing @@ -396,7 +398,7 @@ public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() // Arrange // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; Mock> mockOnSuccess = new(); @@ -410,7 +412,7 @@ public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() mockOnSuccess.Object, onFailure: null, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -429,7 +431,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletes() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; const int state2 = 234; @@ -453,7 +455,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert mockOnSuccess.Verify(action => action(state1, state2), Times.Once); @@ -467,7 +469,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() // Arrange // - Task to continue that completed successfully Task taskToContinue = Task.CompletedTask; - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; // - mockOnSuccess handler throws @@ -485,7 +487,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have faulted @@ -505,7 +507,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShould // Arrange // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; const int state2 = 234; @@ -527,7 +529,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShould mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing @@ -544,7 +546,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() // Arrange // - Task to continue that was cancelled Task taskToContinue = GetCancelledTask(); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; const int state2 = 234; @@ -560,7 +562,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -577,7 +579,7 @@ public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldT // Arrange // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; const int state2 = 234; @@ -599,7 +601,7 @@ public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldT mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing @@ -616,7 +618,7 @@ public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() // Arrange // - Task to continue that faulted Task taskToContinue = Task.FromException(new Exception()); - TaskCompletionSource taskCompletionSource = new(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; const int state2 = 234; @@ -632,7 +634,7 @@ public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() mockOnSuccess.Object, onFailure: null, mockOnCancellation.Object); - await RunWithTimeout(taskCompletionSource.Task, TimeSpan.FromSeconds(1)); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); // Assert // - taskCompletionSource should have been cancelled @@ -684,7 +686,7 @@ public async Task CreateContinuationTask_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); @@ -713,7 +715,7 @@ public async Task CreateContinuationTask_TaskCompletesHandlerThrows() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -745,7 +747,7 @@ public async Task CreateContinuationTask_TaskCancels(bool handlerShouldThrow) mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -769,7 +771,7 @@ public async Task CreateContinuationTask_TaskCancelsNoHandler() mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -800,7 +802,7 @@ public async Task CreateContinuationTask_TaskFaults(bool handlerShouldThrow) mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -824,7 +826,7 @@ public async Task CreateContinuationTask_TaskFaultsNoHandler() mockOnSuccess.Object, onFailure: null, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -880,7 +882,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); @@ -911,7 +913,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerT mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -946,7 +948,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCancels(bool hand mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -973,7 +975,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler( mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -1007,7 +1009,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskFaults(bool handl mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -1034,7 +1036,7 @@ public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() mockOnSuccess.Object, onFailure: null, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -1095,7 +1097,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); @@ -1128,7 +1130,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandler mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -1165,7 +1167,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCancels(bool han mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -1194,7 +1196,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler mockOnSuccess.Object, mockOnFailure.Object, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Canceled, continuationTask.Status); @@ -1230,7 +1232,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskFaults(bool hand mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -1259,7 +1261,7 @@ public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler( mockOnSuccess.Object, onFailure: null, onCancellation: null); - await RunWithTimeout(continuationTask, TimeSpan.FromSeconds(1)); + await RunWithTimeout(continuationTask, RunTimeout); // Assert Assert.Equal(TaskStatus.Faulted, continuationTask.Status); @@ -1320,6 +1322,9 @@ private static Task GetCancelledTask() return Task.FromCanceled(cts.Token); } + private static TaskCompletionSource GetTaskCompletionSource() + => new(TaskCreationOptions.RunContinuationsAsynchronously); + private static async Task RunWithTimeout([NotNull] Task? taskToRun, TimeSpan timeout) { if (taskToRun is null) From c7d06b98ba3c38581097529cb241602c4e553593 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 31 Oct 2025 13:40:26 -0500 Subject: [PATCH 25/30] *Try*Set everything, add comment as per PR comment --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index d05ebb1197..494efa9591 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -297,7 +297,7 @@ internal static void ContinueTaskWithState( } catch (Exception e) { - typedState2.TaskCompletionSource.SetException(e); + typedState2.TaskCompletionSource.TrySetException(e); } } }, @@ -342,6 +342,9 @@ internal static void ContinueTaskWithState( { if (taskToContinue is null) { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. onSuccess(); return null; } @@ -386,11 +389,11 @@ internal static void ContinueTaskWithState( try { typedState.OnSuccess(); - typedState.TaskCompletionSource.SetResult(null); + typedState.TaskCompletionSource.TrySetResult(null); } catch (Exception e) { - typedState.TaskCompletionSource.SetException(e); + typedState.TaskCompletionSource.TrySetException(e); } } }, @@ -441,6 +444,9 @@ internal static void ContinueTaskWithState( { if (taskToContinue is null) { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. onSuccess(state); return null; } @@ -487,11 +493,11 @@ internal static void ContinueTaskWithState( try { typedState2.OnSuccess(typedState2.State); - typedState2.TaskCompletionSource.SetResult(null); + typedState2.TaskCompletionSource.TrySetResult(null); } catch (Exception e) { - typedState2.TaskCompletionSource.SetException(e); + typedState2.TaskCompletionSource.TrySetException(e); } } @@ -546,6 +552,9 @@ internal static void ContinueTaskWithState( { if (taskToContinue is null) { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. onSuccess(state1, state2); return null; } @@ -594,7 +603,7 @@ internal static void ContinueTaskWithState( try { typedState2.OnSuccess(typedState2.State1, typedState2.State2); - typedState2.TaskCompletionSource.SetResult(null); + typedState2.TaskCompletionSource.TrySetResult(null); } catch (Exception e) { From 2035cf51f7f55c49b17fb23ebebadb7e2828da3b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 3 Nov 2025 13:35:42 -0600 Subject: [PATCH 26/30] Tweak comments, add logging to unovserved exception observation, increase timeout substantially --- .../Data/SqlClient/SqlClientEventSource.cs | 12 +++++++ .../Data/SqlClient/Utilities/AsyncHelper.cs | 32 +++++++++++++++---- .../SqlClient/Utilities/AsyncHelperTest.cs | 8 ++++- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index 11fd3a316d..83f780d9a7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -330,6 +330,18 @@ private string GetFormattedMessage(string className, string memberName, string e #region Trace #region Traces without if statements + + internal void TraceEvent(string message) + { + Trace(message); + } + + [NonEvent] + internal void TraceEvent(string message, T0 args0) + { + Trace(string.Format(message, args0?.ToString() ?? NullStr); + } + [NonEvent] internal void TraceEvent(string message, T0 args0, T1 args1) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 494efa9591..f48eef0463 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -39,7 +39,7 @@ internal static class AsyncHelper /// * Successfully /// * is called /// * IF an exception is thrown during execution of , the - /// exception is set on the . + /// helper will try to set an exception on the . /// * is *not* with result on success. This /// is to allow the task completion source to be continued even more after this current /// continuation. @@ -131,7 +131,7 @@ internal static void ContinueTask( /// * Successfully /// * is called /// * IF an exception is thrown during execution of , the - /// exception is set on the . + /// helper will try to set an exception on the . /// * is *not* with result on success. This /// is to allow the task completion source to be continued even more after this current /// continuation. @@ -227,10 +227,10 @@ internal static void ContinueTaskWithState( /// * Successfully /// * is called /// * IF an exception is thrown during execution of , the - /// exception is set on the . + /// helper will try to set an exception on the . /// * is *not* with result on success. This - /// is to allow the task completion source to be continued even more subsequent to - /// this current continuation. + /// is to allow the task completion source to be continued even more after this + /// current continuation. /// /// Task to continue with provided callbacks /// @@ -745,10 +745,30 @@ internal static void WaitForCompletion( } } + /// + /// This method is intended to be used within the above helpers to ensure that any + /// exceptions thrown during callbacks do not go unobserved. If these exceptions were + /// to go unobserved, they will trigger events to be raised by the default task scheduler. + /// Neither situation is ideal: + /// * If an application assigns a listener to this event, it will generate events that + /// should be reported to us. But, because it happens outside the stack that caused the + /// exception, most of the context of the exception is lost. Furthermore, the event is + /// triggered when the GC runs, so the event happens asynchronous to the action that + /// caused it. + /// * Adding this forced observation of the exception prevents applications from receiving + /// the event, effectively swallowing it. + /// * However, if we log the exception when we observe it, we can still log that the + /// unobserved exception happened without causing undue disruption to the application + /// or leaking resources and causing overhead by raising the event. + /// private static void ObserveContinuationException(Task continuationTask) { continuationTask.ContinueWith( - static task => _ = task.Exception, + static task => + { + SqlClientEventSource.Log.TraceEvent($"Unobserved task exception: {task.Exception}"); + return _ = task.Exception; + }, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 7f15a4708a..fab234b009 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -15,7 +15,13 @@ namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient.Utilities { public class AsyncHelperTest { - private static readonly TimeSpan RunTimeout = TimeSpan.FromSeconds(2); + // This timeout is set fairly high. The tests are expected to complete quickly, but are + // dependent on congestion of the thread pool. If the thread pool is congested, like on a + // full CI run, short timeouts may elapse even if the code under test would behave as + // expected. As such, we set a long timeout to ride out reasonable congestion on the + // thread pool, but still trigger a failure if the code under test hangs. + // @TODO: If suite-level timeouts are added, these timeouts can likely be removed. + private static readonly TimeSpan RunTimeout = TimeSpan.FromSeconds(30); #region ContinueTask From bbc91932820b458fbaed27e66dac386269c0d584 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 4 Nov 2025 10:34:30 -0600 Subject: [PATCH 27/30] Fix syntax error --- .../src/Microsoft/Data/SqlClient/SqlClientEventSource.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index 83f780d9a7..1dadde00cd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -339,7 +339,7 @@ internal void TraceEvent(string message) [NonEvent] internal void TraceEvent(string message, T0 args0) { - Trace(string.Format(message, args0?.ToString() ?? NullStr); + Trace(string.Format(message, args0?.ToString() ?? NullStr)); } [NonEvent] From 3560027976cceb12c85ca9c280192d2964910ef3 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 4 Nov 2025 13:56:19 -0600 Subject: [PATCH 28/30] Only trace the event if tracing is enabled (should've read the comments) --- .../src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index f48eef0463..05dcb0d017 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -766,7 +766,7 @@ private static void ObserveContinuationException(Task continuationTask) continuationTask.ContinueWith( static task => { - SqlClientEventSource.Log.TraceEvent($"Unobserved task exception: {task.Exception}"); + SqlClientEventSource.Log.TryTraceEvent($"Unobserved task exception: {task.Exception}"); return _ = task.Exception; }, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); From 3f5cfa9c8a1cd2c8d063a1ffbdc0eba83fc405a3 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 6 Nov 2025 17:34:51 -0600 Subject: [PATCH 29/30] Make sure mockOnCancellation is checked on one test, reorder a couple checks b/c I felt like it --- .../Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index fab234b009..d6e545af2e 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -86,7 +86,7 @@ public async Task ContinueTask_TaskCompletesHandlerThrows() Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); mockOnSuccess.Verify(action => action(), Times.Once); mockOnFailure.VerifyNeverCalled(); - mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); } [Theory] @@ -185,8 +185,8 @@ public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); mockOnSuccess.VerifyNeverCalled(); - mockOnCancellation.VerifyNeverCalled(); mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); } [Fact] @@ -394,8 +394,8 @@ public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldTh Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); mockOnSuccess.VerifyNeverCalled(); - mockOnCancellation.VerifyNeverCalled(); mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); } [Fact] From 10f39c9061965840803a825b665104dd87371f1b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Mon, 10 Nov 2025 18:55:03 -0600 Subject: [PATCH 30/30] Fix tests that don't correspond to their method name Fix event source listener test so it is ok with the trace from the unobserved exception --- .../TracingTests/EventSourceTest.cs | 31 ++++++++++++------- .../SqlClient/Utilities/AsyncHelperTest.cs | 12 ++++--- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs index 4992c55974..72be245d3b 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using System.Linq; using Xunit; @@ -12,22 +13,28 @@ public class EventSourceTest [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public void EventSourceTestAll() { - using DataTestUtility.MDSEventListener TraceListener = new(); - using (SqlConnection connection = new(DataTestUtility.TCPConnectionString)) + using DataTestUtility.MDSEventListener traceListener = new(); + + using SqlConnection connection = new(DataTestUtility.TCPConnectionString); + connection.Open(); + + using SqlCommand command = new("SELECT @@VERSION", connection); + using SqlDataReader reader = command.ExecuteReader(); + while (reader.Read()) { - connection.Open(); - using SqlCommand command = new("SELECT @@VERSION", connection); - using SqlDataReader reader = command.ExecuteReader(); - while (reader.Read()) - { - // Flush data - } + // Flush data } - // Need to investigate better way of validating traces in sequential runs, - // For now we're collecting all traces to improve code coverage. + // TODO: Need to investigate better way of validating traces in sequential runs, for now we're collecting all traces to improve code coverage. - Assert.All(TraceListener.IDs, item => { Assert.Contains(item, Enumerable.Range(1, 21)); }); + // Assert + // - Collected trace event IDs are in the range of official trace event IDs + // @TODO: This is brittle, refactor the SqlClientEventSource code so the event IDs it can throw are accessible here + HashSet acceptableEventIds = new HashSet(Enumerable.Range(0, 21)); + foreach (int id in traceListener.IDs) + { + Assert.Contains(id, acceptableEventIds); + } } } } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index d6e545af2e..39783ed0b3 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -477,19 +477,21 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() Task taskToContinue = Task.CompletedTask; TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); const int state1 = 123; + const int state2 = 234 // - mockOnSuccess handler throws - Mock> mockOnSuccess = new(); - mockOnSuccess.Setup(o => o(It.IsAny())).Throws(); + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(o => o(It.IsAny(), It.IsAny())).Throws(); - Mock> mockOnFailure = new(); - Mock> mockOnCancellation = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); // Act AsyncHelper.ContinueTaskWithState( taskToContinue, taskCompletionSource, state1, + state2, mockOnSuccess.Object, mockOnFailure.Object, mockOnCancellation.Object); @@ -500,7 +502,7 @@ public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); // - mockOnSuccess was called with state obj - mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); mockOnFailure.VerifyNeverCalled(); mockOnCancellation.VerifyNeverCalled(); }