From 0f3c5652e52ee9382d578efc0c9b2add6fd2c90a Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Mon, 17 Nov 2025 18:02:55 +0100 Subject: [PATCH] Mask past-only covariates during loss computation (#379) *Issue #, if available:* *Description of changes:* This PR masks rows corresponding to all covariates in the future target. Specifically, this is to avoid the contribution of past-only covariates in loss computation. The previous setup was correct from the perspective of pretraining but I think this makes more sense for fine-tuning. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 8d6fd3f..cb75571 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -477,6 +477,7 @@ class Chronos2Dataset(IterableDataset): task_n_covariates, task_n_future_covariates, ) = self.tasks[task_idx] + task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone() task_n_past_only_covariates = task_n_covariates - task_n_future_covariates full_length = task_past_tensor.shape[-1] @@ -502,7 +503,9 @@ class Chronos2Dataset(IterableDataset): # the task_context_tensor by slicing the appropriate indices which we do below if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: # the first task_n_targets elements in task_context_tensor are the targets - task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length] + task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone() + # mask out all rows corresponding to covariates + task_future_target[task_n_targets:] = torch.nan if task_n_future_covariates > 0: # the last task_n_future_covariates elements in task_context_tensor are the known covariates