Skip to content

Commit 0f3c565

Browse files
authored
Mask past-only covariates during loss computation (#379)
*Issue #, if available:* *Description of changes:* This PR masks rows corresponding to all covariates in the future target. Specifically, this is to avoid the contribution of past-only covariates in loss computation. The previous setup was correct from the perspective of pretraining but I think this makes more sense for fine-tuning. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 111972a commit 0f3c565

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/chronos/chronos2/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor |
477477
task_n_covariates,
478478
task_n_future_covariates,
479479
) = self.tasks[task_idx]
480+
task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone()
480481
task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
481482

482483
full_length = task_past_tensor.shape[-1]
@@ -502,7 +503,9 @@ def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor |
502503
# the task_context_tensor by slicing the appropriate indices which we do below
503504
if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
504505
# the first task_n_targets elements in task_context_tensor are the targets
505-
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length]
506+
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
507+
# mask out all rows corresponding to covariates
508+
task_future_target[task_n_targets:] = torch.nan
506509

507510
if task_n_future_covariates > 0:
508511
# the last task_n_future_covariates elements in task_context_tensor are the known covariates

0 commit comments

Comments
 (0)