Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 95 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.sqlalchemy.base import DatabricksDialect
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
from databricks.sqlalchemy._types import (
TINYINT,
TIMESTAMP,
TIMESTAMP_NTZ,
DatabricksArray,
DatabricksMap,
)

__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"]
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
76 changes: 76 additions & 0 deletions src/databricks/sqlalchemy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import TypeDecorator, UserDefinedType

from databricks.sql.utils import ParamEscaper

Expand All @@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any):
return value


def identity_processor(value):
"""This method returns the value itself, when no other processor is provided"""
return value


@compiles(sqlalchemy.types.Enum, "databricks")
@compiles(sqlalchemy.types.String, "databricks")
@compiles(sqlalchemy.types.Text, "databricks")
Expand Down Expand Up @@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
@compiles(TINYINT, "databricks")
def compile_tinyint(type_, compiler, **kw):
return "TINYINT"


class DatabricksArray(UserDefinedType):
"""
A custom array type that can wrap any other SQLAlchemy type.

Examples:
DatabricksArray(String) -> ARRAY<STRING>
DatabricksArray(Integer) -> ARRAY<INT>
DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
"""

def __init__(self, item_type):
self.item_type = item_type() if isinstance(item_type, type) else item_type

def bind_processor(self, dialect):
item_processor = self.item_type.bind_processor(dialect)
if item_processor is None:
item_processor = identity_processor

def process(value):
return [item_processor(val) for val in value]

return process


@compiles(DatabricksArray, "databricks")
def compile_databricks_array(type_, compiler, **kw):
inner = compiler.process(type_.item_type, **kw)

return f"ARRAY<{inner}>"


class DatabricksMap(UserDefinedType):
"""
A custom map type that can wrap any other SQLAlchemy types for both key and value.

Examples:
DatabricksMap(String, String) -> MAP<STRING,STRING>
DatabricksMap(Integer, String) -> MAP<INT,STRING>
DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
"""

def __init__(self, key_type, value_type):
self.key_type = key_type() if isinstance(key_type, type) else key_type
self.value_type = value_type() if isinstance(value_type, type) else value_type

def bind_processor(self, dialect):
key_processor = self.key_type.bind_processor(dialect)
value_processor = self.value_type.bind_processor(dialect)

if key_processor is None:
key_processor = identity_processor
if value_processor is None:
value_processor = identity_processor

def process(value):
return {
key_processor(key): value_processor(value)
for key, value in value.items()
}

return process


@compiles(DatabricksMap, "databricks")
def compile_databricks_map(type_, compiler, **kw):
key_type = compiler.process(type_.key_type, **kw)
value_type = compiler.process(type_.value_type, **kw)
return f"MAP<{key_type},{value_type}>"
Empty file.
Loading
Loading