Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 71 additions & 47 deletions src/chronos/chronos2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,74 @@ def _compute_loss(

return loss

def encode(
self,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
group_ids: torch.Tensor | None = None,
future_covariates: torch.Tensor | None = None,
future_covariates_mask: torch.Tensor | None = None,
num_output_patches: int = 1,
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
):
Comment on lines +550 to +561
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish the diff would be more helpful here: is the body of this simply moved from forward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the first (encoding) portion from forward has been factored out into encode.

self._validate_input(
context=context,
context_mask=context_mask,
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
group_ids=group_ids,
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
)

batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
)
num_context_patches = attention_mask.shape[-1]

# get input embeddings of shape (batch, num_context_patches, d_model)
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
)

patched_future, patched_future_covariates_mask = self._prepare_patched_future(
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
loc_scale=loc_scale,
num_output_patches=num_output_patches,
batch_size=batch_size,
)
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)

# get future embeddings of shape (batch, num_output_patches, d_model)
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)

# concatenate context and future embeddings and masks
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)

if group_ids is None:
# by default, each time series is treated independently, i.e., no mixing across the batch
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)

encoder_outputs: Chronos2EncoderOutput = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
group_ids=group_ids,
output_attentions=output_attentions,
)
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches

def forward(
self,
context: torch.Tensor,
Expand Down Expand Up @@ -625,63 +693,19 @@ def forward(
- enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
"""

self._validate_input(
batch_size = context.shape[0]
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
context=context,
context_mask=context_mask,
group_ids=group_ids,
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
group_ids=group_ids,
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
)

batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
)
num_context_patches = attention_mask.shape[-1]

# get input embeddings of shape (batch, num_context_patches, d_model)
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
)

patched_future, patched_future_covariates_mask = self._prepare_patched_future(
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
loc_scale=loc_scale,
num_output_patches=num_output_patches,
batch_size=batch_size,
)
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)

# get future embeddings of shape (batch, num_output_patches, d_model)
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)

# concatenate context and future embeddings and masks
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)

if group_ids is None:
# by default, each time series is treated independently, i.e., no mixing across the batch
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)

encoder_outputs: Chronos2EncoderOutput = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
group_ids=group_ids,
output_attentions=output_attentions,
)
hidden_states: torch.Tensor = encoder_outputs[0]

assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)

# slice the last num_output_patches hidden states to be input into the output_patch_embedding
Expand Down
75 changes: 75 additions & 0 deletions src/chronos/chronos2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,81 @@ def predict_fev(

return predictions_per_window, inference_time_s

@torch.no_grad()
def embed(
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
"""
Get encoder embeddings for the given time series.

Parameters
----------
inputs
The time series to get embeddings for, can be one of:
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
will be shared among the different variates of each time series in the batch.
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
will be applied, if needed.
batch_size
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is clear to me: does the batch_size refer to the .shape[0] of the tensors being processed? Or does it span the variates dimension as well? (.shape[1]) I suppose it's the latter, given the docstring for the dataset class:

batch_size
The batch size for training the model. Note that the batch size here means the number of time series, including target(s) and
covariates, that are input into the model. If your data has multiple target and/or covariates, the effective number of time series
tasks in a batch will be lower than this value.

I see this is pretty much the description of batch_size everywhere (here, predict methods, dataset class). Maybe the confusion comes from "total number of time series" instead of "total number of variates", or something like that. But this could also be addressed separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally, there's no notion of a variate dimension in the model: only batch and time (patch) axes. The batch_size here refers to the maximum items x (co)-variates per batch. Open to suggestions on a better docstring.

If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
context_length
The maximum context length used during for inference, by default set to the model's default context length

Returns
-------
embeddings
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
loc_scale
a list of tuples with the mean and standard deviation of each time series.
"""
if context_length is None:
context_length = self.model_context_length

if context_length > self.model_context_length:
warnings.warn(
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
f"Resetting context_length to {self.model_context_length}."
)
context_length = self.model_context_length

test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
context_length=context_length,
prediction_length=0,
batch_size=batch_size,
output_patch_size=self.model_output_patch_size,
mode=DatasetMode.TEST,
)
test_loader = DataLoader(
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
)
all_embeds: list[torch.Tensor] = []
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
for batch in test_loader:
assert batch["future_target"] is None
batch_context = batch["context"]
batch_group_ids = batch["group_ids"]
batch_target_idx_ranges = batch["target_idx_ranges"]

encoder_outputs, (locs, scales), *_ = self.model.encode(
context=batch_context.to(device=self.model.device, dtype=torch.float32),
group_ids=batch_group_ids.to(self.model.device),
)
batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges]
batch_loc_scales = list(
zip(
[locs[start:end].cpu() for (start, end) in batch_target_idx_ranges],
[scales[start:end].cpu() for (start, end) in batch_target_idx_ranges],
)
)
all_embeds.extend(batch_embeds)
all_loc_scales.extend(batch_loc_scales)

return all_embeds, all_loc_scales

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Expand Down
29 changes: 29 additions & 0 deletions test/test_chronos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32)


@pytest.mark.parametrize(
"inputs, expected_output_shapes",
[
# NOTE: d_model for the dummy model is 6
# Homogenous univariate task
(torch.rand(4, 1, 16), [(1, 3, 6)] * 4),
# Homogenous multivariate task
(torch.rand(4, 3, 37), [(3, 5, 6)] * 4),
# Heterogenous tasks with different history lengths
(
[torch.rand(100), torch.rand(2, 150), torch.rand(120)],
[(1, 12, 6), (2, 12, 6), (1, 12, 6)],
),
],
)
def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes):
embeds, loc_scales = pipeline.embed(inputs)

assert (
isinstance(embeds, list)
and len(embeds) == len(expected_output_shapes)
and len(loc_scales) == len(expected_output_shapes)
)
for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes):
validate_tensor(embed, expected_shape, dtype=torch.float32)
validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32)
validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32)


@pytest.mark.parametrize(
"task_kwargs",
[
Expand Down