diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index 23efbfb44..818b94696 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -47,6 +47,7 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool) public int ActiveCommandId { get; private set; } public int CancellationTimeout { get; private set; } public int ConnectionId { get; set; } + public string? ServerHostname { get; set; } public byte[]? AuthPluginData { get; set; } public long CreatedTimestamp { get; } public ConnectionPool? Pool { get; } @@ -117,6 +118,24 @@ public void DoCancel(ICancellableCommand commandToCancel, MySqlCommand killComma return; } + // Verify server identity before executing KILL QUERY to prevent cancelling on the wrong server + var killSession = killCommand.Connection!.Session; + if (!string.IsNullOrEmpty(ServerHostname) && !string.IsNullOrEmpty(killSession.ServerHostname)) + { + if (!string.Equals(ServerHostname, killSession.ServerHostname, StringComparison.Ordinal)) + { + Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname); + return; + } + } + else if (!string.IsNullOrEmpty(ServerHostname) || !string.IsNullOrEmpty(killSession.ServerHostname)) + { + // One session has hostname, the other doesn't - this is a potential mismatch + Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerHostname, killSession.ServerHostname); + return; + } + // If both sessions have no hostname, allow the operation for backward compatibility + // NOTE: This command is executed while holding the lock to prevent race conditions during asynchronous cancellation. // For example, if the lock weren't held, the current command could finish and the other thread could set ActiveCommandId // to zero, then start executing a new command. By the time this "KILL QUERY" command reached the server, the wrong @@ -635,6 +654,9 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella ConnectionId = newConnectionId; } + // Get server hostname for KILL QUERY verification + await GetServerHostnameAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout; return redirectionUrl; } @@ -1951,6 +1973,52 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation } } + private async Task GetServerHostnameAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + { + Log.GettingServerHostname(m_logger, Id); + try + { + var payload = SupportsQueryAttributes ? s_selectHostnameWithAttributesPayload : s_selectHostnameNoAttributesPayload; + await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + + // column count: 1 + _ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + + // @@hostname column + _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + + if (!SupportsDeprecateEof) + { + payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + _ = EofPayload.Create(payload.Span); + } + + // first (and only) row + payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + + var reader = new ByteArrayReader(payload.Span); + var length = reader.ReadLengthEncodedIntegerOrNull(); + var hostname = length > 0 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null; + + ServerHostname = hostname; + + Log.RetrievedServerHostname(m_logger, Id, hostname); + + // OK/EOF payload + payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + if (OkPayload.IsOk(payload.Span, this)) + OkPayload.Verify(payload.Span, this); + else + EofPayload.Create(payload.Span); + } + catch (MySqlException ex) + { + Log.FailedToGetServerHostname(m_logger, ex, Id); + // Set fallback value to ensure operation can continue + ServerHostname = null; + } + } + private void ShutdownSocket() { Log.ClosingStreamSocket(m_logger, Id); @@ -2182,6 +2250,8 @@ protected override void OnStatementBegin(int index) private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @__MySqlConnector__Sleep;"u8); private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8); private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8); + private static readonly PayloadData s_selectHostnameNoAttributesPayload = QueryPayload.Create(false, "SELECT @@hostname;"u8); + private static readonly PayloadData s_selectHostnameWithAttributesPayload = QueryPayload.Create(true, "SELECT @@hostname;"u8); private readonly ILogger m_logger; #if NET9_0_OR_GREATER diff --git a/src/MySqlConnector/Logging/EventIds.cs b/src/MySqlConnector/Logging/EventIds.cs index de87e8590..8dab69056 100644 --- a/src/MySqlConnector/Logging/EventIds.cs +++ b/src/MySqlConnector/Logging/EventIds.cs @@ -86,6 +86,9 @@ internal static class EventIds public const int CertificateErrorUnixSocket = 2158; public const int CertificateErrorNoPassword = 2159; public const int CertificateErrorValidThumbprint = 2160; + public const int GettingServerHostname = 2161; + public const int RetrievedServerHostname = 2162; + public const int FailedToGetServerHostname = 2163; // Command execution events, 2200-2299 public const int CannotExecuteNewCommandInState = 2200; @@ -108,6 +111,7 @@ internal static class EventIds public const int IgnoringCancellationForInactiveCommand = 2306; public const int CancelingCommand = 2307; public const int SendingSleepToClearPendingCancellation = 2308; + public const int IgnoringCancellationForDifferentServer = 2309; // Cached procedure events, 2400-2499 public const int GettingCachedProcedure = 2400; diff --git a/src/MySqlConnector/Logging/Log.cs b/src/MySqlConnector/Logging/Log.cs index e9b4f88bc..c195d1387 100644 --- a/src/MySqlConnector/Logging/Log.cs +++ b/src/MySqlConnector/Logging/Log.cs @@ -189,6 +189,18 @@ internal static partial class Log [LoggerMessage(EventIds.FailedToGetConnectionId, LogLevel.Information, "Session {SessionId} failed to get CONNECTION_ID(), VERSION()")] public static partial void FailedToGetConnectionId(ILogger logger, Exception exception, string sessionId); + [LoggerMessage(EventIds.GettingServerHostname, LogLevel.Debug, "Session {SessionId} getting server hostname")] + public static partial void GettingServerHostname(ILogger logger, string sessionId); + + [LoggerMessage(EventIds.RetrievedServerHostname, LogLevel.Debug, "Session {SessionId} retrieved server hostname: {ServerHostname}")] + public static partial void RetrievedServerHostname(ILogger logger, string sessionId, string? serverHostname); + + [LoggerMessage(EventIds.FailedToGetServerHostname, LogLevel.Information, "Session {SessionId} failed to get server hostname")] + public static partial void FailedToGetServerHostname(ILogger logger, Exception exception, string sessionId); + + [LoggerMessage(EventIds.IgnoringCancellationForDifferentServer, LogLevel.Warning, "Session {SessionId} ignoring cancellation from session {KillSessionId}: server hostname mismatch (this hostname={ServerHostname}, kill hostname={KillServerHostname})")] + public static partial void IgnoringCancellationForDifferentServer(ILogger logger, string sessionId, string killSessionId, string? serverHostname, string? killServerHostname); + [LoggerMessage(EventIds.ClosingStreamSocket, LogLevel.Debug, "Session {SessionId} closing stream/socket")] public static partial void ClosingStreamSocket(ILogger logger, string sessionId); diff --git a/tests/IntegrationTests/ServerIdentificationTests.cs b/tests/IntegrationTests/ServerIdentificationTests.cs new file mode 100644 index 000000000..d893fad9d --- /dev/null +++ b/tests/IntegrationTests/ServerIdentificationTests.cs @@ -0,0 +1,53 @@ +using System.Diagnostics; + +namespace IntegrationTests; + +public class ServerIdentificationTests : IClassFixture, IDisposable +{ + public ServerIdentificationTests(DatabaseFixture database) + { + m_database = database; + } + + public void Dispose() + { + } + + [SkippableFact(ServerFeatures.Timeout)] + public void CancelCommand_WithServerVerification() + { + // This test verifies that cancellation still works with server verification + using var connection = new MySqlConnection(AppConfig.ConnectionString); + connection.Open(); + + using var cmd = new MySqlCommand("SELECT SLEEP(5)", connection); + var task = Task.Run(async () => + { + await Task.Delay(TimeSpan.FromSeconds(0.5)); + cmd.Cancel(); + }); + + var stopwatch = Stopwatch.StartNew(); + TestUtilities.AssertExecuteScalarReturnsOneOrIsCanceled(cmd); + Assert.InRange(stopwatch.ElapsedMilliseconds, 250, 2500); + +#pragma warning disable xUnit1031 // Do not use blocking task operations in test method + task.Wait(); // shouldn't throw +#pragma warning restore xUnit1031 // Do not use blocking task operations in test method + } + + [SkippableFact(ServerFeatures.KnownCertificateAuthority)] + public void ServerHasServerHostname() + { + using var connection = new MySqlConnection(AppConfig.ConnectionString); + connection.Open(); + + // Test that we can query server hostname + using var cmd = new MySqlCommand("SELECT @@hostname", connection); + var hostname = cmd.ExecuteScalar(); + + // Hostname might be null on some server configurations, but the query should succeed + } + + private readonly DatabaseFixture m_database; +} \ No newline at end of file