Skip to content

Commit 111972a

Browse files
authored
Add Chronos2Pipeline.embed (#361)
*Issue #, if available:* #354 *Description of changes:* This PR adds `Chronos2Pipeline.embed` to enable users to extract embeddings from the last encoder layer in an easy way. The API and behavior is similar to what Chronos and Chronos-Bolt provides. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent e48f480 commit 111972a

File tree

3 files changed

+175
-47
lines changed

3 files changed

+175
-47
lines changed

src/chronos/chronos2/model.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,74 @@ def _compute_loss(
547547

548548
return loss
549549

550+
def encode(
551+
self,
552+
context: torch.Tensor,
553+
context_mask: torch.Tensor | None = None,
554+
group_ids: torch.Tensor | None = None,
555+
future_covariates: torch.Tensor | None = None,
556+
future_covariates_mask: torch.Tensor | None = None,
557+
num_output_patches: int = 1,
558+
future_target: torch.Tensor | None = None,
559+
future_target_mask: torch.Tensor | None = None,
560+
output_attentions: bool = False,
561+
):
562+
self._validate_input(
563+
context=context,
564+
context_mask=context_mask,
565+
future_covariates=future_covariates,
566+
future_covariates_mask=future_covariates_mask,
567+
group_ids=group_ids,
568+
num_output_patches=num_output_patches,
569+
future_target=future_target,
570+
future_target_mask=future_target_mask,
571+
)
572+
573+
batch_size = context.shape[0]
574+
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
575+
context=context, context_mask=context_mask
576+
)
577+
num_context_patches = attention_mask.shape[-1]
578+
579+
# get input embeddings of shape (batch, num_context_patches, d_model)
580+
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
581+
# append [REG] special token embedding, if needed
582+
if self.chronos_config.use_reg_token:
583+
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
584+
reg_embeds = self.shared(reg_input_ids)
585+
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
586+
attention_mask = torch.cat(
587+
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
588+
)
589+
590+
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
591+
future_covariates=future_covariates,
592+
future_covariates_mask=future_covariates_mask,
593+
loc_scale=loc_scale,
594+
num_output_patches=num_output_patches,
595+
batch_size=batch_size,
596+
)
597+
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
598+
599+
# get future embeddings of shape (batch, num_output_patches, d_model)
600+
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
601+
602+
# concatenate context and future embeddings and masks
603+
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
604+
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
605+
606+
if group_ids is None:
607+
# by default, each time series is treated independently, i.e., no mixing across the batch
608+
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
609+
610+
encoder_outputs: Chronos2EncoderOutput = self.encoder(
611+
attention_mask=attention_mask,
612+
inputs_embeds=input_embeds,
613+
group_ids=group_ids,
614+
output_attentions=output_attentions,
615+
)
616+
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
617+
550618
def forward(
551619
self,
552620
context: torch.Tensor,
@@ -625,63 +693,19 @@ def forward(
625693
- enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
626694
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
627695
"""
628-
629-
self._validate_input(
696+
batch_size = context.shape[0]
697+
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
630698
context=context,
631699
context_mask=context_mask,
700+
group_ids=group_ids,
632701
future_covariates=future_covariates,
633702
future_covariates_mask=future_covariates_mask,
634-
group_ids=group_ids,
635703
num_output_patches=num_output_patches,
636704
future_target=future_target,
637705
future_target_mask=future_target_mask,
638-
)
639-
640-
batch_size = context.shape[0]
641-
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
642-
context=context, context_mask=context_mask
643-
)
644-
num_context_patches = attention_mask.shape[-1]
645-
646-
# get input embeddings of shape (batch, num_context_patches, d_model)
647-
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
648-
# append [REG] special token embedding, if needed
649-
if self.chronos_config.use_reg_token:
650-
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
651-
reg_embeds = self.shared(reg_input_ids)
652-
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
653-
attention_mask = torch.cat(
654-
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
655-
)
656-
657-
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
658-
future_covariates=future_covariates,
659-
future_covariates_mask=future_covariates_mask,
660-
loc_scale=loc_scale,
661-
num_output_patches=num_output_patches,
662-
batch_size=batch_size,
663-
)
664-
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
665-
666-
# get future embeddings of shape (batch, num_output_patches, d_model)
667-
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
668-
669-
# concatenate context and future embeddings and masks
670-
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
671-
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
672-
673-
if group_ids is None:
674-
# by default, each time series is treated independently, i.e., no mixing across the batch
675-
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
676-
677-
encoder_outputs: Chronos2EncoderOutput = self.encoder(
678-
attention_mask=attention_mask,
679-
inputs_embeds=input_embeds,
680-
group_ids=group_ids,
681706
output_attentions=output_attentions,
682707
)
683708
hidden_states: torch.Tensor = encoder_outputs[0]
684-
685709
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
686710

687711
# slice the last num_output_patches hidden states to be input into the output_patch_embedding

src/chronos/chronos2/pipeline.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,81 @@ def predict_fev(
988988

989989
return predictions_per_window, inference_time_s
990990

991+
@torch.no_grad()
992+
def embed(
993+
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
994+
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
995+
"""
996+
Get encoder embeddings for the given time series.
997+
998+
Parameters
999+
----------
1000+
inputs
1001+
The time series to get embeddings for, can be one of:
1002+
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
1003+
will be shared among the different variates of each time series in the batch.
1004+
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
1005+
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
1006+
will be applied, if needed.
1007+
batch_size
1008+
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
1009+
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
1010+
context_length
1011+
The maximum context length used during for inference, by default set to the model's default context length
1012+
1013+
Returns
1014+
-------
1015+
embeddings
1016+
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
1017+
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
1018+
loc_scale
1019+
a list of tuples with the mean and standard deviation of each time series.
1020+
"""
1021+
if context_length is None:
1022+
context_length = self.model_context_length
1023+
1024+
if context_length > self.model_context_length:
1025+
warnings.warn(
1026+
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
1027+
f"Resetting context_length to {self.model_context_length}."
1028+
)
1029+
context_length = self.model_context_length
1030+
1031+
test_dataset = Chronos2Dataset.convert_inputs(
1032+
inputs=inputs,
1033+
context_length=context_length,
1034+
prediction_length=0,
1035+
batch_size=batch_size,
1036+
output_patch_size=self.model_output_patch_size,
1037+
mode=DatasetMode.TEST,
1038+
)
1039+
test_loader = DataLoader(
1040+
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
1041+
)
1042+
all_embeds: list[torch.Tensor] = []
1043+
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
1044+
for batch in test_loader:
1045+
assert batch["future_target"] is None
1046+
batch_context = batch["context"]
1047+
batch_group_ids = batch["group_ids"]
1048+
batch_target_idx_ranges = batch["target_idx_ranges"]
1049+
1050+
encoder_outputs, (locs, scales), *_ = self.model.encode(
1051+
context=batch_context.to(device=self.model.device, dtype=torch.float32),
1052+
group_ids=batch_group_ids.to(self.model.device),
1053+
)
1054+
batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges]
1055+
batch_loc_scales = list(
1056+
zip(
1057+
[locs[start:end].cpu() for (start, end) in batch_target_idx_ranges],
1058+
[scales[start:end].cpu() for (start, end) in batch_target_idx_ranges],
1059+
)
1060+
)
1061+
all_embeds.extend(batch_embeds)
1062+
all_loc_scales.extend(batch_loc_scales)
1063+
1064+
return all_embeds, all_loc_scales
1065+
9911066
@classmethod
9921067
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
9931068
"""

test/test_chronos2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
340340
validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32)
341341

342342

343+
@pytest.mark.parametrize(
344+
"inputs, expected_output_shapes",
345+
[
346+
# NOTE: d_model for the dummy model is 6
347+
# Homogenous univariate task
348+
(torch.rand(4, 1, 16), [(1, 3, 6)] * 4),
349+
# Homogenous multivariate task
350+
(torch.rand(4, 3, 37), [(3, 5, 6)] * 4),
351+
# Heterogenous tasks with different history lengths
352+
(
353+
[torch.rand(100), torch.rand(2, 150), torch.rand(120)],
354+
[(1, 12, 6), (2, 12, 6), (1, 12, 6)],
355+
),
356+
],
357+
)
358+
def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes):
359+
embeds, loc_scales = pipeline.embed(inputs)
360+
361+
assert (
362+
isinstance(embeds, list)
363+
and len(embeds) == len(expected_output_shapes)
364+
and len(loc_scales) == len(expected_output_shapes)
365+
)
366+
for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes):
367+
validate_tensor(embed, expected_shape, dtype=torch.float32)
368+
validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32)
369+
validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32)
370+
371+
343372
@pytest.mark.parametrize(
344373
"task_kwargs",
345374
[

0 commit comments

Comments
 (0)