diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 0a8db81dc2..081edceced 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -87,6 +87,12 @@ Microsoft\Data\Common\NameValuePair.cs + + Microsoft\Data\Common\PacketBuffer.cs + + + Microsoft\Data\Common\ReadOnlySequenceUtilities.cs + Microsoft\Data\DataException.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 5be11f71a7..fefab82b3f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -279,6 +279,12 @@ Microsoft\Data\Common\NameValuePair.cs + + Microsoft\Data\Common\PacketBuffer.cs + + + Microsoft\Data\Common\ReadOnlySequenceUtilities.cs + Microsoft\Data\DataException.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/PacketBuffer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/PacketBuffer.cs new file mode 100644 index 0000000000..079c2933a5 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/PacketBuffer.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; + +#nullable enable + +namespace Microsoft.Data.Common; + +/// +/// One buffer, which may contain one unparsed packet from a single destination. +/// +internal sealed class PacketBuffer : ReadOnlySequenceSegment +{ + public PacketBuffer(ReadOnlyMemory buffer, PacketBuffer? previous) + { + Memory = buffer; + + if (previous is not null) + { + previous.Next = this; + RunningIndex = previous.RunningIndex + previous.Memory.Length; + } + else + { + RunningIndex = 0; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ReadOnlySequenceUtilities.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ReadOnlySequenceUtilities.cs new file mode 100644 index 0000000000..792ca6ddfc --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ReadOnlySequenceUtilities.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Buffers.Binary; + +namespace Microsoft.Data.Common; + +internal static class ReadOnlySequenceUtilities +{ + /// + /// Reads the next byte from the sequence, advancing its position by one byte. + /// + /// The sequence to read and to advance from. + /// The first span in the sequence. Reassigned if the next byte can only be read from the next span. + /// Current position in the sequence. Advanced by one byte following a successful read. + /// The value read from . + /// true if was long enough to retrieve the next byte, false otherwise. + public static bool ReadByte(this ref ReadOnlySequence sequence, ref ReadOnlySpan currSpan, ref long currPos, out byte value) + { + if (sequence.Length < sizeof(byte)) + { + value = default; + return false; + } + + currPos += sizeof(byte); + if (currSpan.Length >= sizeof(byte)) + { + value = currSpan[0]; + + sequence = sequence.Slice(sizeof(byte)); + currSpan = currSpan.Slice(sizeof(byte)); + + return true; + } + else + { + Span buffer = stackalloc byte[sizeof(byte)]; + + sequence.Slice(0, sizeof(byte)).CopyTo(buffer); + value = buffer[0]; + + sequence = sequence.Slice(sizeof(byte)); + currSpan = sequence.First.Span; + + return true; + } + } + + /// + /// Reads the next two bytes from the sequence as a , advancing its position by two bytes. + /// + /// The sequence to read and to advance from. + /// The first span in the sequence. Reassigned if the next two bytes can only be read from the next span. + /// Current position in the sequence. Advanced by two bytes following a successful read. + /// The value read from + /// true if was long enough to retrieve the next two bytes, false otherwise. + public static bool ReadLittleEndian(this ref ReadOnlySequence sequence, ref ReadOnlySpan currSpan, ref long currPos, out ushort value) + { + if (sequence.Length < sizeof(ushort)) + { + value = default; + return false; + } + + currPos += sizeof(ushort); + if (currSpan.Length >= sizeof(ushort)) + { + value = BinaryPrimitives.ReadUInt16LittleEndian(currSpan); + + sequence = sequence.Slice(sizeof(ushort)); + currSpan = currSpan.Slice(sizeof(ushort)); + + return true; + } + else + { + Span buffer = stackalloc byte[sizeof(ushort)]; + + sequence.Slice(0, sizeof(ushort)).CopyTo(buffer); + value = BinaryPrimitives.ReadUInt16LittleEndian(buffer); + + sequence = sequence.Slice(sizeof(ushort)); + currSpan = sequence.First.Span; + + return true; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs b/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs index ffb0003f64..8eb8e3429d 100644 --- a/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs +++ b/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs @@ -43,5 +43,13 @@ public MemberNotNullWhenAttribute(bool returnValue, params string[] members) internal sealed class NotNullAttribute : Attribute { } + + [AttributeUsage(AttributeTargets.Parameter, Inherited = false)] + internal sealed class NotNullWhenAttribute : Attribute + { + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + public bool ReturnValue { get; } + } #endif } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/DacResponseProcessorTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/DacResponseProcessorTest.cs new file mode 100644 index 0000000000..65cd8f2553 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/DacResponseProcessorTest.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using Xunit; + +namespace Microsoft.Data.Sql.UnitTests; + +public class DacResponseProcessorTest +{ + [Theory] + [MemberData(nameof(SsrpPacketTestData.EmptyPacketBuffer), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_EmptyBuffer_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.InvalidSVR_RESP_DACPackets), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_InvalidDacResponse_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.ValidSVR_RESP_DACPacketBuffer), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_ValidDacResponse_ReturnsData(ReadOnlySequence packetBuffers, int expectedDacPort) + { + _ = packetBuffers; + _ = expectedDacPort; + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SqlDataSourceResponseProcessorTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SqlDataSourceResponseProcessorTest.cs new file mode 100644 index 0000000000..04ffde489b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SqlDataSourceResponseProcessorTest.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using Xunit; + +namespace Microsoft.Data.Sql.UnitTests; + +public class SqlDataSourceResponseProcessorTest +{ + [Theory] + [MemberData(nameof(SsrpPacketTestData.EmptyPacketBuffer), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_EmptyBuffer_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.InvalidSVR_RESPPackets), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_InvalidSqlDataSourceResponse_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.InvalidRESP_DATAPackets), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_InvalidSqlDataSourceResponse_RESP_DATA_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.InvalidTCP_INFOPackets), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_InvalidSqlDataSourceResponse_TCP_INFO_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.Invalid_CLNT_UCAST_INST_SVR_RESPPackets), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_InvalidSqlDataSourceResponseToCLNT_UCAST_INST_ReturnsFalse(ReadOnlySequence packetBuffers) + { + _ = packetBuffers; + } + + [Theory] + [MemberData(nameof(SsrpPacketTestData.ValidSVR_RESPPacketBuffer), MemberType = typeof(SsrpPacketTestData), DisableDiscoveryEnumeration = true)] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3700")] + public void Process_ValidSqlDataSourceResponse_ReturnsData(ReadOnlySequence packetBuffers, string expectedVersion, int expectedTcpPort, string? expectedPipeName) + { + _ = packetBuffers; + _ = expectedVersion; + _ = expectedTcpPort; + _ = expectedPipeName; + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SsrpPacketTestData.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SsrpPacketTestData.cs new file mode 100644 index 0000000000..1fd4e1e5bf --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/Sql/SsrpPacketTestData.cs @@ -0,0 +1,501 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Data.Common; +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace Microsoft.Data.Sql.UnitTests; + +/// +/// Test cases used to verify the successful processing of valid SSRP responses and the silent +/// discarding of invalid SSRP responses. +/// +internal static class SsrpPacketTestData +{ + /// + /// One empty packet buffer, which should be successfully processed and contain zero responses. + /// + /// + /// + public static TheoryData> EmptyPacketBuffer => + new(GeneratePacketBuffers([])); + + /// + /// Various combinations of packet buffers containing normal SVR_RESP responses, all of which + /// should be successfully processed. + /// + /// + public static TheoryData, string, int, string?> ValidSVR_RESPPacketBuffer + { + get + { + byte[] complexValidPacket = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1433", npInfo: @"np;\\svr1\pipe\SampléPipeName", + viaInfo: "via;svr1 1:1433", rpcInfo: "rpc;svr1", spxInfo: "spx;MSSQLSERVER", + adspInfo: "adsp;SQL2000", bvInfo: "bv;item;group;item;group;org", otherParameters: null))); + byte[] validPacket1 = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1433", null, null, null, null, null, null, null))); + byte[] validPacket2 = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1434", null, null, null, null, null, null, null))); + byte[] validPacket3 = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1435", null, null, null, null, null, null, null))); + byte[] validPacket4 = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1436", null, null, null, null, null, null, null))); + byte[] invalidPacket1 = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "v14", + CreateProtocolParameters(tcpInfo: "tcp;1433", null, null, null, null, null, null, null))); + + return new() + { + // One buffer, one response + { GeneratePacketBuffers( + complexValidPacket + ), "14.0.0.0", 1433, @"\\svr1\pipe\SampléPipeName" }, + // One response, split into four buffers in the middle of a string + { GeneratePacketBuffers( + complexValidPacket.AsSpan(0, 14).ToArray(), + complexValidPacket.AsSpan(14, 22).ToArray(), + complexValidPacket.AsSpan(36, 71).ToArray(), + // Position 107 is the second byte of the é character when encoded to UTF8. + complexValidPacket.AsSpan(107).ToArray() + ), "14.0.0.0", 1433, @"\\svr1\pipe\SampléPipeName" }, + // Four responses, each with different methods + { GeneratePacketBuffers( + validPacket1, validPacket2, validPacket3, validPacket4 + ), "14.0.0.0", 1436, null }, + // Five responses, with response three invalid + { GeneratePacketBuffers( + complexValidPacket, validPacket2, invalidPacket1, validPacket3, validPacket4 + ), "14.0.0.0", 1436, null } + }; + } + } + + /// + /// Various combinations of packet buffers containing SVR_RESP (DAC) responses, all of which + /// should be successfully processed. + /// + /// + public static TheoryData, int> ValidSVR_RESP_DACPacketBuffer + { + get + { + byte[] validPacket1 = FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x01, 1434)); + byte[] validPacket2 = FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x01, 1435)); + byte[] validPacket3 = FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x01, 1436)); + byte[] validPacket4 = FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x01, 1437)); + byte[] invalidPacket1 = FormatSVR_RESPMessage(0x05, 0x03, CreateRESP_DATA(0x01, 1434)); + + return new() + { + // One buffer, one response + { GeneratePacketBuffers(validPacket1), 1434 }, + + // One response, split into three buffers + { GeneratePacketBuffers(validPacket1.AsSpan(0, 2).ToArray(), + validPacket1.AsSpan(2, 2).ToArray(), + validPacket1.AsSpan(4).ToArray()), 1434 }, + + // Two responses with trailing data + { GeneratePacketBuffers(validPacket1.AsSpan(0, 2).ToArray(), + validPacket1.AsSpan(2, 2).ToArray(), + [..validPacket1.AsSpan(4).ToArray(), 0x05], + validPacket2.AsSpan(0, 2).ToArray(), + validPacket2.AsSpan(2).ToArray(), + [0x05]), 1435 }, + + // Four responses, each with different DAC ports + { GeneratePacketBuffers(validPacket1, validPacket2, validPacket3, validPacket4), 1437 }, + + // Five responses, with response three invalid + { GeneratePacketBuffers(validPacket1, validPacket2, invalidPacket1, validPacket3, validPacket4), 1437 }, + + // Four responses, with three extraneous 0x05 bytes between responses 2 and 3 + { GeneratePacketBuffers(validPacket1, [..validPacket2, 0x05], [0x05], [0x05, ..validPacket3], validPacket4), 1437 } + }; + } + } + + /// + /// Packet buffers containing nothing but invalid SVR_RESP (DAC) responses. + /// + /// + public static TheoryData> InvalidSVR_RESP_DACPackets + { + get + { + return new() + { + // Invalid header byte + { GeneratePacketBuffers(FormatSVR_RESPMessage(0x00, 0x06, CreateRESP_DATA(0x01, 1434))) }, + + // Invalid size + { GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, 0x09, CreateRESP_DATA(0x01, 1434))) }, + + // Invalid protocol version + { GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x02, 1434))) }, + + // Invalid port + { GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x01, 0))) }, + }; + } + } + + /// + /// Packets containing an SVR_RESP response which is a valid response to a CLNT_[B|U]CAST_EX message + /// but not to a CLNT_UCAST_INST message. + /// + /// + public static TheoryData> Invalid_CLNT_UCAST_INST_SVR_RESPPackets + { + get + { + // The RESP_DATA section of the response to a CLNT_UCAST_INST message must be shorter than 1024 bytes. + byte[] longPacket = FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo: "tcp;1433", npInfo: @"np;" + new string('a', 1025), + null, null, null, null, null, null))); + + return new() + { + GeneratePacketBuffers(longPacket) + }; + } + } + + /// + /// Packet buffers containing an SSRP message which is failing due to invalid data + /// in the top-level SVR_RESP message fields. + /// + /// + public static TheoryData> InvalidSVR_RESPPackets + { + get + { + return new( + // Invalid SVR_RESP header field value + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x04, 0x06, CreateRESP_DATA(0x01, 1434)) + ), + + // RESP_SIZE too small (DAC response) + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x05, 0x05, CreateRESP_DATA(0x01, 1434)) + ), + + // RESP_SIZE too large (DAC response) + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x05, 0x07, CreateRESP_DATA(0x01, 1434)) + ), + + // RESP_SIZE larger than the buffer (normal response) + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x05, 72, + CreateRESP_DATA("svr1", "MSSQLSERVER", true, "14.0.0")) + ) + ); + } + } + + /// + /// Packet buffers containing an SSRP message with valid top-level SVR_RESP message + /// fields but invalid components of the child RESP_DATA structure. + /// + /// + public static TheoryData> InvalidRESP_DATAPackets + { + get + { + string validTcpInfo = CreateProtocolParameters("tcp;1433", null, null, null, null, null, null, null); + + return new( + // Does not start with "ServerName" string + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitServerName: true))), + + // Server name longer than 255 bytes + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA(new string('a', 256), "MSSQLSERVER", true, "14.0.0.0", validTcpInfo))), + + // Missing semicolons between keys and values + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitKeyValueSeparators: true))), + + // Missing terminating pair of semicolons + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitTrailingSemicolons: true))), + + // Missing "InstanceName" + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitInstanceName: true))), + + // Instance name longer than 255 bytes + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", new string('a', 256), true, "14.0.0.0", validTcpInfo))), + + // Missing "IsClustered" + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitIsClustered: true))), + + // Invalid IsClustered value + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", "IsClustered;INVALID;" + validTcpInfo, omitIsClustered: true))), + + // Missing "Version" + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", validTcpInfo, omitVersion: true))), + + // Empty version string + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "", validTcpInfo, omitVersion: true))), + + // Version string longer than 16 bytes + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "65535.65535.65.53", validTcpInfo))), + + // Version string not in the correct format: 1*[0-9"."] + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "v14", validTcpInfo))), + + // Protocol components listed twice + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters("tcp;1434", null, null, null, null, null, null, "tcp;1434")))), + + // Invalid protocol components appear + GeneratePacketBuffers(FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters("tcp;1434", null, null, null, null, null, null, "invalid_protocol;value")))), + + // Invalid PROTOCOLVERSION field value + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x05, 0x06, CreateRESP_DATA(0x02, 1434)) + ) + ); + } + } + + /// + /// Packet buffers containing an SSRP message with valid top-level SVR_RESP message + /// fields, a valid RESP_DATA child structure but an invalid TCP_INFO structure. + /// + /// + public static TheoryData> InvalidTCP_INFOPackets + { + get + { + return new( + // Port is absent + CreateSVR_RESPMessage("tcp"), + + // Port is non-numeric + CreateSVR_RESPMessage("tcp;one"), + + // Port is > ushort.MaxValue + CreateSVR_RESPMessage("tcp;65536"), + + // Port is < 0 + CreateSVR_RESPMessage("tcp;-1") + ); + + static ReadOnlySequence CreateSVR_RESPMessage(string tcpInfo) => + GeneratePacketBuffers( + FormatSVR_RESPMessage(0x05, + CreateRESP_DATA("srv1", "MSSQLSERVER", true, "14.0.0.0", + CreateProtocolParameters(tcpInfo, null, null, null, null, null, null, null) + ) + ) + ); + } + } + + private static ReadOnlySequence GeneratePacketBuffers(params byte[][] packetBuffers) + { + if (packetBuffers.Length == 0) + { + return ReadOnlySequence.Empty; + } + + PacketBuffer first = new(packetBuffers[0], null); + PacketBuffer curr = first; + PacketBuffer last; + + for (int i = 1; i < packetBuffers.Length; i++) + { + curr = new(packetBuffers[i], curr); + } + last = curr; + + return new ReadOnlySequence(first, 0, last, last.Memory.Length); + } + + /// + /// Generates an SVR_RESP message with a valid length. + /// + /// The SVR_RESP header value. Expected to be 0x05. + /// The serialized RESP_DATA section. + /// + /// A byte representation of one SVR_RESP message. + private static byte[] FormatSVR_RESPMessage(byte header, ReadOnlySpan respData) => + FormatSVR_RESPMessage(header, (ushort)respData.Length, respData); + + /// + /// Generates an SVR_RESP message with specific characteristics. + /// + /// The SVR_RESP header value. Expected to be 0x05. + /// The RESP_SIZE field to be serialized to the header. + /// The serialized RESP_DATA section. + /// If specified, the number of bytes to actually write. + /// A byte representation of one SVR_RESP message. + /// + /// + private static byte[] FormatSVR_RESPMessage(byte header, ushort serializedResponseSize, ReadOnlySpan respData, + int? realResponseSize = null) + { + byte[] realRespData = realResponseSize is null + ? new byte[sizeof(byte) + sizeof(ushort) + respData.Length] + : new byte[realResponseSize.Value]; + + // Pad any free space after RESP_DATA with 0x05 + if (realRespData.Length > sizeof(byte) + sizeof(ushort) + respData.Length) + { + realRespData.AsSpan(sizeof(byte) + sizeof(ushort) + respData.Length).Fill(0x05); + } + + // Write RESP_DATA + if (realRespData.Length > sizeof(byte) + sizeof(ushort)) + { + int bytesToCopy = Math.Min(respData.Length, realRespData.Length - sizeof(byte) - sizeof(ushort)); + + respData.Slice(0, bytesToCopy).CopyTo(realRespData.AsSpan(sizeof(byte) + sizeof(ushort))); + } + + // Write RESP_SIZE + if (realRespData.Length > sizeof(byte)) + { + Span responseSizeBytes = stackalloc byte[sizeof(ushort)]; + int bytesToCopy = Math.Min(responseSizeBytes.Length, realRespData.Length - sizeof(byte)); + + BinaryPrimitives.WriteUInt16LittleEndian(responseSizeBytes, serializedResponseSize); + responseSizeBytes.Slice(0, bytesToCopy).CopyTo(realRespData.AsSpan(sizeof(byte))); + } + + // Write SVR_RESP + if (realRespData.Length > 0) + { + realRespData[0] = header; + } + + return realRespData; + } + + /// + /// Creates a new RESP_DATA section of an SVR_RESP message for the DAC request with a specified protocol version and TCP port number. + /// + /// Protocol version. Expected to be 0x01. + /// TCP port number of the DAC. + /// A byte representation of a RESP_DATA section. + /// + private static byte[] CreateRESP_DATA(byte protocolVersion, ushort dacPort) + { + byte[] data = new byte[sizeof(byte) + sizeof(ushort)]; + + data[0] = protocolVersion; + BinaryPrimitives.WriteUInt16LittleEndian(data.AsSpan(1), dacPort); + return data; + } + + /// + /// Creates a RESP_DATA section of an SVR_RESP message with specific characteristics. + /// + /// ServerName parameter value. + /// InstanceName parameter value. + /// IsClustered parameter value. + /// Version parameter value. + /// If specified, the protocol parameters. Generated by . + /// If true, the ServerName, InstanceName, IsClustered and Version keys will be written in lowercase. + /// If true, the return value will not include the trailing ;;. + /// If true, no separators between the keys and values will be written. + /// If true, the mandatory ServerName parameter value will not be written. + /// If true, the mandatory InstanceName parameter value will not be written. + /// If true, the mandatory IsClustered parameter value will not be written. + /// If true, the mandatory Version parameter value will not be written. + /// If true, the key/value pairs will be written in a non-sequential order. + /// A byte representation of a RESP_DATA section. + /// + private static byte[] CreateRESP_DATA(string serverName, string instanceName, bool isClustered, string version, + string? protocolParameters = null, + bool lowercaseKey = false, bool omitTrailingSemicolons = false, bool omitKeyValueSeparators = false, + bool omitServerName = false, bool omitInstanceName = false, bool omitIsClustered = false, bool omitVersion = false, + bool shuffleKeys = false) + { + string serverNameKey = GenerateKeyValuePair("ServerName", serverName, omitServerName); + string instanceNameKey = GenerateKeyValuePair("InstanceName", instanceName, omitInstanceName); + string isClusteredKey = GenerateKeyValuePair("IsClustered", isClustered ? "Yes" : "No", omitIsClustered); + string versionKey = GenerateKeyValuePair("Version", version, omitVersion); + string[] components = + shuffleKeys + ? [protocolParameters ?? string.Empty, versionKey, isClusteredKey, instanceNameKey, serverNameKey] + : [serverNameKey, instanceNameKey, isClusteredKey, versionKey, protocolParameters ?? string.Empty]; + string outputString = string.Join(";", components) + + (omitTrailingSemicolons ? string.Empty : ";;"); + + return Encoding.UTF8.GetBytes(outputString); + + string GenerateKeyValuePair(string key, string value, bool omitKey) + { + if (omitKey) + { + return string.Empty; + } + + if (lowercaseKey) + { + key = key.ToLower(); + } + + return key + (omitKeyValueSeparators ? string.Empty : ";") + value; + } + } + + /// + /// Creates the protocol parameters for a RESP_DATA section. + /// + /// If non-null, the TCP_INFO data. + /// If non-null, the NP_INFO data. + /// If non-null, the VIA_INFO data. + /// If non-null, the RPC_INFO data + /// If non-null, the SPX_INFO data. + /// If non-null, the ADSP_INFO data. + /// If non-null, the BV_INFO data. + /// Any additional protocol parameters to include. + /// The collated protocol parameters for a RESP_DATA section. + private static string CreateProtocolParameters(string? tcpInfo, string? npInfo, + string? viaInfo, string? rpcInfo, string? spxInfo, string? adspInfo, string? bvInfo, + string? otherParameters) + { + List protocolParameters = []; + ReadOnlySpan allParams = [tcpInfo, npInfo, viaInfo, rpcInfo, spxInfo, adspInfo, bvInfo, otherParameters]; + + foreach (string? param in allParams) + { + if (param is not null) + { + protocolParameters.Add(param); + } + } + + return string.Join(";", protocolParameters); + } +}