Skip to content

Commit 7271866

Browse files
committed
Basic working prototype
1 parent cfc58d6 commit 7271866

File tree

6 files changed

+7673
-2657
lines changed

6 files changed

+7673
-2657
lines changed

src/databricks/sql/client.py

Lines changed: 3 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848

4949
from databricks.sql.thrift_api.TCLIService.ttypes import (
5050
TSparkParameter,
51-
TOperationState, TSparkParameterValue, TSparkParameterValueArg,
51+
TOperationState,
52+
TSparkParameterValue,
53+
TSparkParameterValueArg,
5254
)
5355

5456

@@ -807,146 +809,6 @@ def execute(
807809
transformed_operation, normalized_parameters, param_structure
808810
)
809811

810-
# temp_prepared_operation="""INSERT INTO ___________________first.jprakash.complex_types (
811-
# user_id, name, emails, preferences, address, recent_orders
812-
# ) VALUES (
813-
# :user_id, :name, :emails, :preferences,:address, :recent_orders
814-
# )"""
815-
#
816-
# temp_prepared_params=[
817-
# TSparkParameter(
818-
# name="user_id",
819-
# type="STRING",
820-
# value=TSparkParameterValue(stringValue="11")
821-
# ),
822-
# TSparkParameter(
823-
# name="name",
824-
# type="STRING",
825-
# value=TSparkParameterValue(stringValue="John Doe"),
826-
# ),
827-
# TSparkParameter(
828-
# name="emails",
829-
# # type="ARRAY",
830-
# arguments=[
831-
# TSparkParameterValueArg(
832-
# type="STRING",
833-
# value="john.doe@example.com"
834-
# ),
835-
# TSparkParameterValueArg(
836-
# type="STRING",
837-
# value="jd@example.org"
838-
# )
839-
# ]
840-
# ),
841-
# TSparkParameter(
842-
# name="preferences",
843-
# type="MAP",
844-
# arguments=[
845-
# TSparkParameterValueArg(
846-
# type="STRING",
847-
# value="theme"
848-
# ),
849-
# TSparkParameterValueArg(
850-
# type="STRING",
851-
# value="dark"
852-
# ),
853-
# TSparkParameterValueArg(
854-
# type="STRING",
855-
# value="language"
856-
# ),
857-
# TSparkParameterValueArg(
858-
# type="STRING",
859-
# value="en"
860-
# ),
861-
# ]
862-
#
863-
# ),
864-
# TSparkParameter(
865-
# name="address",
866-
# type="NAMED_STRUCT",
867-
# arguments=[
868-
# TSparkParameterValueArg(
869-
# type="STRING",
870-
# value="street"
871-
# ),
872-
# TSparkParameterValueArg(
873-
# type="STRING",
874-
# value="123 Main St"
875-
# ),
876-
# TSparkParameterValueArg(
877-
# type="STRING",
878-
# value="city"
879-
# ),
880-
# TSparkParameterValueArg(
881-
# type="STRING",
882-
# value="Metropolis"
883-
# ),
884-
# TSparkParameterValueArg(
885-
# type="STRING",
886-
# value="zip"
887-
# ),
888-
# TSparkParameterValueArg(
889-
# type="STRING",
890-
# value="12345"
891-
# ),
892-
# ]
893-
# ),
894-
# # TSparkParameter(
895-
# # name="address",
896-
# # type="STRUCT",
897-
# # arguments=[
898-
# # TSparkParameterValueArg(
899-
# # type="STRING",
900-
# # value="123 Main St"
901-
# # ),
902-
# # TSparkParameterValueArg(
903-
# # type="STRING",
904-
# # value="Metropolis"
905-
# # ),
906-
# # TSparkParameterValueArg(
907-
# # type="STRING",
908-
# # value="12345"
909-
# # ),
910-
# # ]
911-
# # ),
912-
# TSparkParameter(
913-
# name="recent_orders",
914-
# type="ARRAY",
915-
# arguments=[
916-
# TSparkParameterValueArg(
917-
# type="NAMED_STRUCT",
918-
# arguments=[
919-
# TSparkParameterValueArg(type="STRING", value="order_id"),
920-
# TSparkParameterValueArg(type="STRING", value="ord001"),
921-
# TSparkParameterValueArg(type="STRING", value="amount"),
922-
# TSparkParameterValueArg(type="DECIMAL(10,2)", value="199.99"),
923-
# TSparkParameterValueArg(type="STRING", value="items"),
924-
# TSparkParameterValueArg(type="ARRAY", arguments=[
925-
# TSparkParameterValueArg(type="STRING", value="item1"),
926-
# TSparkParameterValueArg(type="STRING", value="item2")
927-
# ])
928-
# ]
929-
# ),
930-
# TSparkParameterValueArg(
931-
# type="NAMED_STRUCT",
932-
# arguments=[
933-
# TSparkParameterValueArg(type="STRING", value="order_id"),
934-
# TSparkParameterValueArg(type="STRING", value="ord002"),
935-
# TSparkParameterValueArg(type="STRING", value="amount"),
936-
# TSparkParameterValueArg(type="DECIMAL(10,2)", value="49.95"),
937-
# TSparkParameterValueArg(type="STRING", value="items"),
938-
# TSparkParameterValueArg(type="ARRAY", arguments=[
939-
# TSparkParameterValueArg(type="STRING", value="item3"),
940-
# ])
941-
# ]
942-
# ),
943-
# ]
944-
# )
945-
# ]
946-
947-
print("LINE 947")
948-
print(prepared_params)
949-
950812
self._check_not_closed()
951813
self._close_and_clear_active_result_set()
952814
execute_response = self.thrift_backend.execute_command(

src/databricks/sql/parameters/native.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
import decimal
33
from enum import Enum, auto
4-
from typing import Optional, Sequence
4+
from typing import Optional, Sequence, Any
55

66
from databricks.sql.exc import NotSupportedError
77
from databricks.sql.thrift_api.TCLIService.ttypes import (
88
TSparkParameter,
99
TSparkParameterValue,
10+
TSparkParameterValueArg,
1011
)
1112

1213
import datetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
5455

5556

5657
TAllowedParameterValue = Union[
57-
str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None
58+
str,
59+
int,
60+
float,
61+
datetime.datetime,
62+
datetime.date,
63+
bool,
64+
decimal.Decimal,
65+
None,
66+
list,
67+
dict,
68+
tuple,
5869
]
5970

6071

@@ -82,6 +93,7 @@ class DbsqlParameterBase:
8293

8394
CAST_EXPR: str
8495
name: Optional[str]
96+
value: Any
8597

8698
def as_tspark_param(self, named: bool) -> TSparkParameter:
8799
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98110
def _tspark_param_value(self):
99111
return TSparkParameterValue(stringValue=str(self.value))
100112

113+
def _tspark_value_arg(self):
114+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115+
return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr())
116+
101117
def _cast_expr(self):
102118
return self.CAST_EXPR
103119

@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428444
CAST_EXPR = DatabricksSupportedType.TINYINT.name
429445

430446

447+
class ArrayParameter(DbsqlParameterBase):
448+
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449+
450+
def __init__(self, value: Sequence[Any], name: Optional[str] = None):
451+
"""
452+
:value:
453+
The value to bind for this parameter. This will be casted to a ARRAY.
454+
:name:
455+
If None, your query must contain a `?` marker. Like:
456+
457+
```sql
458+
SELECT * FROM table WHERE field = ?
459+
```
460+
If not None, your query should contain a named parameter marker. Like:
461+
```sql
462+
SELECT * FROM table WHERE field = :my_param
463+
```
464+
465+
The `name` argument to this function would be `my_param`.
466+
"""
467+
self.name = name
468+
self.value = [dbsql_parameter_from_primitive(val) for val in value]
469+
470+
def as_tspark_param(self, named: bool) -> TSparkParameter:
471+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472+
473+
tsp = TSparkParameter(type=self._cast_expr())
474+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
475+
476+
if named:
477+
tsp.name = self.name
478+
tsp.ordinal = False
479+
elif not named:
480+
tsp.ordinal = True
481+
return tsp
482+
483+
def _tspark_value_arg(self):
484+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485+
tva = TSparkParameterValueArg(type=self._cast_expr())
486+
tva.arguments = [val._tspark_value_arg() for val in self.value]
487+
return tva
488+
489+
CAST_EXPR = DatabricksSupportedType.ARRAY.name
490+
491+
492+
class MapParameter(DbsqlParameterBase):
493+
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494+
495+
def __init__(self, value: dict, name: Optional[str] = None):
496+
"""
497+
:value:
498+
The value to bind for this parameter. This will be casted to a MAP.
499+
:name:
500+
If None, your query must contain a `?` marker. Like:
501+
502+
```sql
503+
SELECT * FROM table WHERE field = ?
504+
```
505+
If not None, your query should contain a named parameter marker. Like:
506+
```sql
507+
SELECT * FROM table WHERE field = :my_param
508+
```
509+
510+
The `name` argument to this function would be `my_param`.
511+
"""
512+
self.name = name
513+
self.value = [
514+
dbsql_parameter_from_primitive(item)
515+
for key, val in value.items()
516+
for item in (key, val)
517+
]
518+
519+
def as_tspark_param(self, named: bool) -> TSparkParameter:
520+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521+
522+
tsp = TSparkParameter(type=self._cast_expr())
523+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
524+
if named:
525+
tsp.name = self.name
526+
tsp.ordinal = False
527+
elif not named:
528+
tsp.ordinal = True
529+
return tsp
530+
531+
def _tspark_value_arg(self):
532+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533+
tva = TSparkParameterValueArg(type=self._cast_expr())
534+
tva.arguments = [val._tspark_value_arg() for val in self.value]
535+
return tva
536+
537+
CAST_EXPR = DatabricksSupportedType.MAP.name
538+
539+
431540
class DecimalParameter(DbsqlParameterBase):
432541
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433542

@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543652
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544653
# its logic
545654

546-
if type(value) is int:
655+
if isinstance(value, int):
547656
return dbsql_parameter_from_int(value, name=name)
548-
elif type(value) is str:
657+
elif isinstance(value, str):
549658
return StringParameter(value=value, name=name)
550-
elif type(value) is float:
659+
elif isinstance(value, float):
551660
return FloatParameter(value=value, name=name)
552-
elif type(value) is datetime.datetime:
661+
elif isinstance(value, datetime.datetime):
553662
return TimestampParameter(value=value, name=name)
554-
elif type(value) is datetime.date:
663+
elif isinstance(value, datetime.date):
555664
return DateParameter(value=value, name=name)
556-
elif type(value) is bool:
665+
elif isinstance(value, bool):
557666
return BooleanParameter(value=value, name=name)
558-
elif type(value) is decimal.Decimal:
667+
elif isinstance(value, decimal.Decimal):
559668
return DecimalParameter(value=value, name=name)
669+
elif isinstance(value, dict):
670+
return MapParameter(value=value, name=name)
671+
elif isinstance(value, Sequence) and not isinstance(value, str):
672+
return ArrayParameter(value=value, name=name)
560673
elif value is None:
561674
return VoidParameter(value=value, name=name)
562-
563675
else:
564676
raise NotSupportedError(
565677
f"Could not infer parameter type from value: {value} - {type(value)} \n"
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581693
TimestampNTZParameter,
582694
TinyIntParameter,
583695
DecimalParameter,
696+
ArrayParameter,
697+
MapParameter,
584698
]
585699

586700

0 commit comments

Comments
 (0)