-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add transactional batch support for Cosmos DB Spark connector #47478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3bd45e9
cbcf148
34b48bf
231f91d
592320b
2c105fe
6b05ce8
119cd3d
10ec6ee
75cd979
d70777e
093f723
036ba30
5b75638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,14 @@ | |
| // Licensed under the MIT License. | ||
| package com.azure.cosmos.spark | ||
|
|
||
| import com.azure.cosmos.{CosmosAsyncClient, ReadConsistencyStrategy, SparkBridgeInternal} | ||
| import com.azure.cosmos.spark.diagnostics.LoggerHelper | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} | ||
| import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NullOrdering, SortDirection, SortOrder} | ||
| import org.apache.spark.sql.connector.metric.CustomMetric | ||
| import org.apache.spark.sql.connector.write.streaming.StreamingWrite | ||
| import org.apache.spark.sql.connector.write.{BatchWrite, Write, WriteBuilder} | ||
| import org.apache.spark.sql.connector.write.{BatchWrite, RequiresDistributionAndOrdering, Write, WriteBuilder} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.sql.util.CaseInsensitiveStringMap | ||
|
|
||
|
|
@@ -46,7 +49,7 @@ private class ItemsWriterBuilder | |
| diagnosticsConfig, | ||
| sparkEnvironmentInfo) | ||
|
|
||
| private class CosmosWrite extends Write { | ||
| private class CosmosWrite extends Write with RequiresDistributionAndOrdering { | ||
|
|
||
| private[this] val supportedCosmosMetrics: Array[CustomMetric] = { | ||
| Array( | ||
|
|
@@ -56,22 +59,127 @@ private class ItemsWriterBuilder | |
| ) | ||
| } | ||
|
|
||
| // Extract userConfig conversion to avoid repeated calls | ||
| private[this] val userConfigMap = userConfig.asCaseSensitiveMap().asScala.toMap | ||
|
|
||
| private[this] val writeConfig = CosmosWriteConfig.parseWriteConfig( | ||
| userConfigMap, | ||
| inputSchema | ||
| ) | ||
|
|
||
| private[this] val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig( | ||
| userConfigMap | ||
| ) | ||
|
|
||
| override def toBatch(): BatchWrite = | ||
| new ItemsBatchWriter( | ||
| userConfig.asCaseSensitiveMap().asScala.toMap, | ||
| userConfigMap, | ||
| inputSchema, | ||
| cosmosClientStateHandles, | ||
| diagnosticsConfig, | ||
| sparkEnvironmentInfo) | ||
|
|
||
| override def toStreaming: StreamingWrite = | ||
| new ItemsBatchWriter( | ||
| userConfig.asCaseSensitiveMap().asScala.toMap, | ||
| userConfigMap, | ||
| inputSchema, | ||
| cosmosClientStateHandles, | ||
| diagnosticsConfig, | ||
| sparkEnvironmentInfo) | ||
|
|
||
| override def supportedCustomMetrics(): Array[CustomMetric] = supportedCosmosMetrics | ||
|
|
||
| override def requiredDistribution(): Distribution = { | ||
| if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) { | ||
| log.logInfo("Transactional batch mode enabled - configuring data distribution by partition key columns") | ||
| // For transactional writes, partition by all partition key columns | ||
| val partitionKeyPaths = getPartitionKeyColumnNames() | ||
| if (partitionKeyPaths.nonEmpty) { | ||
| // Use public Expressions.column() factory - returns NamedReference | ||
| val clustering = partitionKeyPaths.map(path => Expressions.column(path): Expression).toArray | ||
| Distributions.clustered(clustering) | ||
| } else { | ||
| Distributions.unspecified() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wraning/error log - pk def should never have 0 columns. At elast log woudl be good in case this happens in the wild (could in theory for extremely old containers but containers for years use artificial PK)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added error logging for edge cases, when:
|
||
| } | ||
| } else { | ||
| Distributions.unspecified() | ||
| } | ||
| } | ||
|
|
||
| override def requiredOrdering(): Array[SortOrder] = { | ||
| if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) { | ||
| // For transactional writes, order by all partition key columns (ascending) | ||
| val partitionKeyPaths = getPartitionKeyColumnNames() | ||
| if (partitionKeyPaths.nonEmpty) { | ||
| partitionKeyPaths.map { path => | ||
| // Use public Expressions.sort() factory for creating SortOrder | ||
| Expressions.sort( | ||
| Expressions.column(path), | ||
| SortDirection.ASCENDING, | ||
| NullOrdering.NULLS_FIRST | ||
| ) | ||
| }.toArray | ||
| } else { | ||
| Array.empty[SortOrder] | ||
| } | ||
| } else { | ||
| Array.empty[SortOrder] | ||
| } | ||
| } | ||
|
|
||
| private def getPartitionKeyColumnNames(): Seq[String] = { | ||
| try { | ||
| Loan( | ||
| List[Option[CosmosClientCacheItem]]( | ||
| Some(createClientForPartitionKeyLookup()) | ||
| )) | ||
| .to(clientCacheItems => { | ||
| val container = ThroughputControlHelper.getContainer( | ||
| userConfigMap, | ||
| containerConfig, | ||
| clientCacheItems(0).get, | ||
| None | ||
| ) | ||
|
|
||
| // Simplified retrieval using SparkBridgeInternal directly | ||
| val containerProperties = SparkBridgeInternal.getContainerPropertiesFromCollectionCache(container) | ||
| val partitionKeyDefinition = containerProperties.getPartitionKeyDefinition | ||
|
|
||
| extractPartitionKeyPaths(partitionKeyDefinition) | ||
| }) | ||
| } catch { | ||
| case ex: Exception => | ||
| log.logWarning(s"Failed to get partition key definition for transactional writes: ${ex.getMessage}") | ||
| Seq.empty[String] | ||
| } | ||
| } | ||
|
|
||
| private def createClientForPartitionKeyLookup(): CosmosClientCacheItem = { | ||
| CosmosClientCache( | ||
| CosmosClientConfiguration( | ||
| userConfigMap, | ||
| ReadConsistencyStrategy.EVENTUAL, | ||
| sparkEnvironmentInfo | ||
| ), | ||
| Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), | ||
| "ItemsWriterBuilder-PKLookup" | ||
| ) | ||
| } | ||
|
|
||
| private def extractPartitionKeyPaths(partitionKeyDefinition: com.azure.cosmos.models.PartitionKeyDefinition): Seq[String] = { | ||
| if (partitionKeyDefinition != null && partitionKeyDefinition.getPaths != null) { | ||
| val paths = partitionKeyDefinition.getPaths.asScala | ||
| if (paths.isEmpty) { | ||
| log.logError("Partition key definition has 0 columns - this should not happen for modern containers") | ||
| } | ||
| paths.map(path => { | ||
| // Remove leading '/' from partition key path (e.g., "/pk" -> "pk") | ||
| if (path.startsWith("/")) path.substring(1) else path | ||
| }).toSeq | ||
| } else { | ||
| log.logError("Partition key definition is null - this should not happen for modern containers") | ||
| Seq.empty[String] | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: info log
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to info log level (in constructor).