Mask past-only covariates during loss computation (#379)

*Issue #, if available:*

*Description of changes:* This PR masks rows corresponding to all
covariates in the future target. Specifically, this is to avoid the
contribution of past-only covariates in loss computation. The previous
setup was correct from the perspective of pretraining but I think this
makes more sense for fine-tuning.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-11-17 18:02:55 +01:00 committed by GitHub
parent 111972a6cc
commit 0f3c5652e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -477,6 +477,7 @@ class Chronos2Dataset(IterableDataset):
task_n_covariates,
task_n_future_covariates,
) = self.tasks[task_idx]
task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone()
task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
full_length = task_past_tensor.shape[-1]
@ -502,7 +503,9 @@ class Chronos2Dataset(IterableDataset):
# the task_context_tensor by slicing the appropriate indices which we do below
if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
# the first task_n_targets elements in task_context_tensor are the targets
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length]
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
# mask out all rows corresponding to covariates
task_future_target[task_n_targets:] = torch.nan
if task_n_future_covariates > 0:
# the last task_n_future_covariates elements in task_context_tensor are the known covariates