Address PR comments

This commit is contained in:
Oleksandr Shchur 2026-02-19 08:58:53 +00:00
parent 588f19f1f8
commit 8928819100
2 changed files with 30 additions and 50 deletions

View file

@ -227,9 +227,6 @@ def prepare_tasks(
This function handles mode-specific preprocessing (e.g., filtering short series)
and calls validate_and_prepare_single_dict_task for each task.
"""
if isinstance(mode, str):
mode = DatasetMode(mode)
tasks: list[PreparedTask] = []
for idx, raw_task in enumerate(raw_tasks):
@ -268,20 +265,37 @@ def validate_prepared_schema(task: Any) -> None:
if not isinstance(task, Mapping):
raise TypeError(
f"Expected task to be a dict-like, got {type(task).__name__}. "
"Set convert_inputs=True to preprocess raw inputs."
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
missing = required_keys - set(task.keys())
if missing:
raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.")
raise TypeError(
f"Task is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
context = task["context"]
if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2:
if not isinstance(context, torch.Tensor) or context.ndim != 2:
raise TypeError(
f"Expected 'context' to be 2-d array, got {type(context).__name__} "
f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} "
f"with shape {getattr(context, 'shape', 'N/A')}. "
"Set convert_inputs=True to preprocess raw inputs."
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
future_covariates = task["future_covariates"]
if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2:
raise TypeError(
f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} "
f"with shape {getattr(future_covariates, 'shape', 'N/A')}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
if context.shape[0] != future_covariates.shape[0]:
raise ValueError(
f"Expected 'context' and 'future_covariates' to have the same first dimension, "
f"got {context.shape[0]} and {future_covariates.shape[0]}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
@ -512,7 +526,9 @@ class Chronos2Dataset(IterableDataset):
if convert_inputs:
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)):
elif (
isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray))
):
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
else:
@ -528,7 +544,7 @@ class Chronos2Dataset(IterableDataset):
def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
task = self.tasks[task_idx]
task_past_tensor = task["context"].clone()
task_past_tensor = task["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length)
task_future_tensor = task["future_covariates"].clone()
task_n_targets = int(task["n_targets"])
task_n_covariates = int(task["n_covariates"])
@ -672,39 +688,3 @@ class Chronos2Dataset(IterableDataset):
yield batch
else:
yield from self._generate_sequential_batches()
@classmethod
def convert_inputs(
cls,
inputs: TensorOrArray
| Sequence[TensorOrArray]
| Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]],
context_length: int,
prediction_length: int,
batch_size: int,
output_patch_size: int,
min_past: int = 1,
mode: str | DatasetMode = DatasetMode.TRAIN,
) -> "Chronos2Dataset":
"""Convert from different input formats to a Chronos2Dataset."""
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]):
inputs = cast(list[TensorOrArray], inputs)
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]):
pass
else:
raise ValueError("Unexpected inputs format")
inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs)
return cls(
inputs,
context_length=context_length,
prediction_length=prediction_length,
batch_size=batch_size,
output_patch_size=output_patch_size,
min_past=min_past,
mode=mode,
)

View file

@ -616,8 +616,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
context_length = self.model_context_length
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
test_dataset = Chronos2Dataset(
inputs,
context_length=context_length,
prediction_length=prediction_length,
batch_size=batch_size,
@ -1142,8 +1142,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
context_length = self.model_context_length
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
test_dataset = Chronos2Dataset(
inputs,
context_length=context_length,
prediction_length=0,
batch_size=batch_size,