From 6c6915501cba77f8ae8645a9d0397c3fc98cbc46 Mon Sep 17 00:00:00 2001 From: "Xiaoxuan(Alex) Zhang" <574817827@qq.com> Date: Wed, 29 Oct 2025 14:00:56 +0100 Subject: [PATCH] fix validate_and_prepare_single_dict_task by separating "past only" and "future known" cov keys (#344) Closes #345 Hi @abdulfatir Here is the bugfix about the function "validate_and_prepare_single_dict_task", which had 2 issue points: 1. Originally, one of this func return, the "task_n_future_covariates", will return the ["past only" + "future known"]covariates number, by `task_n_future_covariates = len(task_future_covariates_list)` as `task_future_covariates_list ` is filled by for` key in task_covariates_keys` 2. The code seems not to guarantee the last "future known" rows are atcually what we expected, even there is a sorted option. So, this PR fixed them by separating "past only" and "future known" covs from the "past_covariates" input, and explicitly put the "past only" covs rows above "future known" cov rows, supported by a temp list "ordered_covariate_keys". --- src/chronos/chronos2/dataset.py | 42 ++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 61a8a0f..86bf8b9 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -105,9 +105,29 @@ def validate_and_prepare_single_dict_task( f"Found invalid type for `past_covariates` in element at index {idx}. " f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}' ) + + # gather keys and ensure known-future keys come last to match downstream assumptions task_covariates_keys = sorted(task_past_covariates.keys()) + + task_future_covariates = task.get("future_covariates", {}) + if not isinstance(task_future_covariates, dict): + raise ValueError( + f"Found invalid type for `future_covariates` in element at index {idx}. " + f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}' + ) + task_future_covariates_keys = sorted(task_future_covariates.keys()) + if not set(task_future_covariates_keys).issubset(task_covariates_keys): + raise ValueError( + f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, " + f"but found {task_future_covariates_keys} in element at index {idx}" + ) + + # create ordered keys: past-only first, then known-future (so known-future are the last rows) + task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys + task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys + task_past_covariates_list: list[torch.Tensor] = [] - for key in task_covariates_keys: + for key in task_ordered_covariate_keys: tensor = task_past_covariates[key] if isinstance(tensor, np.ndarray): # apply encoding to categorical variates @@ -140,21 +160,10 @@ def validate_and_prepare_single_dict_task( if task_past_covariates_list else torch.zeros((0, history_length), device=task_target.device) ) - # validate future_covariates - task_future_covariates = task.get("future_covariates", {}) - if not isinstance(task_future_covariates, dict): - raise ValueError( - f"Found invalid type for `future_covariates` in element at index {idx}. " - f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}' - ) - task_future_covariates_keys = sorted(task_future_covariates.keys()) - if not set(task_future_covariates_keys).issubset(task_covariates_keys): - raise ValueError( - f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, " - f"but found {task_future_covariates_keys} in element at index {idx}" - ) + + # validate future_covariates (build rows in the same task_ordered_covariate_keys order) task_future_covariates_list: list[torch.Tensor] = [] - for key in task_covariates_keys: + for key in task_ordered_covariate_keys: # future values of past-only covariates are filled with NaNs tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan)) if isinstance(tensor, np.ndarray): @@ -186,7 +195,8 @@ def validate_and_prepare_single_dict_task( ).to(dtype=torch.float32) task_n_targets = task_target.shape[0] task_n_covariates = task_past_covariates_tensor.shape[0] - task_n_future_covariates = len(task_future_covariates_list) + # number of known-future covariates + task_n_future_covariates = len(task_future_covariates_keys) return ( task_context_tensor,