mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Address PR comments
This commit is contained in:
parent
588f19f1f8
commit
8928819100
2 changed files with 30 additions and 50 deletions
|
|
@ -227,9 +227,6 @@ def prepare_tasks(
|
|||
This function handles mode-specific preprocessing (e.g., filtering short series)
|
||||
and calls validate_and_prepare_single_dict_task for each task.
|
||||
"""
|
||||
if isinstance(mode, str):
|
||||
mode = DatasetMode(mode)
|
||||
|
||||
tasks: list[PreparedTask] = []
|
||||
|
||||
for idx, raw_task in enumerate(raw_tasks):
|
||||
|
|
@ -268,20 +265,37 @@ def validate_prepared_schema(task: Any) -> None:
|
|||
if not isinstance(task, Mapping):
|
||||
raise TypeError(
|
||||
f"Expected task to be a dict-like, got {type(task).__name__}. "
|
||||
"Set convert_inputs=True to preprocess raw inputs."
|
||||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
|
||||
missing = required_keys - set(task.keys())
|
||||
if missing:
|
||||
raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.")
|
||||
raise TypeError(
|
||||
f"Task is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
context = task["context"]
|
||||
if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2:
|
||||
if not isinstance(context, torch.Tensor) or context.ndim != 2:
|
||||
raise TypeError(
|
||||
f"Expected 'context' to be 2-d array, got {type(context).__name__} "
|
||||
f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} "
|
||||
f"with shape {getattr(context, 'shape', 'N/A')}. "
|
||||
"Set convert_inputs=True to preprocess raw inputs."
|
||||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
future_covariates = task["future_covariates"]
|
||||
if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2:
|
||||
raise TypeError(
|
||||
f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} "
|
||||
f"with shape {getattr(future_covariates, 'shape', 'N/A')}. "
|
||||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
if context.shape[0] != future_covariates.shape[0]:
|
||||
raise ValueError(
|
||||
f"Expected 'context' and 'future_covariates' to have the same first dimension, "
|
||||
f"got {context.shape[0]} and {future_covariates.shape[0]}. "
|
||||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -512,7 +526,9 @@ class Chronos2Dataset(IterableDataset):
|
|||
if convert_inputs:
|
||||
if isinstance(inputs, (torch.Tensor, np.ndarray)):
|
||||
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
|
||||
elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)):
|
||||
elif (
|
||||
isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray))
|
||||
):
|
||||
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
|
||||
self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
|
||||
else:
|
||||
|
|
@ -528,7 +544,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
|
||||
def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
|
||||
task = self.tasks[task_idx]
|
||||
task_past_tensor = task["context"].clone()
|
||||
task_past_tensor = task["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length)
|
||||
task_future_tensor = task["future_covariates"].clone()
|
||||
task_n_targets = int(task["n_targets"])
|
||||
task_n_covariates = int(task["n_covariates"])
|
||||
|
|
@ -672,39 +688,3 @@ class Chronos2Dataset(IterableDataset):
|
|||
yield batch
|
||||
else:
|
||||
yield from self._generate_sequential_batches()
|
||||
|
||||
@classmethod
|
||||
def convert_inputs(
|
||||
cls,
|
||||
inputs: TensorOrArray
|
||||
| Sequence[TensorOrArray]
|
||||
| Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]],
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
batch_size: int,
|
||||
output_patch_size: int,
|
||||
min_past: int = 1,
|
||||
mode: str | DatasetMode = DatasetMode.TRAIN,
|
||||
) -> "Chronos2Dataset":
|
||||
"""Convert from different input formats to a Chronos2Dataset."""
|
||||
if isinstance(inputs, (torch.Tensor, np.ndarray)):
|
||||
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
|
||||
elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]):
|
||||
inputs = cast(list[TensorOrArray], inputs)
|
||||
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs)
|
||||
elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unexpected inputs format")
|
||||
|
||||
inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs)
|
||||
|
||||
return cls(
|
||||
inputs,
|
||||
context_length=context_length,
|
||||
prediction_length=prediction_length,
|
||||
batch_size=batch_size,
|
||||
output_patch_size=output_patch_size,
|
||||
min_past=min_past,
|
||||
mode=mode,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -616,8 +616,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
)
|
||||
context_length = self.model_context_length
|
||||
|
||||
test_dataset = Chronos2Dataset.convert_inputs(
|
||||
inputs=inputs,
|
||||
test_dataset = Chronos2Dataset(
|
||||
inputs,
|
||||
context_length=context_length,
|
||||
prediction_length=prediction_length,
|
||||
batch_size=batch_size,
|
||||
|
|
@ -1142,8 +1142,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
)
|
||||
context_length = self.model_context_length
|
||||
|
||||
test_dataset = Chronos2Dataset.convert_inputs(
|
||||
inputs=inputs,
|
||||
test_dataset = Chronos2Dataset(
|
||||
inputs,
|
||||
context_length=context_length,
|
||||
prediction_length=0,
|
||||
batch_size=batch_size,
|
||||
|
|
|
|||
Loading…
Reference in a new issue