From d697cfce83bfbce339c00d087c9d33c8a00494ac Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Wed, 3 Sep 2025 14:15:19 +0530 Subject: [PATCH 1/4] Added support for Variant datatype in SQLAlchemy --- src/databricks/sqlalchemy/__init__.py | 3 +- src/databricks/sqlalchemy/_parse.py | 1 + src/databricks/sqlalchemy/_types.py | 45 +++++++++++ tests/test_local/e2e/test_complex_types.py | 88 +++++++++++++++++++++- tests/test_local/test_ddl.py | 6 +- tests/test_local/test_types.py | 4 +- 6 files changed, 141 insertions(+), 6 deletions(-) diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index 81d35d6..e0c59b8 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -5,6 +5,7 @@ TIMESTAMP_NTZ, DatabricksArray, DatabricksMap, + DatabricksVariant, ) -__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"] +__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap", "DatabricksVariant"] diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 1541e28..37a6cc4 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -318,6 +318,7 @@ def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[st "map": sqlalchemy.types.String, "struct": sqlalchemy.types.String, "uniontype": sqlalchemy.types.String, + "variant": type_overrides.DatabricksVariant, "decimal": sqlalchemy.types.Numeric, "timestamp": type_overrides.TIMESTAMP, "timestamp_ntz": type_overrides.TIMESTAMP_NTZ, diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index bc996bb..718cb39 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -9,6 +9,7 @@ from databricks.sql.utils import ParamEscaper +from sqlalchemy.sql import expression def process_literal_param_hack(value: Any): """This method is supposed to accept a Python type and return a string representation of that type. @@ -397,3 +398,47 @@ def compile_databricks_map(type_, compiler, **kw): key_type = compiler.process(type_.key_type, **kw) value_type = compiler.process(type_.value_type, **kw) return f"MAP<{key_type},{value_type}>" + +class DatabricksVariant(UserDefinedType): + """ + A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types. + Note: VARIANT MAP types can only have STRING keys. + + Examples: + DatabricksVariant() -> VARIANT + + Usage: + Column('data', DatabricksVariant()) + """ + cache_ok = True + + def __init__(self): + self.pe = ParamEscaper() + + def bind_processor(self, dialect): + """Process values before sending to database. + """ + + def process(value): + return value + + return process + + def bind_expression(self, bindvalue): + """Wrap with PARSE_JSON() in SQL""" + return expression.func.PARSE_JSON(bindvalue) + + def literal_processor(self, dialect): + """Process literal values for SQL generation. + For VARIANT columns, use PARSE_JSON() to properly insert data. + """ + def process(value): + if value is None: + return "NULL" + return self.pe.escape_string(value) + + return f"PARSE_JSON('{process}')" + +@compiles(DatabricksVariant, "databricks") +def compile_variant(type_, compiler, **kw): + return "VARIANT" diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py index 22cc57a..9a8d7e7 100644 --- a/tests/test_local/e2e/test_complex_types.py +++ b/tests/test_local/e2e/test_complex_types.py @@ -11,13 +11,14 @@ DateTime, ) from collections.abc import Sequence -from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap +from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant from sqlalchemy.orm import DeclarativeBase, Session from sqlalchemy import select from datetime import date, datetime, time, timedelta, timezone import pandas as pd import numpy as np import decimal +import json class TestComplexTypes(TestSetup): @@ -46,7 +47,7 @@ def _parse_to_common_type(self, value): ): return tuple(value) elif isinstance(value, dict): - return tuple(value.items()) + return tuple(sorted(value.items())) elif isinstance(value, np.generic): return value.item() elif isinstance(value, decimal.Decimal): @@ -152,6 +153,35 @@ class MapTable(Base): return MapTable, sample_data + def sample_variant_table(self) -> tuple[DeclarativeBase, dict]: + class Base(DeclarativeBase): + pass + + class VariantTable(Base): + __tablename__ = "sqlalchemy_variant_table" + + int_col = Column(Integer, primary_key=True) + variant_simple_col = Column(DatabricksVariant()) + variant_nested_col = Column(DatabricksVariant()) + variant_array_col = Column(DatabricksVariant()) + variant_mixed_col = Column(DatabricksVariant()) + + sample_data = { + "int_col": 1, + "variant_simple_col": {"key": "value", "number": 42}, + "variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True}, + "variant_array_col": [1, 2, 3, "hello", {"nested": "data"}], + "variant_mixed_col": { + "string": "test", + "number": 123, + "boolean": True, + "array": [1, 2, 3], + "object": {"nested": "value"} + } + } + + return VariantTable, sample_data + def test_insert_array_table_sqlalchemy(self): table, sample_data = self.sample_array_table() @@ -209,3 +239,57 @@ def test_map_table_creation_pandas(self): stmt = select(table) df_result = pd.read_sql(stmt, engine) assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + + def test_insert_variant_table_sqlalchemy(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + # Pre-serialize variant data for SQLAlchemy + variant_data = sample_data.copy() + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key]) + + sa_obj = table(**variant_data) + session = Session(engine) + session.add(sa_obj) + session.commit() + + stmt = select(table).where(table.int_col == 1) + + result = session.scalar(stmt) + + compare = {key: getattr(result, key) for key in sample_data.keys()} + # Parse JSON values back to original format for comparison + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if compare[key] is not None: + compare[key] = json.loads(compare[key]) + assert self._recursive_compare(compare, sample_data) + + def test_variant_table_creation_pandas(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + # Pre-serialize variant data for pandas + variant_data = sample_data.copy() + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key]) + + # Insert the data into the table + df = pd.DataFrame([variant_data]) + dtype_mapping = { + "variant_simple_col": DatabricksVariant, + "variant_nested_col": DatabricksVariant, + "variant_array_col": DatabricksVariant, + "variant_mixed_col": DatabricksVariant + } + df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping) + + # Read the data from the table + stmt = select(table) + df_result = pd.read_sql(stmt, engine) + result_dict = df_result.iloc[0].to_dict() + # Parse JSON values back to original format for comparison + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if result_dict[key] is not None: + result_dict[key] = json.loads(result_dict[key]) + assert self._recursive_compare(result_dict, sample_data) \ No newline at end of file diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index 0b04c2e..9b19acf 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -7,7 +7,7 @@ SetColumnComment, SetTableComment, ) -from databricks.sqlalchemy import DatabricksArray, DatabricksMap +from databricks.sqlalchemy import DatabricksArray, DatabricksMap, DatabricksVariant class DDLTestBase: @@ -103,7 +103,8 @@ def metadata(self) -> MetaData: metadata = MetaData() col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String))) col2 = Column("map_string_string", DatabricksMap(String, String)) - table = Table("complex_type", metadata, col1, col2) + col3 = Column("variant_col", DatabricksVariant()) + table = Table("complex_type", metadata, col1, col2, col3) return metadata def test_create_table_with_complex_type(self, metadata): @@ -112,3 +113,4 @@ def test_create_table_with_complex_type(self, metadata): assert "array_array_string ARRAY>" in output assert "map_string_string MAP" in output + assert "variant_col VARIANT" in output diff --git a/tests/test_local/test_types.py b/tests/test_local/test_types.py index b91217e..4f99e4b 100644 --- a/tests/test_local/test_types.py +++ b/tests/test_local/test_types.py @@ -4,7 +4,7 @@ import sqlalchemy from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ +from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant class DatabricksDataType(enum.Enum): @@ -28,6 +28,7 @@ class DatabricksDataType(enum.Enum): ARRAY = enum.auto() MAP = enum.auto() STRUCT = enum.auto() + VARIANT = enum.auto() # Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types. @@ -131,6 +132,7 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self): TINYINT: DatabricksDataType.TINYINT, TIMESTAMP: DatabricksDataType.TIMESTAMP, TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ, + DatabricksVariant: DatabricksDataType.VARIANT, } From 5ec04bc2865c1b1598a7abfe967c8be89ebc7be8 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Thu, 4 Sep 2025 11:55:48 +0530 Subject: [PATCH 2/4] Allows user to directly pass object for variant type. --- src/databricks/sqlalchemy/_types.py | 13 +++++++++++-- tests/test_local/e2e/test_complex_types.py | 20 ++++---------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index 718cb39..ab284ec 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -10,6 +10,7 @@ from databricks.sql.utils import ParamEscaper from sqlalchemy.sql import expression +import json def process_literal_param_hack(value: Any): """This method is supposed to accept a Python type and return a string representation of that type. @@ -420,7 +421,12 @@ def bind_processor(self, dialect): """ def process(value): - return value + if value is None: + return None + try: + return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + except (TypeError, ValueError) as e: + raise ValueError(f"Cannot serialize value {value} to JSON: {e}") return process @@ -435,7 +441,10 @@ def literal_processor(self, dialect): def process(value): if value is None: return "NULL" - return self.pe.escape_string(value) + try: + return self.pe.escape_string(json.dumps(value, ensure_ascii=False, separators=(',', ':'))) + except (TypeError, ValueError) as e: + raise ValueError(f"Cannot serialize value {value} to JSON: {e}") return f"PARSE_JSON('{process}')" diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py index 9a8d7e7..9932fcd 100644 --- a/tests/test_local/e2e/test_complex_types.py +++ b/tests/test_local/e2e/test_complex_types.py @@ -20,7 +20,6 @@ import decimal import json - class TestComplexTypes(TestSetup): def _parse_to_common_type(self, value): """ @@ -244,38 +243,28 @@ def test_insert_variant_table_sqlalchemy(self): table, sample_data = self.sample_variant_table() with self.table_context(table) as engine: - # Pre-serialize variant data for SQLAlchemy - variant_data = sample_data.copy() - for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: - variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key]) - - sa_obj = table(**variant_data) + + sa_obj = table(**sample_data) session = Session(engine) session.add(sa_obj) session.commit() stmt = select(table).where(table.int_col == 1) - result = session.scalar(stmt) - compare = {key: getattr(result, key) for key in sample_data.keys()} # Parse JSON values back to original format for comparison for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: if compare[key] is not None: compare[key] = json.loads(compare[key]) + assert self._recursive_compare(compare, sample_data) def test_variant_table_creation_pandas(self): table, sample_data = self.sample_variant_table() with self.table_context(table) as engine: - # Pre-serialize variant data for pandas - variant_data = sample_data.copy() - for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: - variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key]) - # Insert the data into the table - df = pd.DataFrame([variant_data]) + df = pd.DataFrame([sample_data]) dtype_mapping = { "variant_simple_col": DatabricksVariant, "variant_nested_col": DatabricksVariant, @@ -284,7 +273,6 @@ def test_variant_table_creation_pandas(self): } df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping) - # Read the data from the table stmt = select(table) df_result = pd.read_sql(stmt, engine) result_dict = df_result.iloc[0].to_dict() From ed7cd9432b591f7b3854153312b72a1fe5f1b753 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 5 Sep 2025 12:21:11 +0530 Subject: [PATCH 3/4] Added variant to sqlalchemy_example and added a test for literal_processor for variant --- sqlalchemy_example.py | 36 +++++++++++++++++--- src/databricks/sqlalchemy/_types.py | 2 +- tests/test_local/e2e/test_complex_types.py | 39 ++++++++++++++++++++-- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/sqlalchemy_example.py b/sqlalchemy_example.py index 1bdd14e..d76716f 100644 --- a/sqlalchemy_example.py +++ b/sqlalchemy_example.py @@ -17,11 +17,12 @@ from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from uuid import UUID +import json # By convention, backend-specific SQLA types are defined in uppercase -# This dialect exposes Databricks SQL's TIMESTAMP and TINYINT types +# This dialect exposes Databricks SQL's TIMESTAMP, TINYINT, and VARIANT types # as these are not covered by the generic, camelcase types shown below -from databricks.sqlalchemy import TIMESTAMP, TINYINT +from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksVariant # Beside the CamelCase types shown below, line comments reflect # the underlying Databricks SQL / Delta table type @@ -82,6 +83,12 @@ class SampleObject(Base): datetime_col_ntz = Column(DateTime) time_col = Column(Time) uuid_col = Column(Uuid) + variant_col = Column(DatabricksVariant) + +Base.metadata.drop_all(engine) + +# Output SQL is: +# DROP TABLE pysql_sqlalchemy_example_table # This generates a CREATE TABLE statement against the catalog and schema # specified in the connection string @@ -100,6 +107,7 @@ class SampleObject(Base): # datetime_col_ntz TIMESTAMP_NTZ, # time_col STRING, # uuid_col STRING, +# variant_col VARIANT, # PRIMARY KEY (bigint_col) # ) USING DELTA @@ -120,6 +128,23 @@ class SampleObject(Base): "datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41), "time_col": time(23, 59, 59), "uuid_col": UUID(int=255), + "variant_col": { + "name": "John Doe", + "age": 30, + "address": { + "street": "123 Main St", + "city": "San Francisco", + "state": "CA", + "zip": "94105" + }, + "hobbies": ["reading", "hiking", "cooking"], + "is_active": True, + "metadata": { + "created_at": "2024-01-15T10:30:00Z", + "version": 1.2, + "tags": ["premium", "verified"] + } + }, } sa_obj = SampleObject(**sample_object) @@ -140,7 +165,8 @@ class SampleObject(Base): # datetime_col, # datetime_col_ntz, # time_col, -# uuid_col +# uuid_col, +# variant_col # ) # VALUES # ( @@ -154,7 +180,8 @@ class SampleObject(Base): # :datetime_col, # :datetime_col_ntz, # :time_col, -# :uuid_col +# :uuid_col, +# PARSE_JSON(:variant_col) # ) # Here we build a SELECT query using ORM @@ -165,6 +192,7 @@ class SampleObject(Base): # Finally, we read out the input data and compare it to the output compare = {key: getattr(result, key) for key in sample_object.keys()} +compare['variant_col'] = json.loads(compare['variant_col']) assert compare == sample_object # Then we drop the demonstration table diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index ab284ec..16a7081 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -446,7 +446,7 @@ def process(value): except (TypeError, ValueError) as e: raise ValueError(f"Cannot serialize value {value} to JSON: {e}") - return f"PARSE_JSON('{process}')" + return process @compiles(DatabricksVariant, "databricks") def compile_variant(type_, compiler, **kw): diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py index 9932fcd..07cd637 100644 --- a/tests/test_local/e2e/test_complex_types.py +++ b/tests/test_local/e2e/test_complex_types.py @@ -257,7 +257,7 @@ def test_insert_variant_table_sqlalchemy(self): if compare[key] is not None: compare[key] = json.loads(compare[key]) - assert self._recursive_compare(compare, sample_data) + assert compare == sample_data def test_variant_table_creation_pandas(self): table, sample_data = self.sample_variant_table() @@ -280,4 +280,39 @@ def test_variant_table_creation_pandas(self): for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: if result_dict[key] is not None: result_dict[key] = json.loads(result_dict[key]) - assert self._recursive_compare(result_dict, sample_data) \ No newline at end of file + + assert result_dict == sample_data + + def test_variant_literal_processor(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + stmt = table.__table__.insert().values(**sample_data) + + try: + compiled = stmt.compile( + dialect=engine.dialect, + compile_kwargs={"literal_binds": True} + ) + sql_str = str(compiled) + + # Assert that JSON actually got inlined + assert '{"key":"value","number":42}' in sql_str + except NotImplementedError: + raise + + with engine.begin() as conn: + conn.execute(stmt) + + session = Session(engine) + stmt_select = select(table).where(table.int_col == sample_data["int_col"]) + result = session.scalar(stmt_select) + + compare = {key: getattr(result, key) for key in sample_data.keys()} + + # Parse JSON values back to original Python objects + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if compare[key] is not None: + compare[key] = json.loads(compare[key]) + + assert compare == sample_data From 446c496017b0762145801b1f24129a88ddfec5bc Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 5 Sep 2025 12:28:10 +0530 Subject: [PATCH 4/4] Lint fix --- src/databricks/sqlalchemy/__init__.py | 9 ++++++++- src/databricks/sqlalchemy/_types.py | 22 ++++++++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index e0c59b8..af2ebb2 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -8,4 +8,11 @@ DatabricksVariant, ) -__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap", "DatabricksVariant"] +__all__ = [ + "TINYINT", + "TIMESTAMP", + "TIMESTAMP_NTZ", + "DatabricksArray", + "DatabricksMap", + "DatabricksVariant", +] diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index 16a7081..c9ca9c8 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -12,6 +12,7 @@ from sqlalchemy.sql import expression import json + def process_literal_param_hack(value: Any): """This method is supposed to accept a Python type and return a string representation of that type. But due to some weirdness in the way SQLAlchemy's literal rendering works, we have to return @@ -400,31 +401,32 @@ def compile_databricks_map(type_, compiler, **kw): value_type = compiler.process(type_.value_type, **kw) return f"MAP<{key_type},{value_type}>" + class DatabricksVariant(UserDefinedType): """ A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types. Note: VARIANT MAP types can only have STRING keys. - + Examples: DatabricksVariant() -> VARIANT - + Usage: Column('data', DatabricksVariant()) """ + cache_ok = True def __init__(self): self.pe = ParamEscaper() def bind_processor(self, dialect): - """Process values before sending to database. - """ + """Process values before sending to database.""" def process(value): if value is None: return None try: - return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + return json.dumps(value, ensure_ascii=False, separators=(",", ":")) except (TypeError, ValueError) as e: raise ValueError(f"Cannot serialize value {value} to JSON: {e}") @@ -435,19 +437,23 @@ def bind_expression(self, bindvalue): return expression.func.PARSE_JSON(bindvalue) def literal_processor(self, dialect): - """Process literal values for SQL generation. + """Process literal values for SQL generation. For VARIANT columns, use PARSE_JSON() to properly insert data. """ + def process(value): if value is None: return "NULL" try: - return self.pe.escape_string(json.dumps(value, ensure_ascii=False, separators=(',', ':'))) + return self.pe.escape_string( + json.dumps(value, ensure_ascii=False, separators=(",", ":")) + ) except (TypeError, ValueError) as e: raise ValueError(f"Cannot serialize value {value} to JSON: {e}") - + return process + @compiles(DatabricksVariant, "databricks") def compile_variant(type_, compiler, **kw): return "VARIANT"