-
Notifications
You must be signed in to change notification settings - Fork 500
Add Chronos2Pipeline.embed #361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -988,6 +988,81 @@ def predict_fev( | |||||||||
|
|
||||||||||
| return predictions_per_window, inference_time_s | ||||||||||
|
|
||||||||||
| @torch.no_grad() | ||||||||||
| def embed( | ||||||||||
| self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None | ||||||||||
| ) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]: | ||||||||||
| """ | ||||||||||
| Get encoder embeddings for the given time series. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| inputs | ||||||||||
| The time series to get embeddings for, can be one of: | ||||||||||
| - A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information | ||||||||||
| will be shared among the different variates of each time series in the batch. | ||||||||||
| - A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,) | ||||||||||
| or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding | ||||||||||
| will be applied, if needed. | ||||||||||
| batch_size | ||||||||||
| The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model. | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is clear to me: does the chronos-forecasting/src/chronos/chronos2/dataset.py Lines 402 to 405 in e48f480
I see this is pretty much the description of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Internally, there's no notion of a variate dimension in the model: only batch and time (patch) axes. The batch_size here refers to the maximum |
||||||||||
| If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256 | ||||||||||
| context_length | ||||||||||
| The maximum context length used during for inference, by default set to the model's default context length | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| embeddings | ||||||||||
| a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number | ||||||||||
| of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token. | ||||||||||
| loc_scale | ||||||||||
| a list of tuples with the mean and standard deviation of each time series. | ||||||||||
| """ | ||||||||||
| if context_length is None: | ||||||||||
| context_length = self.model_context_length | ||||||||||
|
|
||||||||||
| if context_length > self.model_context_length: | ||||||||||
| warnings.warn( | ||||||||||
| f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. " | ||||||||||
| f"Resetting context_length to {self.model_context_length}." | ||||||||||
| ) | ||||||||||
| context_length = self.model_context_length | ||||||||||
|
|
||||||||||
| test_dataset = Chronos2Dataset.convert_inputs( | ||||||||||
| inputs=inputs, | ||||||||||
| context_length=context_length, | ||||||||||
| prediction_length=0, | ||||||||||
| batch_size=batch_size, | ||||||||||
| output_patch_size=self.model_output_patch_size, | ||||||||||
| mode=DatasetMode.TEST, | ||||||||||
| ) | ||||||||||
| test_loader = DataLoader( | ||||||||||
| test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False | ||||||||||
| ) | ||||||||||
| all_embeds: list[torch.Tensor] = [] | ||||||||||
| all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = [] | ||||||||||
| for batch in test_loader: | ||||||||||
| assert batch["future_target"] is None | ||||||||||
| batch_context = batch["context"] | ||||||||||
| batch_group_ids = batch["group_ids"] | ||||||||||
| batch_target_idx_ranges = batch["target_idx_ranges"] | ||||||||||
|
|
||||||||||
| encoder_outputs, (locs, scales), *_ = self.model.encode( | ||||||||||
| context=batch_context.to(device=self.model.device, dtype=torch.float32), | ||||||||||
| group_ids=batch_group_ids.to(self.model.device), | ||||||||||
| ) | ||||||||||
| batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges] | ||||||||||
| batch_loc_scales = list( | ||||||||||
| zip( | ||||||||||
| [locs[start:end].cpu() for (start, end) in batch_target_idx_ranges], | ||||||||||
| [scales[start:end].cpu() for (start, end) in batch_target_idx_ranges], | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| all_embeds.extend(batch_embeds) | ||||||||||
| all_loc_scales.extend(batch_loc_scales) | ||||||||||
|
|
||||||||||
| return all_embeds, all_loc_scales | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | ||||||||||
| """ | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish the diff would be more helpful here: is the body of this simply moved from
forward?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the first (encoding) portion from
forwardhas been factored out intoencode.