diff --git a/poetry.lock b/poetry.lock index 4d7ae9e..1e05d03 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -195,31 +195,31 @@ files = [ [[package]] name = "databricks-sql-connector" -version = "4.0.0" +version = "4.0.4" description = "Databricks SQL Connector for Python" optional = false python-versions = "<4.0.0,>=3.8.0" groups = ["main"] files = [ - {file = "databricks_sql_connector-4.0.0-py3-none-any.whl", hash = "sha256:798ebc740e992eaf435754510d1035872d3ebbc8c5cb597aa939217220463236"}, - {file = "databricks_sql_connector-4.0.0.tar.gz", hash = "sha256:3634fe3d19ee4641cdf76a77854573d9fe234ccdebd20230aaf94053397bc693"}, + {file = "databricks_sql_connector-4.0.4-py3-none-any.whl", hash = "sha256:240d0e44a0ff973c3b38ba1e63b448f29824ef0eb11ae7c5e4a9f117c88d7340"}, + {file = "databricks_sql_connector-4.0.4.tar.gz", hash = "sha256:2536e173fab7199ce6956da90c2deca4e10742135059f2e2f35303b2a76f6812"}, ] [package.dependencies] lz4 = ">=4.0.2,<5.0.0" -numpy = [ - {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, -] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.3.0", markers = "python_version >= \"3.8\""} +pandas = [ + {version = ">=1.2.5,<2.3.0", markers = "python_version >= \"3.8\" and python_version < \"3.13\""}, + {version = ">=2.2.3,<2.3.0", markers = "python_version >= \"3.13\""}, +] +python-dateutil = ">=2.8.0,<3.0.0" requests = ">=2.18.1,<3.0.0" thrift = ">=0.16.0,<0.21.0" urllib3 = ">=1.26" [package.extras] -pyarrow = ["pyarrow (>=14.0.1)"] +pyarrow = ["pyarrow (>=14.0.1) ; python_version >= \"3.8\" and python_version < \"3.13\"", "pyarrow (>=18.0.0) ; python_version >= \"3.13\""] [[package]] name = "dill" @@ -667,6 +667,7 @@ description = "Powerful data structures for data analysis, time series, and stat optional = false python-versions = ">=3.8" groups = ["main"] +markers = "python_version <= \"3.12\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -728,6 +729,90 @@ sql-other = ["SQLAlchemy (>=1.4.16)"] test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.6.3)"] +[[package]] +name = "pandas" +version = "2.2.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.13\"" +files = [ + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, +] + +[package.dependencies] +numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + [[package]] name = "pathspec" version = "0.12.1" diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index 2a17ac3..81d35d6 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -1,4 +1,10 @@ from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ +from databricks.sqlalchemy._types import ( + TINYINT, + TIMESTAMP, + TIMESTAMP_NTZ, + DatabricksArray, + DatabricksMap, +) -__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"] +__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"] diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index 5fc14a7..bc996bb 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -5,6 +5,7 @@ import sqlalchemy from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.ext.compiler import compiles +from sqlalchemy.types import TypeDecorator, UserDefinedType from databricks.sql.utils import ParamEscaper @@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any): return value +def identity_processor(value): + """This method returns the value itself, when no other processor is provided""" + return value + + @compiles(sqlalchemy.types.Enum, "databricks") @compiles(sqlalchemy.types.String, "databricks") @compiles(sqlalchemy.types.Text, "databricks") @@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator): @compiles(TINYINT, "databricks") def compile_tinyint(type_, compiler, **kw): return "TINYINT" + + +class DatabricksArray(UserDefinedType): + """ + A custom array type that can wrap any other SQLAlchemy type. + + Examples: + DatabricksArray(String) -> ARRAY + DatabricksArray(Integer) -> ARRAY + DatabricksArray(CustomType) -> ARRAY + """ + + def __init__(self, item_type): + self.item_type = item_type() if isinstance(item_type, type) else item_type + + def bind_processor(self, dialect): + item_processor = self.item_type.bind_processor(dialect) + if item_processor is None: + item_processor = identity_processor + + def process(value): + return [item_processor(val) for val in value] + + return process + + +@compiles(DatabricksArray, "databricks") +def compile_databricks_array(type_, compiler, **kw): + inner = compiler.process(type_.item_type, **kw) + + return f"ARRAY<{inner}>" + + +class DatabricksMap(UserDefinedType): + """ + A custom map type that can wrap any other SQLAlchemy types for both key and value. + + Examples: + DatabricksMap(String, String) -> MAP + DatabricksMap(Integer, String) -> MAP + DatabricksMap(String, DatabricksArray(Integer)) -> MAP> + """ + + def __init__(self, key_type, value_type): + self.key_type = key_type() if isinstance(key_type, type) else key_type + self.value_type = value_type() if isinstance(value_type, type) else value_type + + def bind_processor(self, dialect): + key_processor = self.key_type.bind_processor(dialect) + value_processor = self.value_type.bind_processor(dialect) + + if key_processor is None: + key_processor = identity_processor + if value_processor is None: + value_processor = identity_processor + + def process(value): + return { + key_processor(key): value_processor(value) + for key, value in value.items() + } + + return process + + +@compiles(DatabricksMap, "databricks") +def compile_databricks_map(type_, compiler, **kw): + key_type = compiler.process(type_.key_type, **kw) + value_type = compiler.process(type_.value_type, **kw) + return f"MAP<{key_type},{value_type}>" diff --git a/tests/test_local/e2e/__init__.py b/tests/test_local/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py new file mode 100644 index 0000000..22cc57a --- /dev/null +++ b/tests/test_local/e2e/test_complex_types.py @@ -0,0 +1,211 @@ +from .test_setup import TestSetup +from sqlalchemy import ( + Column, + BigInteger, + String, + Integer, + Numeric, + Boolean, + Date, + TIMESTAMP, + DateTime, +) +from collections.abc import Sequence +from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap +from sqlalchemy.orm import DeclarativeBase, Session +from sqlalchemy import select +from datetime import date, datetime, time, timedelta, timezone +import pandas as pd +import numpy as np +import decimal + + +class TestComplexTypes(TestSetup): + def _parse_to_common_type(self, value): + """ + Function to convert the :value passed into a common python datatype for comparison + + Convertion fyi + MAP Datatype on server is returned as a list of tuples + Ex: + {"a":1,"b":2} -> [("a",1),("b",2)] + + ARRAY Datatype on server is returned as a numpy array + Ex: + ["a","b","c"] -> np.array(["a","b","c"],dtype=object) + + Primitive datatype on server is returned as a numpy primitive + Ex: + 1 -> np.int64(1) + 2 -> np.int32(2) + """ + if value is None: + return None + elif isinstance(value, (Sequence, np.ndarray)) and not isinstance( + value, (str, bytes) + ): + return tuple(value) + elif isinstance(value, dict): + return tuple(value.items()) + elif isinstance(value, np.generic): + return value.item() + elif isinstance(value, decimal.Decimal): + return float(value) + else: + return value + + def _recursive_compare(self, actual, expected): + """ + Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level + + Note: Complex datatype like MAP is not returned as a dictionary but as a list of tuples + """ + actual_parsed = self._parse_to_common_type(actual) + expected_parsed = self._parse_to_common_type(expected) + + # Check if types are the same + if type(actual_parsed) != type(expected_parsed): + return False + + # Handle lists or tuples + if isinstance(actual_parsed, (list, tuple)): + if len(actual_parsed) != len(expected_parsed): + return False + return all( + self._recursive_compare(o1, o2) + for o1, o2 in zip(actual_parsed, expected_parsed) + ) + + return actual_parsed == expected_parsed + + def sample_array_table(self) -> tuple[DeclarativeBase, dict]: + class Base(DeclarativeBase): + pass + + class ArrayTable(Base): + __tablename__ = "sqlalchemy_array_table" + + int_col = Column(Integer, primary_key=True) + array_int_col = Column(DatabricksArray(Integer)) + array_bigint_col = Column(DatabricksArray(BigInteger)) + array_numeric_col = Column(DatabricksArray(Numeric(10, 2))) + array_string_col = Column(DatabricksArray(String)) + array_boolean_col = Column(DatabricksArray(Boolean)) + array_date_col = Column(DatabricksArray(Date)) + array_datetime_col = Column(DatabricksArray(TIMESTAMP)) + array_datetime_col_ntz = Column(DatabricksArray(DateTime)) + array_tinyint_col = Column(DatabricksArray(TINYINT)) + + sample_data = { + "int_col": 1, + "array_int_col": [1, 2], + "array_bigint_col": [1234567890123456789, 2345678901234567890], + "array_numeric_col": [1.1, 2.2], + "array_string_col": ["a", "b"], + "array_boolean_col": [True, False], + "array_date_col": [date(2020, 12, 25), date(2021, 1, 2)], + "array_datetime_col": [ + datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8))), + datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8))), + ], + "array_datetime_col_ntz": [ + datetime(1990, 12, 4, 6, 33, 41), + datetime(1990, 12, 4, 6, 33, 41), + ], + "array_tinyint_col": [-100, 100], + } + + return ArrayTable, sample_data + + def sample_map_table(self) -> tuple[DeclarativeBase, dict]: + class Base(DeclarativeBase): + pass + + class MapTable(Base): + __tablename__ = "sqlalchemy_map_table" + + int_col = Column(Integer, primary_key=True) + map_int_col = Column(DatabricksMap(Integer, Integer)) + map_bigint_col = Column(DatabricksMap(Integer, BigInteger)) + map_numeric_col = Column(DatabricksMap(Integer, Numeric(10, 2))) + map_string_col = Column(DatabricksMap(Integer, String)) + map_boolean_col = Column(DatabricksMap(Integer, Boolean)) + map_date_col = Column(DatabricksMap(Integer, Date)) + map_datetime_col = Column(DatabricksMap(Integer, TIMESTAMP)) + map_datetime_col_ntz = Column(DatabricksMap(Integer, DateTime)) + map_tinyint_col = Column(DatabricksMap(Integer, TINYINT)) + + sample_data = { + "int_col": 1, + "map_int_col": {1: 1}, + "map_bigint_col": {1: 1234567890123456789}, + "map_numeric_col": {1: 1.1}, + "map_string_col": {1: "a"}, + "map_boolean_col": {1: True}, + "map_date_col": {1: date(2020, 12, 25)}, + "map_datetime_col": { + 1: datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8))) + }, + "map_datetime_col_ntz": {1: datetime(1990, 12, 4, 6, 33, 41)}, + "map_tinyint_col": {1: -100}, + } + + return MapTable, sample_data + + def test_insert_array_table_sqlalchemy(self): + table, sample_data = self.sample_array_table() + + with self.table_context(table) as engine: + sa_obj = table(**sample_data) + session = Session(engine) + session.add(sa_obj) + session.commit() + + stmt = select(table).where(table.int_col == 1) + + result = session.scalar(stmt) + + compare = {key: getattr(result, key) for key in sample_data.keys()} + assert self._recursive_compare(compare, sample_data) + + def test_insert_map_table_sqlalchemy(self): + table, sample_data = self.sample_map_table() + + with self.table_context(table) as engine: + sa_obj = table(**sample_data) + session = Session(engine) + session.add(sa_obj) + session.commit() + + stmt = select(table).where(table.int_col == 1) + + result = session.scalar(stmt) + + compare = {key: getattr(result, key) for key in sample_data.keys()} + assert self._recursive_compare(compare, sample_data) + + def test_array_table_creation_pandas(self): + table, sample_data = self.sample_array_table() + + with self.table_context(table) as engine: + # Insert the data into the table + df = pd.DataFrame([sample_data]) + df.to_sql(table.__tablename__, engine, if_exists="append", index=False) + + # Read the data from the table + stmt = select(table) + df_result = pd.read_sql(stmt, engine) + assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + + def test_map_table_creation_pandas(self): + table, sample_data = self.sample_map_table() + + with self.table_context(table) as engine: + # Insert the data into the table + df = pd.DataFrame([sample_data]) + df.to_sql(table.__tablename__, engine, if_exists="append", index=False) + + # Read the data from the table + stmt = select(table) + df_result = pd.read_sql(stmt, engine) + assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) diff --git a/tests/test_local/e2e/test_setup.py b/tests/test_local/e2e/test_setup.py new file mode 100644 index 0000000..94a37aa --- /dev/null +++ b/tests/test_local/e2e/test_setup.py @@ -0,0 +1,31 @@ +import pytest +from sqlalchemy import create_engine, Engine +from contextlib import contextmanager +from sqlalchemy.orm import DeclarativeBase, Session + + +class TestSetup: + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + self.arguments = connection_details.copy() + + def db_engine(self) -> Engine: + HOST = self.arguments["host"] + HTTP_PATH = self.arguments["http_path"] + ACCESS_TOKEN = self.arguments["access_token"] + CATALOG = self.arguments["catalog"] + SCHEMA = self.arguments["schema"] + + connect_args = {"_user_agent_entry": "SQLAlchemy e2e Tests"} + + conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" + return create_engine(conn_string, connect_args=connect_args) + + @contextmanager + def table_context(self, table: DeclarativeBase): + engine = self.db_engine() + table.metadata.create_all(engine) + try: + yield engine + finally: + table.metadata.drop_all(engine) diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index f596dff..0b04c2e 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import Column, MetaData, String, Table, create_engine +from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine from sqlalchemy.schema import ( CreateTable, DropColumnComment, @@ -7,6 +7,7 @@ SetColumnComment, SetTableComment, ) +from databricks.sqlalchemy import DatabricksArray, DatabricksMap class DDLTestBase: @@ -94,3 +95,20 @@ def test_alter_table_drop_comment(self, table_with_comment): stmt = DropTableComment(table_with_comment) output = self.compile(stmt) assert output == "COMMENT ON TABLE martin IS NULL" + + +class TestTableComplexTypeDDL(DDLTestBase): + @pytest.fixture(scope="class") + def metadata(self) -> MetaData: + metadata = MetaData() + col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String))) + col2 = Column("map_string_string", DatabricksMap(String, String)) + table = Table("complex_type", metadata, col1, col2) + return metadata + + def test_create_table_with_complex_type(self, metadata): + stmt = CreateTable(metadata.tables["complex_type"]) + output = self.compile(stmt) + + assert "array_array_string ARRAY>" in output + assert "map_string_string MAP" in output diff --git a/tests/test_local/test_parsing.py b/tests/test_local/test_parsing.py index c8ab443..026b6a4 100644 --- a/tests/test_local/test_parsing.py +++ b/tests/test_local/test_parsing.py @@ -9,7 +9,28 @@ get_comment_from_dte_output, DatabricksSqlAlchemyParseException, ) +from sqlalchemy import ( + BigInteger, + Boolean, + Date, + DateTime, + Integer, + Numeric, + String, + Time, + Uuid, +) + +from databricks.sqlalchemy import ( + DatabricksArray, + TIMESTAMP, + TINYINT, + DatabricksMap, + TIMESTAMP_NTZ, +) +from databricks.sqlalchemy import DatabricksDialect +dialect = DatabricksDialect() # These are outputs from DESCRIBE TABLE EXTENDED @pytest.mark.parametrize( @@ -158,3 +179,76 @@ def test_filter_dict_by_value(match, output): def test_get_comment_from_dte_output(): assert get_comment_from_dte_output(FMT_SAMPLE_DT_OUTPUT) == "some comment" + + +def get_databricks_non_compound_types(): + return [ + Integer(), + String(), + Boolean(), + Date(), + DateTime(), + Time(), + Uuid(), + Numeric(), + TINYINT(), + TIMESTAMP(), + TIMESTAMP_NTZ(), + BigInteger(), + ] + + +def get_databricks_compound_types(): + return [DatabricksArray(String), DatabricksMap(String, String)] + + +@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types()) +def test_array_parsing(internal_type): + array_type = DatabricksArray(internal_type) + + actual_parsed = array_type.compile(dialect=dialect) + expected_parsed = "ARRAY<{}>".format(internal_type.compile(dialect=dialect)) + assert actual_parsed == expected_parsed + + +@pytest.mark.parametrize("internal_type_1", get_databricks_non_compound_types()) +@pytest.mark.parametrize("internal_type_2", get_databricks_non_compound_types()) +def test_map_parsing(internal_type_1, internal_type_2): + map_type = DatabricksMap(internal_type_1, internal_type_2) + + actual_parsed = map_type.compile(dialect=dialect) + expected_parsed = "MAP<{},{}>".format( + internal_type_1.compile(dialect=dialect), + internal_type_2.compile(dialect=dialect), + ) + assert actual_parsed == expected_parsed + + +@pytest.mark.parametrize( + "internal_type", + get_databricks_non_compound_types() + get_databricks_compound_types(), +) +def test_multilevel_array_type_parsing(internal_type): + array_type = DatabricksArray(DatabricksArray(DatabricksArray(internal_type))) + + actual_parsed = array_type.compile(dialect=dialect) + expected_parsed = "ARRAY>>".format( + internal_type.compile(dialect=dialect) + ) + assert actual_parsed == expected_parsed + + +@pytest.mark.parametrize( + "internal_type", + get_databricks_non_compound_types() + get_databricks_compound_types(), +) +def test_multilevel_map_type_parsing(internal_type): + map_type = DatabricksMap( + String, DatabricksMap(String, DatabricksMap(String, internal_type)) + ) + + actual_parsed = map_type.compile(dialect=dialect) + expected_parsed = "MAP>>".format( + internal_type.compile(dialect=dialect) + ) + assert actual_parsed == expected_parsed