Skip to content

Commit 5494793

Browse files
committed
add new connection param
1 parent ffc5568 commit 5494793

File tree

4 files changed

+28
-15
lines changed

4 files changed

+28
-15
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,16 +764,14 @@ def _col_to_description(col, field=None, session_id_hex=None):
764764
else:
765765
precision, scale = None, None
766766

767-
# Extract variant/measure type from field if available
767+
# Extract variant type from field if available
768768
if field is not None:
769769
try:
770-
# Check for variant/measure type in metadata
770+
# Check for variant type in metadata
771771
if field.metadata and b"Spark:DataType:SqlName" in field.metadata:
772772
sql_type = field.metadata.get(b"Spark:DataType:SqlName")
773773
if sql_type == b"VARIANT":
774774
cleaned_type = "variant"
775-
if sql_type and b"measure" in sql_type:
776-
cleaned_type += " measure"
777775
except Exception as e:
778776
logger.debug(f"Could not extract variant type from field: {e}")
779777

src/databricks/sql/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def read(self) -> Optional[OAuthToken]:
200200
STRUCT is returned as Dict[str, Any]
201201
ARRAY is returned as numpy.ndarray
202202
When False, complex types are returned as a strings. These are generally deserializable as JSON.
203+
:param enable_metric_view_metadata: `bool`, optional (default is False)
204+
When True, enables metric view metadata support by setting the
205+
spark.sql.thriftserver.metadata.metricview.enabled session configuration.
206+
This allows
207+
1. cursor.tables() to return METRIC_VIEW table type
208+
2. cursor.columns() to return "measure" column type
203209
"""
204210

205211
# Internal arguments in **kwargs:
@@ -248,6 +254,15 @@ def read(self) -> Optional[OAuthToken]:
248254
access_token_kv = {"access_token": access_token}
249255
kwargs = {**kwargs, **access_token_kv}
250256

257+
# Handle enable_metric_view_metadata parameter
258+
enable_metric_view_metadata = kwargs.get("enable_metric_view_metadata", False)
259+
if enable_metric_view_metadata:
260+
if session_configuration is None:
261+
session_configuration = {}
262+
session_configuration[
263+
"spark.sql.thriftserver.metadata.metricview.enabled"
264+
] = "true"
265+
251266
self.disable_pandas = kwargs.get("_disable_pandas", False)
252267
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
253268
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)

tests/unit/test_session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ def test_configuration_passthrough(self, mock_client_class):
163163
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
164164
assert call_kwargs["session_configuration"] == mock_session_config
165165

166+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
167+
def test_enable_metric_view_metadata_parameter(self, mock_client_class):
168+
"""Test that enable_metric_view_metadata parameter sets the correct session configuration."""
169+
databricks.sql.connect(
170+
enable_metric_view_metadata=True, **self.DUMMY_CONNECTION_ARGS
171+
)
172+
173+
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
174+
expected_config = {"spark.sql.thriftserver.metadata.metricview.enabled": "true"}
175+
assert call_kwargs["session_configuration"] == expected_config
176+
166177
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
167178
def test_initial_namespace_passthrough(self, mock_client_class):
168179
mock_cat = Mock()

tests/unit/test_thrift_backend.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,17 +2402,6 @@ def test_hive_schema_to_description(self):
24022402
],
24032403
[("regular_col", "string"), ("variant_col", "variant")],
24042404
),
2405-
(
2406-
[
2407-
("measure_col", ttypes.TTypeId.DOUBLE_TYPE),
2408-
("int_measure_col", ttypes.TTypeId.INT_TYPE),
2409-
],
2410-
[
2411-
("measure_col", {b"Spark:DataType:SqlName": b"double measure"}),
2412-
("int_measure_col", {b"Spark:DataType:SqlName": b"int measure"}),
2413-
],
2414-
[("measure_col", "double measure"), ("int_measure_col", "int measure")],
2415-
),
24162405
(
24172406
[("regular_col", ttypes.TTypeId.STRING_TYPE)],
24182407
None, # No arrow schema

0 commit comments

Comments
 (0)