2121 TimestampParameter ,
2222 TinyIntParameter ,
2323 VoidParameter ,
24+ MapParameter ,
25+ ArrayParameter ,
2426)
2527from databricks .sql .parameters .native import (
2628 TDbsqlParameter ,
29+ TSparkParameter ,
2730 TSparkParameterValue ,
31+ TSparkParameterValueArg ,
2832 dbsql_parameter_from_primitive ,
2933)
3034from databricks .sql .thrift_api .TCLIService import ttypes
@@ -112,6 +116,8 @@ class Primitive(Enum):
112116 DOUBLE = 3.14
113117 FLOAT = 3.15
114118 SMALLINT = 51
119+ ARRAY = [1 , 2 , 3 ]
120+ MAP = {"a" : 1 , "b" : 2 }
115121
116122
117123class TestDbsqlParameter :
@@ -131,6 +137,8 @@ class TestDbsqlParameter:
131137 (TimestampParameter , Primitive .TIMESTAMP , "TIMESTAMP" ),
132138 (TimestampNTZParameter , Primitive .TIMESTAMP , "TIMESTAMP_NTZ" ),
133139 (TinyIntParameter , Primitive .INT , "TINYINT" ),
140+ (MapParameter , Primitive .MAP , "MAP" ),
141+ (ArrayParameter , Primitive .ARRAY , "ARRAY" ),
134142 ),
135143 )
136144 def test_cast_expression (
@@ -165,6 +173,18 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim):
165173 assert output == None
166174 else :
167175 assert output == TSparkParameterValue (stringValue = str (prim .value ))
176+
177+ @pytest .mark .parametrize (
178+ "base_type,input,expected_output" ,[
179+ (ArrayParameter , [1 ,2 ,3 ], TSparkParameter (ordinal = True , name = None , type = 'ARRAY' , value = None , arguments = [TSparkParameterValueArg (type = 'INT' , value = '1' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '2' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '3' , arguments = None )])),
180+ (MapParameter , {"a" : 1 , "b" : 2 }, TSparkParameter (ordinal = True , name = None , type = 'MAP' , value = None , arguments = [TSparkParameterValueArg (type = 'STRING' , value = 'a' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '1' , arguments = None ), TSparkParameterValueArg (type = 'STRING' , value = 'b' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '2' , arguments = None )])),
181+ (ArrayParameter ,[{"a" :1 ,"b" :2 },{"c" :3 ,"d" :4 }], TSparkParameter (ordinal = True , name = None , type = 'ARRAY' , value = None , arguments = [TSparkParameterValueArg (type = 'MAP' , value = None , arguments = [TSparkParameterValueArg (type = 'STRING' , value = 'a' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '1' , arguments = None ), TSparkParameterValueArg (type = 'STRING' , value = 'b' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '2' , arguments = None )]), TSparkParameterValueArg (type = 'MAP' , value = None , arguments = [TSparkParameterValueArg (type = 'STRING' , value = 'c' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '3' , arguments = None ), TSparkParameterValueArg (type = 'STRING' , value = 'd' , arguments = None ), TSparkParameterValueArg (type = 'INT' , value = '4' , arguments = None )])])),
182+ ]
183+ )
184+ def test_complex_type_tspark_param (self ,base_type ,input ,expected_output ):
185+ p = base_type (input )
186+ tsp = p .as_tspark_param ()
187+ assert tsp == expected_output
168188
169189 def test_tspark_param_named (self ):
170190 p = dbsql_parameter_from_primitive (Primitive .INT .value , name = "p" )
@@ -192,6 +212,8 @@ def test_tspark_param_ordinal(self):
192212 (FloatParameter , Primitive .FLOAT ),
193213 (VoidParameter , Primitive .NONE ),
194214 (TimestampParameter , Primitive .TIMESTAMP ),
215+ (MapParameter , Primitive .MAP ),
216+ (ArrayParameter , Primitive .ARRAY ),
195217 ),
196218 )
197219 def test_inference (self , _type : TDbsqlParameter , prim : Primitive ):
0 commit comments