mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Mask past-only covariates during loss computation (#379)
*Issue #, if available:* *Description of changes:* This PR masks rows corresponding to all covariates in the future target. Specifically, this is to avoid the contribution of past-only covariates in loss computation. The previous setup was correct from the perspective of pretraining but I think this makes more sense for fine-tuning. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
111972a6cc
commit
0f3c5652e5
1 changed files with 4 additions and 1 deletions
|
|
@ -477,6 +477,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
task_n_covariates,
|
||||
task_n_future_covariates,
|
||||
) = self.tasks[task_idx]
|
||||
task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone()
|
||||
task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
|
||||
|
||||
full_length = task_past_tensor.shape[-1]
|
||||
|
|
@ -502,7 +503,9 @@ class Chronos2Dataset(IterableDataset):
|
|||
# the task_context_tensor by slicing the appropriate indices which we do below
|
||||
if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
|
||||
# the first task_n_targets elements in task_context_tensor are the targets
|
||||
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length]
|
||||
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
|
||||
# mask out all rows corresponding to covariates
|
||||
task_future_target[task_n_targets:] = torch.nan
|
||||
|
||||
if task_n_future_covariates > 0:
|
||||
# the last task_n_future_covariates elements in task_context_tensor are the known covariates
|
||||
|
|
|
|||
Loading…
Reference in a new issue