You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
*Issue #, if available:* #354
*Description of changes:* This PR adds `Chronos2Pipeline.embed` to
enable users to extract embeddings from the last encoder layer in an
easy way. The API and behavior is similar to what Chronos and
Chronos-Bolt provides.
By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
The time series to get embeddings for, can be one of:
1002
+
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
1003
+
will be shared among the different variates of each time series in the batch.
1004
+
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
1005
+
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
1006
+
will be applied, if needed.
1007
+
batch_size
1008
+
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
1009
+
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
1010
+
context_length
1011
+
The maximum context length used during for inference, by default set to the model's default context length
1012
+
1013
+
Returns
1014
+
-------
1015
+
embeddings
1016
+
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
1017
+
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
1018
+
loc_scale
1019
+
a list of tuples with the mean and standard deviation of each time series.
1020
+
"""
1021
+
ifcontext_lengthisNone:
1022
+
context_length=self.model_context_length
1023
+
1024
+
ifcontext_length>self.model_context_length:
1025
+
warnings.warn(
1026
+
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
1027
+
f"Resetting context_length to {self.model_context_length}."
0 commit comments