Skip to content

Commit aedade2

Browse files
committed
Add Chronos2Pipeline.embed
1 parent c23d34c commit aedade2

File tree

2 files changed

+120
-47
lines changed

2 files changed

+120
-47
lines changed

src/chronos/chronos2/model.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,74 @@ def _compute_loss(
547547

548548
return loss
549549

550+
def encode(
551+
self,
552+
context: torch.Tensor,
553+
context_mask: torch.Tensor | None = None,
554+
group_ids: torch.Tensor | None = None,
555+
future_covariates: torch.Tensor | None = None,
556+
future_covariates_mask: torch.Tensor | None = None,
557+
num_output_patches: int = 1,
558+
future_target: torch.Tensor | None = None,
559+
future_target_mask: torch.Tensor | None = None,
560+
output_attentions: bool = False,
561+
):
562+
self._validate_input(
563+
context=context,
564+
context_mask=context_mask,
565+
future_covariates=future_covariates,
566+
future_covariates_mask=future_covariates_mask,
567+
group_ids=group_ids,
568+
num_output_patches=num_output_patches,
569+
future_target=future_target,
570+
future_target_mask=future_target_mask,
571+
)
572+
573+
batch_size = context.shape[0]
574+
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
575+
context=context, context_mask=context_mask
576+
)
577+
num_context_patches = attention_mask.shape[-1]
578+
579+
# get input embeddings of shape (batch, num_context_patches, d_model)
580+
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
581+
# append [REG] special token embedding, if needed
582+
if self.chronos_config.use_reg_token:
583+
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
584+
reg_embeds = self.shared(reg_input_ids)
585+
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
586+
attention_mask = torch.cat(
587+
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
588+
)
589+
590+
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
591+
future_covariates=future_covariates,
592+
future_covariates_mask=future_covariates_mask,
593+
loc_scale=loc_scale,
594+
num_output_patches=num_output_patches,
595+
batch_size=batch_size,
596+
)
597+
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
598+
599+
# get future embeddings of shape (batch, num_output_patches, d_model)
600+
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
601+
602+
# concatenate context and future embeddings and masks
603+
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
604+
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
605+
606+
if group_ids is None:
607+
# by default, each time series is treated independently, i.e., no mixing across the batch
608+
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
609+
610+
encoder_outputs: Chronos2EncoderOutput = self.encoder(
611+
attention_mask=attention_mask,
612+
inputs_embeds=input_embeds,
613+
group_ids=group_ids,
614+
output_attentions=output_attentions,
615+
)
616+
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
617+
550618
def forward(
551619
self,
552620
context: torch.Tensor,
@@ -625,63 +693,19 @@ def forward(
625693
- enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
626694
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
627695
"""
628-
629-
self._validate_input(
696+
batch_size = context.shape[0]
697+
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
630698
context=context,
631699
context_mask=context_mask,
700+
group_ids=group_ids,
632701
future_covariates=future_covariates,
633702
future_covariates_mask=future_covariates_mask,
634-
group_ids=group_ids,
635703
num_output_patches=num_output_patches,
636704
future_target=future_target,
637705
future_target_mask=future_target_mask,
638-
)
639-
640-
batch_size = context.shape[0]
641-
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
642-
context=context, context_mask=context_mask
643-
)
644-
num_context_patches = attention_mask.shape[-1]
645-
646-
# get input embeddings of shape (batch, num_context_patches, d_model)
647-
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
648-
# append [REG] special token embedding, if needed
649-
if self.chronos_config.use_reg_token:
650-
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
651-
reg_embeds = self.shared(reg_input_ids)
652-
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
653-
attention_mask = torch.cat(
654-
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
655-
)
656-
657-
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
658-
future_covariates=future_covariates,
659-
future_covariates_mask=future_covariates_mask,
660-
loc_scale=loc_scale,
661-
num_output_patches=num_output_patches,
662-
batch_size=batch_size,
663-
)
664-
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
665-
666-
# get future embeddings of shape (batch, num_output_patches, d_model)
667-
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
668-
669-
# concatenate context and future embeddings and masks
670-
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
671-
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
672-
673-
if group_ids is None:
674-
# by default, each time series is treated independently, i.e., no mixing across the batch
675-
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
676-
677-
encoder_outputs: Chronos2EncoderOutput = self.encoder(
678-
attention_mask=attention_mask,
679-
inputs_embeds=input_embeds,
680-
group_ids=group_ids,
681706
output_attentions=output_attentions,
682707
)
683708
hidden_states: torch.Tensor = encoder_outputs[0]
684-
685709
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
686710

687711
# slice the last num_output_patches hidden states to be input into the output_patch_embedding

src/chronos/chronos2/pipeline.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,55 @@ def predict_fev(
994994

995995
return predictions_per_window, inference_time_s
996996

997+
@torch.no_grad()
998+
def embed(
999+
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
1000+
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
1001+
if context_length is None:
1002+
context_length = self.model_context_length
1003+
1004+
if context_length > self.model_context_length:
1005+
warnings.warn(
1006+
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
1007+
f"Resetting context_length to {self.model_context_length}."
1008+
)
1009+
context_length = self.model_context_length
1010+
1011+
test_dataset = Chronos2Dataset.convert_inputs(
1012+
inputs=inputs,
1013+
context_length=context_length,
1014+
prediction_length=0,
1015+
batch_size=batch_size,
1016+
output_patch_size=self.model_output_patch_size,
1017+
mode=DatasetMode.TEST,
1018+
)
1019+
test_loader = DataLoader(
1020+
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
1021+
)
1022+
all_embeds: list[torch.Tensor] = []
1023+
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
1024+
for batch in test_loader:
1025+
assert batch["future_target"] is None
1026+
batch_context = batch["context"]
1027+
batch_group_ids = batch["group_ids"]
1028+
batch_target_idx_ranges = batch["target_idx_ranges"]
1029+
1030+
encoder_outputs, (locs, scales), *_ = self.model.encode(
1031+
context=batch_context.to(device=self.model.device, dtype=torch.float32),
1032+
group_ids=batch_group_ids.to(self.model.device),
1033+
)
1034+
batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges]
1035+
batch_loc_scales = list(
1036+
zip(
1037+
[locs[start:end].cpu() for (start, end) in batch_target_idx_ranges],
1038+
[scales[start:end].cpu() for (start, end) in batch_target_idx_ranges],
1039+
)
1040+
)
1041+
all_embeds.extend(batch_embeds)
1042+
all_loc_scales.extend(batch_loc_scales)
1043+
1044+
return all_embeds, all_loc_scales
1045+
9971046
@classmethod
9981047
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
9991048
"""

0 commit comments

Comments
 (0)