diff --git a/sqlalchemy_example.py b/sqlalchemy_example.py index 1bdd14e..d76716f 100644 --- a/sqlalchemy_example.py +++ b/sqlalchemy_example.py @@ -17,11 +17,12 @@ from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from uuid import UUID +import json # By convention, backend-specific SQLA types are defined in uppercase -# This dialect exposes Databricks SQL's TIMESTAMP and TINYINT types +# This dialect exposes Databricks SQL's TIMESTAMP, TINYINT, and VARIANT types # as these are not covered by the generic, camelcase types shown below -from databricks.sqlalchemy import TIMESTAMP, TINYINT +from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksVariant # Beside the CamelCase types shown below, line comments reflect # the underlying Databricks SQL / Delta table type @@ -82,6 +83,12 @@ class SampleObject(Base): datetime_col_ntz = Column(DateTime) time_col = Column(Time) uuid_col = Column(Uuid) + variant_col = Column(DatabricksVariant) + +Base.metadata.drop_all(engine) + +# Output SQL is: +# DROP TABLE pysql_sqlalchemy_example_table # This generates a CREATE TABLE statement against the catalog and schema # specified in the connection string @@ -100,6 +107,7 @@ class SampleObject(Base): # datetime_col_ntz TIMESTAMP_NTZ, # time_col STRING, # uuid_col STRING, +# variant_col VARIANT, # PRIMARY KEY (bigint_col) # ) USING DELTA @@ -120,6 +128,23 @@ class SampleObject(Base): "datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41), "time_col": time(23, 59, 59), "uuid_col": UUID(int=255), + "variant_col": { + "name": "John Doe", + "age": 30, + "address": { + "street": "123 Main St", + "city": "San Francisco", + "state": "CA", + "zip": "94105" + }, + "hobbies": ["reading", "hiking", "cooking"], + "is_active": True, + "metadata": { + "created_at": "2024-01-15T10:30:00Z", + "version": 1.2, + "tags": ["premium", "verified"] + } + }, } sa_obj = SampleObject(**sample_object) @@ -140,7 +165,8 @@ class SampleObject(Base): # datetime_col, # datetime_col_ntz, # time_col, -# uuid_col +# uuid_col, +# variant_col # ) # VALUES # ( @@ -154,7 +180,8 @@ class SampleObject(Base): # :datetime_col, # :datetime_col_ntz, # :time_col, -# :uuid_col +# :uuid_col, +# PARSE_JSON(:variant_col) # ) # Here we build a SELECT query using ORM @@ -165,6 +192,7 @@ class SampleObject(Base): # Finally, we read out the input data and compare it to the output compare = {key: getattr(result, key) for key in sample_object.keys()} +compare['variant_col'] = json.loads(compare['variant_col']) assert compare == sample_object # Then we drop the demonstration table diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index 81d35d6..af2ebb2 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -5,6 +5,14 @@ TIMESTAMP_NTZ, DatabricksArray, DatabricksMap, + DatabricksVariant, ) -__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"] +__all__ = [ + "TINYINT", + "TIMESTAMP", + "TIMESTAMP_NTZ", + "DatabricksArray", + "DatabricksMap", + "DatabricksVariant", +] diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 1541e28..37a6cc4 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -318,6 +318,7 @@ def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[st "map": sqlalchemy.types.String, "struct": sqlalchemy.types.String, "uniontype": sqlalchemy.types.String, + "variant": type_overrides.DatabricksVariant, "decimal": sqlalchemy.types.Numeric, "timestamp": type_overrides.TIMESTAMP, "timestamp_ntz": type_overrides.TIMESTAMP_NTZ, diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index bc996bb..c9ca9c8 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -9,6 +9,9 @@ from databricks.sql.utils import ParamEscaper +from sqlalchemy.sql import expression +import json + def process_literal_param_hack(value: Any): """This method is supposed to accept a Python type and return a string representation of that type. @@ -397,3 +400,60 @@ def compile_databricks_map(type_, compiler, **kw): key_type = compiler.process(type_.key_type, **kw) value_type = compiler.process(type_.value_type, **kw) return f"MAP<{key_type},{value_type}>" + + +class DatabricksVariant(UserDefinedType): + """ + A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types. + Note: VARIANT MAP types can only have STRING keys. + + Examples: + DatabricksVariant() -> VARIANT + + Usage: + Column('data', DatabricksVariant()) + """ + + cache_ok = True + + def __init__(self): + self.pe = ParamEscaper() + + def bind_processor(self, dialect): + """Process values before sending to database.""" + + def process(value): + if value is None: + return None + try: + return json.dumps(value, ensure_ascii=False, separators=(",", ":")) + except (TypeError, ValueError) as e: + raise ValueError(f"Cannot serialize value {value} to JSON: {e}") + + return process + + def bind_expression(self, bindvalue): + """Wrap with PARSE_JSON() in SQL""" + return expression.func.PARSE_JSON(bindvalue) + + def literal_processor(self, dialect): + """Process literal values for SQL generation. + For VARIANT columns, use PARSE_JSON() to properly insert data. + """ + + def process(value): + if value is None: + return "NULL" + try: + return self.pe.escape_string( + json.dumps(value, ensure_ascii=False, separators=(",", ":")) + ) + except (TypeError, ValueError) as e: + raise ValueError(f"Cannot serialize value {value} to JSON: {e}") + + return process + + +@compiles(DatabricksVariant, "databricks") +def compile_variant(type_, compiler, **kw): + return "VARIANT" diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py index 22cc57a..07cd637 100644 --- a/tests/test_local/e2e/test_complex_types.py +++ b/tests/test_local/e2e/test_complex_types.py @@ -11,14 +11,14 @@ DateTime, ) from collections.abc import Sequence -from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap +from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant from sqlalchemy.orm import DeclarativeBase, Session from sqlalchemy import select from datetime import date, datetime, time, timedelta, timezone import pandas as pd import numpy as np import decimal - +import json class TestComplexTypes(TestSetup): def _parse_to_common_type(self, value): @@ -46,7 +46,7 @@ def _parse_to_common_type(self, value): ): return tuple(value) elif isinstance(value, dict): - return tuple(value.items()) + return tuple(sorted(value.items())) elif isinstance(value, np.generic): return value.item() elif isinstance(value, decimal.Decimal): @@ -152,6 +152,35 @@ class MapTable(Base): return MapTable, sample_data + def sample_variant_table(self) -> tuple[DeclarativeBase, dict]: + class Base(DeclarativeBase): + pass + + class VariantTable(Base): + __tablename__ = "sqlalchemy_variant_table" + + int_col = Column(Integer, primary_key=True) + variant_simple_col = Column(DatabricksVariant()) + variant_nested_col = Column(DatabricksVariant()) + variant_array_col = Column(DatabricksVariant()) + variant_mixed_col = Column(DatabricksVariant()) + + sample_data = { + "int_col": 1, + "variant_simple_col": {"key": "value", "number": 42}, + "variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True}, + "variant_array_col": [1, 2, 3, "hello", {"nested": "data"}], + "variant_mixed_col": { + "string": "test", + "number": 123, + "boolean": True, + "array": [1, 2, 3], + "object": {"nested": "value"} + } + } + + return VariantTable, sample_data + def test_insert_array_table_sqlalchemy(self): table, sample_data = self.sample_array_table() @@ -209,3 +238,81 @@ def test_map_table_creation_pandas(self): stmt = select(table) df_result = pd.read_sql(stmt, engine) assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + + def test_insert_variant_table_sqlalchemy(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + + sa_obj = table(**sample_data) + session = Session(engine) + session.add(sa_obj) + session.commit() + + stmt = select(table).where(table.int_col == 1) + result = session.scalar(stmt) + compare = {key: getattr(result, key) for key in sample_data.keys()} + # Parse JSON values back to original format for comparison + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if compare[key] is not None: + compare[key] = json.loads(compare[key]) + + assert compare == sample_data + + def test_variant_table_creation_pandas(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + + df = pd.DataFrame([sample_data]) + dtype_mapping = { + "variant_simple_col": DatabricksVariant, + "variant_nested_col": DatabricksVariant, + "variant_array_col": DatabricksVariant, + "variant_mixed_col": DatabricksVariant + } + df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping) + + stmt = select(table) + df_result = pd.read_sql(stmt, engine) + result_dict = df_result.iloc[0].to_dict() + # Parse JSON values back to original format for comparison + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if result_dict[key] is not None: + result_dict[key] = json.loads(result_dict[key]) + + assert result_dict == sample_data + + def test_variant_literal_processor(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + stmt = table.__table__.insert().values(**sample_data) + + try: + compiled = stmt.compile( + dialect=engine.dialect, + compile_kwargs={"literal_binds": True} + ) + sql_str = str(compiled) + + # Assert that JSON actually got inlined + assert '{"key":"value","number":42}' in sql_str + except NotImplementedError: + raise + + with engine.begin() as conn: + conn.execute(stmt) + + session = Session(engine) + stmt_select = select(table).where(table.int_col == sample_data["int_col"]) + result = session.scalar(stmt_select) + + compare = {key: getattr(result, key) for key in sample_data.keys()} + + # Parse JSON values back to original Python objects + for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + if compare[key] is not None: + compare[key] = json.loads(compare[key]) + + assert compare == sample_data diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index 0b04c2e..9b19acf 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -7,7 +7,7 @@ SetColumnComment, SetTableComment, ) -from databricks.sqlalchemy import DatabricksArray, DatabricksMap +from databricks.sqlalchemy import DatabricksArray, DatabricksMap, DatabricksVariant class DDLTestBase: @@ -103,7 +103,8 @@ def metadata(self) -> MetaData: metadata = MetaData() col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String))) col2 = Column("map_string_string", DatabricksMap(String, String)) - table = Table("complex_type", metadata, col1, col2) + col3 = Column("variant_col", DatabricksVariant()) + table = Table("complex_type", metadata, col1, col2, col3) return metadata def test_create_table_with_complex_type(self, metadata): @@ -112,3 +113,4 @@ def test_create_table_with_complex_type(self, metadata): assert "array_array_string ARRAY>" in output assert "map_string_string MAP" in output + assert "variant_col VARIANT" in output diff --git a/tests/test_local/test_types.py b/tests/test_local/test_types.py index b91217e..4f99e4b 100644 --- a/tests/test_local/test_types.py +++ b/tests/test_local/test_types.py @@ -4,7 +4,7 @@ import sqlalchemy from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ +from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant class DatabricksDataType(enum.Enum): @@ -28,6 +28,7 @@ class DatabricksDataType(enum.Enum): ARRAY = enum.auto() MAP = enum.auto() STRUCT = enum.auto() + VARIANT = enum.auto() # Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types. @@ -131,6 +132,7 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self): TINYINT: DatabricksDataType.TINYINT, TIMESTAMP: DatabricksDataType.TIMESTAMP, TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ, + DatabricksVariant: DatabricksDataType.VARIANT, }