Skip to content

Commit 5354788

Browse files
committed
Merge branch 'main' into variant
2 parents d0e39ec + c123af3 commit 5354788

18 files changed

+819
-68
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# the repo. Unless a later match takes precedence, these
33
# users will be requested for review when someone opens a
44
# pull request.
5-
* @deeksha-db @samikshya-db @jprakash-db @yunbodeng-db @jackyhu-db @benc-db
5+
* @deeksha-db @samikshya-db @jprakash-db @jackyhu-db @madhav-db @gopalldb @jayantsing-db @vikrantpuppala @shivam2680

.github/workflows/code-quality-checks.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
name: Code Quality Checks
2-
on:
3-
push:
4-
branches:
5-
- main
6-
pull_request:
7-
branches:
8-
- main
2+
3+
on: [pull_request]
4+
95
jobs:
106
run-unit-tests:
117
runs-on: ubuntu-latest

.github/workflows/integration.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name: Integration Tests
2+
23
on:
3-
push:
4-
paths-ignore:
5-
- "**.MD"
6-
- "**.md"
4+
push:
5+
branches:
6+
- main
7+
pull_request:
78

89
jobs:
910
run-e2e-tests:

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Release History
22

3+
# 4.0.4 (2025-06-16)
4+
5+
- Update thrift client library after cleaning up unused fields and structs (databricks/databricks-sql-python#553 by @vikrantpuppala)
6+
- Refactor decimal conversion in PyArrow tables to use direct casting (databricks/databricks-sql-python#544 by @jayantsing-db)
7+
- Fix: `fetchall_arrow` to always return results in `arrow` format (databricks/databricks-sql-python#551 by @shivam2680)
8+
- Enhance cursor close handling and context manager exception management to prevent server side resource leaks (databricks/databricks-sql-python#554 by @madhav-db)
9+
- Added additional logging to enhance debugging (databricks/databricks-sql-python#556 by @saishreeeee)
10+
- Feature: Added support for complex data types such as Arrays and Map [Private Preview] (databricks/databricks-sql-python#559 by @jprakash-db)
11+
312
# 4.0.3 (2025-04-22)
413

514
- Fix: Removed `packaging` dependency in favour of default libraries, for `urllib3` version checks (databricks/databricks-sql-python#547 by @jprakash-db)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "databricks-sql-connector"
3-
version = "4.0.3"
3+
version = "4.0.4"
44
description = "Databricks SQL Connector for Python"
55
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
66
license = "Apache-2.0"

src/databricks/sql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __repr__(self):
6868
DATE = DBAPITypeObject("date")
6969
ROWID = DBAPITypeObject()
7070

71-
__version__ = "4.0.3"
71+
__version__ = "4.0.4"
7272
USER_AGENT_NAME = "PyDatabricksSqlConnector"
7373

7474
# These two functions are pyhive legacy

src/databricks/sql/client.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ def read(self) -> Optional[OAuthToken]:
214214
# use_cloud_fetch
215215
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
216216

217+
logger.debug(
218+
"Connection.__init__(server_hostname=%s, http_path=%s)",
219+
server_hostname,
220+
http_path,
221+
)
222+
217223
if access_token:
218224
access_token_kv = {"access_token": access_token}
219225
kwargs = {**kwargs, **access_token_kv}
@@ -315,7 +321,13 @@ def __enter__(self) -> "Connection":
315321
return self
316322

317323
def __exit__(self, exc_type, exc_value, traceback):
318-
self.close()
324+
try:
325+
self.close()
326+
except BaseException as e:
327+
logger.warning(f"Exception during connection close in __exit__: {e}")
328+
if exc_type is None:
329+
raise
330+
return False
319331

320332
def __del__(self):
321333
if self.open:
@@ -456,7 +468,14 @@ def __enter__(self) -> "Cursor":
456468
return self
457469

458470
def __exit__(self, exc_type, exc_value, traceback):
459-
self.close()
471+
try:
472+
logger.debug("Cursor context manager exiting, calling close()")
473+
self.close()
474+
except BaseException as e:
475+
logger.warning(f"Exception during cursor close in __exit__: {e}")
476+
if exc_type is None:
477+
raise
478+
return False
460479

461480
def __iter__(self):
462481
if self.active_result_set:
@@ -787,6 +806,9 @@ def execute(
787806
788807
:returns self
789808
"""
809+
logger.debug(
810+
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
811+
)
790812

791813
param_approach = self._determine_parameter_approach(parameters)
792814
if param_approach == ParameterApproach.NONE:
@@ -1163,7 +1185,21 @@ def cancel(self) -> None:
11631185
def close(self) -> None:
11641186
"""Close cursor"""
11651187
self.open = False
1166-
self.active_op_handle = None
1188+
1189+
# Close active operation handle if it exists
1190+
if self.active_op_handle:
1191+
try:
1192+
self.thrift_backend.close_command(self.active_op_handle)
1193+
except RequestError as e:
1194+
if isinstance(e.args[1], CursorAlreadyClosedError):
1195+
logger.info("Operation was canceled by a prior request")
1196+
else:
1197+
logging.warning(f"Error closing operation handle: {e}")
1198+
except Exception as e:
1199+
logging.warning(f"Error closing operation handle: {e}")
1200+
finally:
1201+
self.active_op_handle = None
1202+
11671203
if self.active_result_set:
11681204
self._close_and_clear_active_result_set()
11691205

src/databricks/sql/parameters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
TimestampNTZParameter,
1313
TinyIntParameter,
1414
DecimalParameter,
15+
MapParameter,
16+
ArrayParameter,
1517
)

src/databricks/sql/parameters/native.py

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
import decimal
33
from enum import Enum, auto
4-
from typing import Optional, Sequence
4+
from typing import Optional, Sequence, Any
55

66
from databricks.sql.exc import NotSupportedError
77
from databricks.sql.thrift_api.TCLIService.ttypes import (
88
TSparkParameter,
99
TSparkParameterValue,
10+
TSparkParameterValueArg,
1011
)
1112

1213
import datetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
5455

5556

5657
TAllowedParameterValue = Union[
57-
str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None
58+
str,
59+
int,
60+
float,
61+
datetime.datetime,
62+
datetime.date,
63+
bool,
64+
decimal.Decimal,
65+
None,
66+
list,
67+
dict,
68+
tuple,
5869
]
5970

6071

@@ -82,6 +93,7 @@ class DbsqlParameterBase:
8293

8394
CAST_EXPR: str
8495
name: Optional[str]
96+
value: Any
8597

8698
def as_tspark_param(self, named: bool) -> TSparkParameter:
8799
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98110
def _tspark_param_value(self):
99111
return TSparkParameterValue(stringValue=str(self.value))
100112

113+
def _tspark_value_arg(self):
114+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115+
return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr())
116+
101117
def _cast_expr(self):
102118
return self.CAST_EXPR
103119

@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428444
CAST_EXPR = DatabricksSupportedType.TINYINT.name
429445

430446

447+
class ArrayParameter(DbsqlParameterBase):
448+
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449+
450+
def __init__(self, value: Sequence[Any], name: Optional[str] = None):
451+
"""
452+
:value:
453+
The value to bind for this parameter. This will be casted to a ARRAY.
454+
:name:
455+
If None, your query must contain a `?` marker. Like:
456+
457+
```sql
458+
SELECT * FROM table WHERE field = ?
459+
```
460+
If not None, your query should contain a named parameter marker. Like:
461+
```sql
462+
SELECT * FROM table WHERE field = :my_param
463+
```
464+
465+
The `name` argument to this function would be `my_param`.
466+
"""
467+
self.name = name
468+
self.value = [dbsql_parameter_from_primitive(val) for val in value]
469+
470+
def as_tspark_param(self, named: bool = False) -> TSparkParameter:
471+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472+
473+
tsp = TSparkParameter(type=self._cast_expr())
474+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
475+
476+
if named:
477+
tsp.name = self.name
478+
tsp.ordinal = False
479+
elif not named:
480+
tsp.ordinal = True
481+
return tsp
482+
483+
def _tspark_value_arg(self):
484+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485+
tva = TSparkParameterValueArg(type=self._cast_expr())
486+
tva.arguments = [val._tspark_value_arg() for val in self.value]
487+
return tva
488+
489+
CAST_EXPR = DatabricksSupportedType.ARRAY.name
490+
491+
492+
class MapParameter(DbsqlParameterBase):
493+
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494+
495+
def __init__(self, value: dict, name: Optional[str] = None):
496+
"""
497+
:value:
498+
The value to bind for this parameter. This will be casted to a MAP.
499+
:name:
500+
If None, your query must contain a `?` marker. Like:
501+
502+
```sql
503+
SELECT * FROM table WHERE field = ?
504+
```
505+
If not None, your query should contain a named parameter marker. Like:
506+
```sql
507+
SELECT * FROM table WHERE field = :my_param
508+
```
509+
510+
The `name` argument to this function would be `my_param`.
511+
"""
512+
self.name = name
513+
self.value = [
514+
dbsql_parameter_from_primitive(item)
515+
for key, val in value.items()
516+
for item in (key, val)
517+
]
518+
519+
def as_tspark_param(self, named: bool = False) -> TSparkParameter:
520+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521+
522+
tsp = TSparkParameter(type=self._cast_expr())
523+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
524+
if named:
525+
tsp.name = self.name
526+
tsp.ordinal = False
527+
elif not named:
528+
tsp.ordinal = True
529+
return tsp
530+
531+
def _tspark_value_arg(self):
532+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533+
tva = TSparkParameterValueArg(type=self._cast_expr())
534+
tva.arguments = [val._tspark_value_arg() for val in self.value]
535+
return tva
536+
537+
CAST_EXPR = DatabricksSupportedType.MAP.name
538+
539+
431540
class DecimalParameter(DbsqlParameterBase):
432541
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433542

@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543652
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544653
# its logic
545654

546-
if type(value) is int:
655+
if isinstance(value, bool):
656+
return BooleanParameter(value=value, name=name)
657+
elif isinstance(value, int):
547658
return dbsql_parameter_from_int(value, name=name)
548-
elif type(value) is str:
659+
elif isinstance(value, str):
549660
return StringParameter(value=value, name=name)
550-
elif type(value) is float:
661+
elif isinstance(value, float):
551662
return FloatParameter(value=value, name=name)
552-
elif type(value) is datetime.datetime:
663+
elif isinstance(value, datetime.datetime):
553664
return TimestampParameter(value=value, name=name)
554-
elif type(value) is datetime.date:
665+
elif isinstance(value, datetime.date):
555666
return DateParameter(value=value, name=name)
556-
elif type(value) is bool:
557-
return BooleanParameter(value=value, name=name)
558-
elif type(value) is decimal.Decimal:
667+
elif isinstance(value, decimal.Decimal):
559668
return DecimalParameter(value=value, name=name)
669+
elif isinstance(value, dict):
670+
return MapParameter(value=value, name=name)
671+
elif isinstance(value, Sequence) and not isinstance(value, str):
672+
return ArrayParameter(value=value, name=name)
560673
elif value is None:
561674
return VoidParameter(value=value, name=name)
562-
563675
else:
564676
raise NotSupportedError(
565677
f"Could not infer parameter type from value: {value} - {type(value)} \n"
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581693
TimestampNTZParameter,
582694
TinyIntParameter,
583695
DecimalParameter,
696+
ArrayParameter,
697+
MapParameter,
584698
]
585699

586700

0 commit comments

Comments
 (0)