diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index d542d47..afbc987 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -536,9 +536,9 @@ class Chronos2Dataset(IterableDataset): task = self.tasks[task_idx] task_past_tensor = task["context"].clone() task_future_tensor = task["future_covariates"].clone() - task_n_targets = task["n_targets"] - task_n_covariates = task["n_covariates"] - task_n_future_covariates = task["n_future_covariates"] + task_n_targets = int(task["n_targets"]) + task_n_covariates = int(task["n_covariates"]) + task_n_future_covariates = int(task["n_future_covariates"]) task_n_past_only_covariates = task_n_covariates - task_n_future_covariates full_length = task_past_tensor.shape[-1]