Skip to content

Commit 9fd3278

Browse files
committed
Mask past-only covariates during loss computation
1 parent 111972a commit 9fd3278

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/chronos/chronos2/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor |
503503
if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
504504
# the first task_n_targets elements in task_context_tensor are the targets
505505
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length]
506+
# mask out all rows corresponding to covariates
507+
task_future_target[task_n_targets:] = torch.nan
506508

507509
if task_n_future_covariates > 0:
508510
# the last task_n_future_covariates elements in task_context_tensor are the known covariates

0 commit comments

Comments
 (0)