Skip to content

Commit b0f9622

Browse files
committed
Add test
1 parent 55a18cd commit b0f9622

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/test_chronos2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
340340
validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32)
341341

342342

343+
@pytest.mark.parametrize(
344+
"inputs, expected_output_shapes",
345+
[
346+
# NOTE: d_model for the dummy model is 6
347+
# Homogenous univariate task
348+
(torch.rand(4, 1, 16), [(1, 3, 6)] * 4),
349+
# Homogenous multivariate task
350+
(torch.rand(4, 3, 37), [(3, 5, 6)] * 4),
351+
# Heterogenous tasks with different history lengths
352+
(
353+
[torch.rand(100), torch.rand(2, 150), torch.rand(120)],
354+
[(1, 12, 6), (2, 12, 6), (1, 12, 6)],
355+
),
356+
],
357+
)
358+
def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes):
359+
embeds, loc_scales = pipeline.embed(inputs)
360+
361+
assert (
362+
isinstance(embeds, list)
363+
and len(embeds) == len(expected_output_shapes)
364+
and len(loc_scales) == len(expected_output_shapes)
365+
)
366+
for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes):
367+
validate_tensor(embed, expected_shape, dtype=torch.float32)
368+
validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32)
369+
validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32)
370+
371+
343372
@pytest.mark.parametrize(
344373
"task_kwargs",
345374
[

0 commit comments

Comments
 (0)