Skip to content

Commit 55a18cd

Browse files
committed
Add docstring
1 parent 5fb8337 commit 55a18cd

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

src/chronos/chronos2/pipeline.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,32 @@ def predict_fev(
992992
def embed(
993993
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
994994
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
995+
"""
996+
Get encoder embeddings for the given time series.
997+
998+
Parameters
999+
----------
1000+
inputs
1001+
The time series to get embeddings for, can be one of:
1002+
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
1003+
will be shared among the different variates of each time series in the batch.
1004+
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
1005+
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
1006+
will be applied, if needed.
1007+
batch_size
1008+
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
1009+
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
1010+
context_length
1011+
The maximum context length used during for inference, by default set to the model's default context length
1012+
1013+
Returns
1014+
-------
1015+
embeddings
1016+
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
1017+
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
1018+
loc_scale
1019+
a list of tuples with the mean and standard deviation of each time series.
1020+
"""
9951021
if context_length is None:
9961022
context_length = self.model_context_length
9971023

0 commit comments

Comments
 (0)