@@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
340340 validate_tensor (quantiles_item , (3 , expected_num_quantiles , 7 ), dtype = torch .float32 )
341341
342342
343+ @pytest .mark .parametrize (
344+ "inputs, expected_output_shapes" ,
345+ [
346+ # NOTE: d_model for the dummy model is 6
347+ # Homogenous univariate task
348+ (torch .rand (4 , 1 , 16 ), [(1 , 3 , 6 )] * 4 ),
349+ # Homogenous multivariate task
350+ (torch .rand (4 , 3 , 37 ), [(3 , 5 , 6 )] * 4 ),
351+ # Heterogenous tasks with different history lengths
352+ (
353+ [torch .rand (100 ), torch .rand (2 , 150 ), torch .rand (120 )],
354+ [(1 , 12 , 6 ), (2 , 12 , 6 ), (1 , 12 , 6 )],
355+ ),
356+ ],
357+ )
358+ def test_when_input_is_valid_then_pipeline_can_embed (pipeline , inputs , expected_output_shapes ):
359+ embeds , loc_scales = pipeline .embed (inputs )
360+
361+ assert (
362+ isinstance (embeds , list )
363+ and len (embeds ) == len (expected_output_shapes )
364+ and len (loc_scales ) == len (expected_output_shapes )
365+ )
366+ for embed , loc_scale , expected_shape in zip (embeds , loc_scales , expected_output_shapes ):
367+ validate_tensor (embed , expected_shape , dtype = torch .float32 )
368+ validate_tensor (loc_scale [0 ], (expected_shape [0 ], 1 ), dtype = torch .float32 )
369+ validate_tensor (loc_scale [1 ], (expected_shape [0 ], 1 ), dtype = torch .float32 )
370+
371+
343372@pytest .mark .parametrize (
344373 "task_kwargs" ,
345374 [
0 commit comments