mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
fix validate_and_prepare_single_dict_task by separating "past only" and "future known" cov keys (#344)
Closes #345 Hi @abdulfatir Here is the bugfix about the function "validate_and_prepare_single_dict_task", which had 2 issue points: 1. Originally, one of this func return, the "task_n_future_covariates", will return the ["past only" + "future known"]covariates number, by `task_n_future_covariates = len(task_future_covariates_list)` as `task_future_covariates_list ` is filled by for` key in task_covariates_keys` 2. The code seems not to guarantee the last "future known" rows are atcually what we expected, even there is a sorted option. So, this PR fixed them by separating "past only" and "future known" covs from the "past_covariates" input, and explicitly put the "past only" covs rows above "future known" cov rows, supported by a temp list "ordered_covariate_keys".
This commit is contained in:
parent
6d46e628ae
commit
6c6915501c
1 changed files with 26 additions and 16 deletions
|
|
@ -105,9 +105,29 @@ def validate_and_prepare_single_dict_task(
|
|||
f"Found invalid type for `past_covariates` in element at index {idx}. "
|
||||
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}'
|
||||
)
|
||||
|
||||
# gather keys and ensure known-future keys come last to match downstream assumptions
|
||||
task_covariates_keys = sorted(task_past_covariates.keys())
|
||||
|
||||
task_future_covariates = task.get("future_covariates", {})
|
||||
if not isinstance(task_future_covariates, dict):
|
||||
raise ValueError(
|
||||
f"Found invalid type for `future_covariates` in element at index {idx}. "
|
||||
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
|
||||
)
|
||||
task_future_covariates_keys = sorted(task_future_covariates.keys())
|
||||
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
|
||||
raise ValueError(
|
||||
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
|
||||
f"but found {task_future_covariates_keys} in element at index {idx}"
|
||||
)
|
||||
|
||||
# create ordered keys: past-only first, then known-future (so known-future are the last rows)
|
||||
task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys
|
||||
task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys
|
||||
|
||||
task_past_covariates_list: list[torch.Tensor] = []
|
||||
for key in task_covariates_keys:
|
||||
for key in task_ordered_covariate_keys:
|
||||
tensor = task_past_covariates[key]
|
||||
if isinstance(tensor, np.ndarray):
|
||||
# apply encoding to categorical variates
|
||||
|
|
@ -140,21 +160,10 @@ def validate_and_prepare_single_dict_task(
|
|||
if task_past_covariates_list
|
||||
else torch.zeros((0, history_length), device=task_target.device)
|
||||
)
|
||||
# validate future_covariates
|
||||
task_future_covariates = task.get("future_covariates", {})
|
||||
if not isinstance(task_future_covariates, dict):
|
||||
raise ValueError(
|
||||
f"Found invalid type for `future_covariates` in element at index {idx}. "
|
||||
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
|
||||
)
|
||||
task_future_covariates_keys = sorted(task_future_covariates.keys())
|
||||
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
|
||||
raise ValueError(
|
||||
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
|
||||
f"but found {task_future_covariates_keys} in element at index {idx}"
|
||||
)
|
||||
|
||||
# validate future_covariates (build rows in the same task_ordered_covariate_keys order)
|
||||
task_future_covariates_list: list[torch.Tensor] = []
|
||||
for key in task_covariates_keys:
|
||||
for key in task_ordered_covariate_keys:
|
||||
# future values of past-only covariates are filled with NaNs
|
||||
tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan))
|
||||
if isinstance(tensor, np.ndarray):
|
||||
|
|
@ -186,7 +195,8 @@ def validate_and_prepare_single_dict_task(
|
|||
).to(dtype=torch.float32)
|
||||
task_n_targets = task_target.shape[0]
|
||||
task_n_covariates = task_past_covariates_tensor.shape[0]
|
||||
task_n_future_covariates = len(task_future_covariates_list)
|
||||
# number of known-future covariates
|
||||
task_n_future_covariates = len(task_future_covariates_keys)
|
||||
|
||||
return (
|
||||
task_context_tensor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue