55from typing import Dict , List , Type , Union
66from unittest .mock import patch
77
8+ import time
9+ import numpy as np
810import pytest
911import pytz
12+ from numpy .random .mtrand import Sequence
1013
1114from databricks .sql .parameters .native import (
1215 BigIntegerParameter ,
@@ -112,6 +115,7 @@ class TestParameterizedQueries(PySQLPytestTestCase):
112115 Primitive .NONE : "null_col" ,
113116 }
114117
118+
115119 def _get_inline_table_column (self , value ):
116120 return self .inline_type_map [Primitive (value )]
117121
@@ -142,7 +146,9 @@ def inline_table(self, connection_details):
142146 date_col DATE,
143147 timestamp_col TIMESTAMP,
144148 array_col ARRAY<STRING>,
145- map_col MAP<STRING, INT>
149+ map_col MAP<STRING, INT>,
150+ array_map_col ARRAY<MAP<STRING,INT>>,
151+ map_array_col MAP<INT,ARRAY<STRING>>
146152 ) USING DELTA
147153 """
148154
@@ -163,7 +169,7 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru
163169 finally :
164170 pass
165171
166- def _inline_roundtrip (self , params : dict , paramstyle : ParamStyle ):
172+ def _inline_roundtrip (self , params : dict , paramstyle : ParamStyle , target_column ):
167173 """This INSERT, SELECT, DELETE dance is necessary because simply selecting
168174 ```
169175 "SELECT %(param)s"
@@ -174,7 +180,6 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle):
174180 :paramstyle:
175181 This is a no-op but is included to make the test-code easier to read.
176182 """
177- target_column = self ._get_inline_table_column (params .get ("p" ))
178183 INSERT_QUERY = f"INSERT INTO ___________________first.jprakash.pysql_e2e_inline_param_test_table (`{ target_column } `) VALUES (%(p)s)"
179184 SELECT_QUERY = f"SELECT { target_column } `col` FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table LIMIT 1"
180185 DELETE_QUERY = "DELETE FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table"
@@ -220,7 +225,7 @@ def _get_one_result(
220225 if approach == ParameterApproach .INLINE :
221226 # inline mode always uses ParamStyle.PYFORMAT
222227 # inline mode doesn't support positional parameters
223- return self ._inline_roundtrip (params , paramstyle = ParamStyle .PYFORMAT )
228+ return self ._inline_roundtrip (params , paramstyle = ParamStyle .PYFORMAT , target_column = self . _get_inline_table_column ( params . get ( "p" )) )
224229 elif approach == ParameterApproach .NATIVE :
225230 # native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT
226231 # native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL
@@ -250,6 +255,57 @@ def _eq(self, actual, expected: Primitive):
250255
251256 return actual_parsed == expected_parsed
252257
258+ def _parse_to_common_type (self , value ):
259+ """
260+ Function to convert the :value passed into a common python datatype for comparison
261+
262+ Convertion fyi
263+ MAP Datatype on server is returned as a list of tuples
264+ Ex:
265+ {"a":1,"b":2} -> [("a",1),("b",2)]
266+
267+ ARRAY Datatype on server is returned as a numpy array
268+ Ex:
269+ ["a","b","c"] -> np.array(["a","b","c"],dtype=object)
270+
271+ Primitive datatype on server is returned as a numpy primitive
272+ Ex:
273+ 1 -> np.int64(1)
274+ 2 -> np.int32(2)
275+ """
276+ if value is None :
277+ return None
278+ elif isinstance (value , (Sequence ,np .ndarray )) and not isinstance (value , (str , bytes )):
279+ return tuple (value )
280+ elif isinstance (value ,dict ):
281+ return tuple (value .items ())
282+ elif isinstance (value , np .generic ):
283+ return value .item ()
284+ else :
285+ return value
286+
287+ def _recursive_compare (self ,actual , expected ):
288+ """
289+ Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level
290+
291+ Note: Complex datatype like MAP is not returned as a dictionary but as a list of tuples
292+ """
293+ actual_parsed = self ._parse_to_common_type (actual )
294+ expected_parsed = self ._parse_to_common_type (expected )
295+
296+ # Check if types are the same
297+ if type (actual_parsed ) != type (expected_parsed ):
298+ return False
299+
300+ # Handle lists or tuples
301+ if isinstance (actual_parsed , (list , tuple )):
302+ if len (actual_parsed ) != len (expected_parsed ):
303+ return False
304+ return all (self ._recursive_compare (o1 , o2 ) for o1 , o2 in zip (actual_parsed , expected_parsed ))
305+
306+ return actual_parsed == expected_parsed
307+
308+
253309 @pytest .mark .parametrize ("primitive" , Primitive )
254310 @pytest .mark .parametrize (
255311 "approach,paramstyle,parameter_structure" , approach_paramstyle_combinations
@@ -379,6 +435,43 @@ def test_readme_example(self):
379435 assert len (result ) == 10
380436 assert result [0 ].p == "foo"
381437
438+ @pytest .mark .parametrize (
439+ "col_name,data" ,[
440+ ("array_map_col" , [{"a" : 1 , "b" : 2 }, {"c" : 3 , "d" : 4 }]),
441+ ("map_array_col" , {1 : ["a" , "b" ], 2 : ["c" , "d" ]}),
442+ ])
443+ def test_inline_recursive_complex_type (self ,col_name ,data ):
444+ params = {'p' :data }
445+ result = self ._inline_roundtrip (params = params ,paramstyle = ParamStyle .PYFORMAT ,target_column = col_name )
446+ assert self ._recursive_compare (result .col , data )
447+
448+
449+ @pytest .mark .parametrize (
450+ "description,data" ,
451+ [
452+ ("ARRAY<MAP<STRING,INT>>" ,[{"a" : 1 , "b" : 2 }, {"c" : 3 , "d" : 4 }]),
453+ ("MAP<INT,ARRAY<STRING>>" ,{1 : ["a" , "b" ], 2 : ["c" , "d" ]}),
454+ ("ARRAY<ARRAY<INT>>" ,[[1 ,2 ,3 ],[1 ,2 ,3 ]]),
455+ ("ARRAY<ARRAY<ARRAY<INT>>>" ,[[[1 ,2 ,3 ],[1 ,2 ,3 ]],[[1 ,2 ,3 ],[1 ,2 ,3 ]]]),
456+ ("MAP<STRING,MAP<STRING,STRING>>" ,{"a" :{"b" :"c" ,"d" :"e" },"f" :{"g" :"h" ,"i" :"j" }})
457+ ],
458+ )
459+ @pytest .mark .parametrize (
460+ "paramstyle,parameter_structure" ,
461+ [
462+ (ParamStyle .NONE , ParameterStructure .POSITIONAL ),
463+ (ParamStyle .PYFORMAT , ParameterStructure .NAMED ),
464+ (ParamStyle .NAMED , ParameterStructure .NAMED ),
465+ ]
466+ )
467+ def test_native_recursive_complex_type (self ,description ,data ,paramstyle ,parameter_structure ):
468+ if (paramstyle == ParamStyle .NONE ):
469+ params = [data ]
470+ else :
471+ params = {'p' :data }
472+ result = self ._native_roundtrip (parameters = params ,paramstyle = paramstyle ,parameter_structure = parameter_structure )
473+ assert self ._recursive_compare (result .col , data )
474+
382475
383476class TestInlineParameterSyntax (PySQLPytestTestCase ):
384477 """The inline parameter approach uses pyformat markers"""
0 commit comments