From e00077d6220984a7a9533d9eaa951f85699dedeb Mon Sep 17 00:00:00 2001 From: Edward Neal <55035479+edwardneal@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:37:11 +0100 Subject: [PATCH] Allow use of aliases in SqlBulkCopy Aliases are: $to_id, $from_id, $node_id, $edge_id --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 85 ++++++++++- .../SqlClient/SqlBulkCopyColumnMapping.cs | 8 ++ .../SQL/SqlBulkCopyTest/CopyAllFromReader.cs | 6 +- .../SQL/SqlBulkCopyTest/SqlGraphTables.cs | 134 +++++++++++++++++- 4 files changed, 225 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 280b79b0f4..694b8ae1ec 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -152,12 +152,17 @@ public SourceColumnMetadata(ValueMethod method, bool isSqlType, bool isDataFeed) // Transaction count has only one value in one column and one row // MetaData has n columns but no rows // Collation has 4 columns and n rows + // Column aliases has 3 columns and n rows private const int MetaDataResultId = 1; private const int CollationResultId = 2; private const int CollationId = 3; + private const int ColumnAliasesResultId = 3; + private const int ColumnCanonicalNameColumnId = 0; + private const int ColumnAliasColumnId = 1; + private const int MAX_LENGTH = 0x7FFFFFFF; private const int DefaultCommandTimeout = 30; @@ -467,13 +472,36 @@ private string CreateInitialQuery() // query will then continue to fail with "Invalid object name" rather than with an unusual error because the query being executed // is NULL. // Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to exclude them explicitly. + // We also include a list of column aliases. This allows someone to write data to $to_id, $from_id, and other "virtual" columns + // in SQL Server which don't physically exist, but which can be queried by name. + // SQL Server also allows columns to be created with the same name as a "virtual" column; a user may create a SQL Graph Node table + // with a real column named "$node_id". + // In such cases, querying for $node_id will return the virtual column and querying for [$node_id] will return the physical column. + // SqlBulkCopy does not follow this convention; if the table has a real column named "$node_id", mapping to the $node_id column + // will map to the real column rather than the column alias. This is for backwards compatibility purposes. return $""" SELECT @@TRANCOUNT; DECLARE @Column_Names NVARCHAR(MAX) = NULL; +DECLARE @Column_Aliases AS TABLE +( + [Canonical_Column_Name] SYSNAME, + [Canonical_Column_Id] INT, + [Aliased_Column_Name] SYSNAME +) + IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type') BEGIN SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC; + + INSERT INTO @Column_Aliases ([Canonical_Column_Name], [Canonical_Column_Id], [Aliased_Column_Name]) + SELECT [name], [column_id], '$to_id' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) = 8 + UNION ALL + SELECT [name], [column_id], '$from_id' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) = 5 + UNION ALL + SELECT [name], [column_id], '$edge_id' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) = 2 AND [name] LIKE '$edge[_]id[_]%' + UNION ALL + SELECT [name], [column_id], '$node_id' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) = 2 AND [name] LIKE '$node[_]id[_]%' END ELSE BEGIN @@ -487,6 +515,11 @@ IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sy SET FMTONLY OFF; EXEC {CatalogName}..{TableCollationsStoredProc} N'{SchemaName}.{TableName}'; + +SELECT [Canonical_Column_Name], [Aliased_Column_Name] +FROM @Column_Aliases +WHERE [Aliased_Column_Name] NOT IN (SELECT [name] FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}')) +ORDER BY [Canonical_Column_Id] ASC """; } @@ -560,9 +593,9 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i // Keep track of any result columns that we don't have a local // mapping for. #if NETFRAMEWORK - HashSet unmatchedColumns = new(); + HashSet unmatchedColumns = new(StringComparer.OrdinalIgnoreCase); #else - HashSet unmatchedColumns = new(_localColumnMappings.Count); + HashSet unmatchedColumns = new(_localColumnMappings.Count, StringComparer.OrdinalIgnoreCase); #endif // Start by assuming all locally mapped Destination columns will be @@ -572,6 +605,50 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i unmatchedColumns.Add(_localColumnMappings[i].DestinationColumn); } + // Apply any necessary column aliases. If an aliased name exists in the + // local column mappings but the canonical name does not, update them. + Result columnAliasResults = internalResults[ColumnAliasesResultId]; + for (int i = 0; i < columnAliasResults.Count; i++) + { + Row aliasRow = columnAliasResults[i]; + SqlString canonicalName = (SqlString)aliasRow[ColumnCanonicalNameColumnId]; + SqlString aliasedName = (SqlString)aliasRow[ColumnAliasColumnId]; + + if (canonicalName.IsNull || aliasedName.IsNull) + { + continue; + } + + string canonical = canonicalName.Value; + bool canonicalNameExists = unmatchedColumns.Contains(canonical) + // The destination columns might be escaped. If so, search for those instead + || unmatchedColumns.Contains(SqlServerEscapeHelper.EscapeIdentifier(canonical)); + + if (canonicalNameExists) + { + continue; + } + + // The canonical name does not exist. Look for a local column mapping which matches + // the alias (or its escaped variant) and replace its name with its canonical name. + string alias = aliasedName.Value; + string escapedAlias = SqlServerEscapeHelper.EscapeIdentifier(alias); + + for (int j = 0; j < _localColumnMappings.Count; j++) + { + if (unmatchedColumns.Comparer.Equals(_localColumnMappings[j].DestinationColumn, alias) + || unmatchedColumns.Comparer.Equals(_localColumnMappings[j].DestinationColumn, escapedAlias)) + { + unmatchedColumns.Remove(_localColumnMappings[j].DestinationColumn); + + unmatchedColumns.Add(canonical); + _localColumnMappings[j].MappedDestinationColumn = canonical; + + break; + } + } + } + // Flag to remember whether or not we need to append a comma before // the next column in the command text. bool appendComma = false; @@ -594,7 +671,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i // Are we missing a mapping between the result column and // this local column (by ordinal or name)? if (localColumn._destinationColumnOrdinal != metadata.ordinal - && UnquotedName(localColumn._destinationColumnName) != metadata.column) + && UnquotedName(localColumn.MappedDestinationColumn) != metadata.column) { // Yes, so move on to the next local column. continue; @@ -604,7 +681,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i matched = true; // Remove it from our unmatched set. - unmatchedColumns.Remove(localColumn.DestinationColumn); + unmatchedColumns.Remove(localColumn.MappedDestinationColumn); // Check for column types that we refuse to bulk load, even // though we found a match. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopyColumnMapping.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopyColumnMapping.cs index 734e0615a7..0494745323 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopyColumnMapping.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopyColumnMapping.cs @@ -18,6 +18,14 @@ public sealed class SqlBulkCopyColumnMapping // _sourceColumnOrdinal(s) will be copied to _internalSourceColumnOrdinal when WriteToServer executes. internal int _internalDestinationColumnOrdinal; internal int _internalSourceColumnOrdinal; // -1 indicates an undetermined value + internal string _mappedDestinationColumn; + + // Used by SqlBulkCopy to generate the correct column name after mapping alternate names. + internal string MappedDestinationColumn + { + get => _mappedDestinationColumn ?? DestinationColumn; + set => _mappedDestinationColumn = value; + } /// public string DestinationColumn diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs index 5ba727be5d..d1b4409fc0 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs @@ -60,12 +60,12 @@ public static void Test(string srcConstr, string dstConstr, string dstTable) DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersReceived"], "Unexpected BuffersReceived value."); DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersSent"], "Unexpected BuffersSent value."); - DataTestUtility.AssertEqualsWithDescription((long)0, stats["IduCount"], "Unexpected IduCount value."); - DataTestUtility.AssertEqualsWithDescription((long)6, stats["SelectCount"], "Unexpected SelectCount value."); + DataTestUtility.AssertEqualsWithDescription((long)1, stats["IduCount"], "Unexpected IduCount value."); + DataTestUtility.AssertEqualsWithDescription((long)7, stats["SelectCount"], "Unexpected SelectCount value."); DataTestUtility.AssertEqualsWithDescription((long)3, stats["ServerRoundtrips"], "Unexpected ServerRoundtrips value."); DataTestUtility.AssertEqualsWithDescription((long)9, stats["SelectRows"], "Unexpected SelectRows value."); DataTestUtility.AssertEqualsWithDescription((long)2, stats["SumResultSets"], "Unexpected SumResultSets value."); - DataTestUtility.AssertEqualsWithDescription((long)0, stats["Transactions"], "Unexpected Transactions value."); + DataTestUtility.AssertEqualsWithDescription((long)1, stats["Transactions"], "Unexpected Transactions value."); } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs index d83693080f..c23890599d 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs @@ -15,7 +15,7 @@ public class SqlGraphTables public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds() { string connectionString = DataTestUtility.TCPConnectionString; - string destinationTable = DataTestUtility.GetShortName("SqlGraphNodeTable"); + string destinationTable = DataTestUtility.GetShortName("SqlGraph_Node"); using SqlConnection dstConn = new SqlConnection(connectionString); using DataTable nodes = new DataTable() @@ -45,5 +45,137 @@ public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds() DataTestUtility.DropTable(dstConn, destinationTable); } } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse))] + public void WriteToServer_CopyToAliasedColumnName_Succeeds() + { + string connectionString = DataTestUtility.TCPConnectionString; + string nodeTable = DataTestUtility.GetShortName("SqlGraph_NodeByAlias"); + string edgeTable = DataTestUtility.GetShortName("SqlGraph_EdgeByAlias"); + + using SqlConnection dstConn = new SqlConnection(connectionString); + using DataTable edges = new DataTable() + { + Columns = { new DataColumn("To_ID", typeof(string)), new DataColumn("From_ID", typeof(string)), new DataColumn("Description", typeof(string)) } + }; + + dstConn.Open(); + + try + { + DataTestUtility.CreateTable(dstConn, nodeTable, "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100)) AS NODE"); + DataTestUtility.CreateTable(dstConn, edgeTable, "([Description] VARCHAR(100)) AS EDGE"); + + string sampleNodeDataCommand = @$"INSERT INTO {nodeTable} ([Name]) SELECT LEFT([name], 100) FROM sys.sysobjects"; + using (SqlCommand insertSampleNodes = new(sampleNodeDataCommand, dstConn)) + { + insertSampleNodes.ExecuteNonQuery(); + } + + using (SqlCommand nodeQuery = new SqlCommand($"SELECT $node_id FROM {nodeTable}", dstConn)) + using (SqlDataReader reader = nodeQuery.ExecuteReader()) + { + bool firstRead = reader.Read(); + string toId; + string fromId; + + Assert.True(firstRead); + toId = reader.GetString(0); + + while (reader.Read()) + { + fromId = reader.GetString(0); + + edges.Rows.Add(toId, fromId, "Test Description"); + toId = fromId; + } + } + + using SqlBulkCopy edgeCopy = new SqlBulkCopy(dstConn); + + edgeCopy.DestinationTableName = edgeTable; + edgeCopy.ColumnMappings.Add("To_ID", "$to_id"); + edgeCopy.ColumnMappings.Add("From_ID", "$from_id"); + edgeCopy.ColumnMappings.Add("Description", "Description"); + + edgeCopy.WriteToServer(edges); + } + finally + { + DataTestUtility.DropTable(dstConn, nodeTable); + DataTestUtility.DropTable(dstConn, edgeTable); + } + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse))] + public void WriteToServer_CopyToTableWithSameNameAsColumnAlias_Succeeds() + { + string connectionString = DataTestUtility.TCPConnectionString; + string destinationGraphTable = DataTestUtility.GetShortName("SqlGraph_NodeWithAlias"); + string destinationNormalTable = DataTestUtility.GetShortName("NonGraph_NodeWithAlias"); + + using SqlConnection dstConn = new SqlConnection(connectionString); + using DataTable nodes = new DataTable() + { + Columns = { new DataColumn("Name", typeof(string)) } + }; + + dstConn.Open(); + + for (int i = 0; i < 5; i++) + { + nodes.Rows.Add($"Name {i}"); + } + + try + { + DataTestUtility.CreateTable(dstConn, destinationGraphTable, "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100), [$node_id] VARCHAR(100)) AS NODE"); + DataTestUtility.CreateTable(dstConn, destinationNormalTable, "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100), [$node_id] VARCHAR(100))"); + + using (SqlBulkCopy nodeCopy = new SqlBulkCopy(dstConn)) + { + nodeCopy.DestinationTableName = destinationGraphTable; + nodeCopy.ColumnMappings.Add("Name", "Name"); + nodeCopy.ColumnMappings.Add("Name", "$node_id"); + nodeCopy.WriteToServer(nodes); + + nodeCopy.DestinationTableName = destinationNormalTable; + nodeCopy.WriteToServer(nodes); + } + + // Read the values back, ensuring that we haven't overwritten the $node_id alias with the contents of the [$node_id] column. + // SELECTing $node_id will read the SQL Graph's node ID, SELECTing [$node_id] will read the column named $node_id. + using (SqlCommand graphVerificationCommand = new SqlCommand($"SELECT Id, $node_id, [$node_id], Name FROM {destinationGraphTable}", dstConn)) + using (SqlDataReader reader = graphVerificationCommand.ExecuteReader()) + { + while (reader.Read()) + { + string aliasNodeId = reader.GetString(1); + string physicalNodeId = reader.GetString(2); + string name = reader.GetString(3); + + Assert.NotEqual(physicalNodeId, aliasNodeId); + Assert.Equal(name, physicalNodeId); + } + } + + using (SqlCommand normalVerificationCommand = new SqlCommand($"SELECT [$node_id], Name FROM {destinationNormalTable}", dstConn)) + using (SqlDataReader reader = normalVerificationCommand.ExecuteReader()) + { + while (reader.Read()) + { + string physicalNodeId = reader.GetString(0); + string name = reader.GetString(1); + + Assert.Equal(name, physicalNodeId); + } + } + } + finally + { + DataTestUtility.DropTable(dstConn, destinationGraphTable); + DataTestUtility.DropTable(dstConn, destinationNormalTable); + } + } } }