mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Rename task -> input in dataset.py
This commit is contained in:
parent
0e9db70afc
commit
5daf273ca6
1 changed files with 59 additions and 59 deletions
|
|
@ -20,8 +20,8 @@ if TYPE_CHECKING:
|
|||
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
|
||||
|
||||
|
||||
class PreparedTask(TypedDict):
|
||||
"""A preprocessed time series task ready for model training/inference."""
|
||||
class PreparedInput(TypedDict):
|
||||
"""A preprocessed time series input ready for model training/inference."""
|
||||
|
||||
context: torch.Tensor # (n_variates, history_length), float32
|
||||
future_covariates: torch.Tensor # (n_variates, prediction_length), float32
|
||||
|
|
@ -47,10 +47,10 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
|
|||
return torch.cat(padded, dim=0)
|
||||
|
||||
|
||||
def validate_and_prepare_single_dict_task(
|
||||
def validate_and_prepare_single_dict_input(
|
||||
task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
|
||||
) -> PreparedTask:
|
||||
"""Validates and prepares a single dictionary task for Chronos2Model.
|
||||
) -> PreparedInput:
|
||||
"""Validates and prepares a single dictionary input for Chronos2Model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -207,7 +207,7 @@ def validate_and_prepare_single_dict_task(
|
|||
# number of known-future covariates
|
||||
task_n_future_covariates = len(task_future_covariates_keys)
|
||||
|
||||
return PreparedTask(
|
||||
return PreparedInput(
|
||||
context=task_context_tensor,
|
||||
future_covariates=task_future_covariates_tensor,
|
||||
n_targets=task_n_targets,
|
||||
|
|
@ -216,23 +216,23 @@ def validate_and_prepare_single_dict_task(
|
|||
)
|
||||
|
||||
|
||||
def prepare_tasks(
|
||||
raw_tasks: Iterable[Mapping[str, Any]],
|
||||
def prepare_inputs(
|
||||
raw_inputs: Iterable[Mapping[str, Any]],
|
||||
prediction_length: int,
|
||||
min_past: int = 1,
|
||||
mode: "DatasetMode | str" = "train",
|
||||
) -> list[PreparedTask]:
|
||||
"""Prepare multiple time series tasks for training/inference.
|
||||
) -> list[PreparedInput]:
|
||||
"""Prepare multiple time series inputs for training/inference.
|
||||
|
||||
This function handles mode-specific preprocessing (e.g., filtering short series)
|
||||
and calls validate_and_prepare_single_dict_task for each task.
|
||||
and calls validate_and_prepare_single_dict_input for each input.
|
||||
"""
|
||||
tasks: list[PreparedTask] = []
|
||||
inputs: list[PreparedInput] = []
|
||||
|
||||
for idx, raw_task in enumerate(raw_tasks):
|
||||
for idx, raw_input in enumerate(raw_inputs):
|
||||
# For non-TEST modes, fix future_covariates (replace None/empty with NaN arrays)
|
||||
if mode != DatasetMode.TEST:
|
||||
raw_future_covariates = raw_task.get("future_covariates", {})
|
||||
raw_future_covariates = raw_input.get("future_covariates", {})
|
||||
if raw_future_covariates:
|
||||
raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates)
|
||||
fixed_future_covariates = {}
|
||||
|
|
@ -240,42 +240,42 @@ def prepare_tasks(
|
|||
fixed_future_covariates[key] = (
|
||||
np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value
|
||||
)
|
||||
raw_task = {**raw_task, "future_covariates": fixed_future_covariates}
|
||||
raw_input = {**raw_input, "future_covariates": fixed_future_covariates}
|
||||
|
||||
raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task)
|
||||
task = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length)
|
||||
raw_input = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_input)
|
||||
prepared = validate_and_prepare_single_dict_input(raw_input, idx, prediction_length)
|
||||
|
||||
# Filter by minimum length (except in TEST mode)
|
||||
if mode != DatasetMode.TEST and task["context"].shape[-1] < min_past + prediction_length:
|
||||
if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length:
|
||||
continue
|
||||
|
||||
tasks.append(task)
|
||||
inputs.append(prepared)
|
||||
|
||||
if len(tasks) == 0:
|
||||
if len(inputs) == 0:
|
||||
raise ValueError(
|
||||
"The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). "
|
||||
"Please provide longer time series or reduce `min_past` or `prediction_length`. "
|
||||
)
|
||||
|
||||
return tasks
|
||||
return inputs
|
||||
|
||||
|
||||
def validate_prepared_schema(task: Any) -> None:
|
||||
"""Validate that a task matches the PreparedTask schema."""
|
||||
if not isinstance(task, Mapping):
|
||||
def validate_prepared_schema(prepared_input: Any) -> None:
|
||||
"""Validate that an input matches the PreparedInput schema."""
|
||||
if not isinstance(prepared_input, Mapping):
|
||||
raise TypeError(
|
||||
f"Expected task to be a dict-like, got {type(task).__name__}. "
|
||||
f"Expected input to be a dict-like, got {type(prepared_input).__name__}. "
|
||||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
|
||||
missing = required_keys - set(task.keys())
|
||||
missing = required_keys - set(prepared_input.keys())
|
||||
if missing:
|
||||
raise TypeError(
|
||||
f"Task is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
f"Input is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
context = task["context"]
|
||||
context = prepared_input["context"]
|
||||
if not isinstance(context, torch.Tensor) or context.ndim != 2:
|
||||
raise TypeError(
|
||||
f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} "
|
||||
|
|
@ -283,7 +283,7 @@ def validate_prepared_schema(task: Any) -> None:
|
|||
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
|
||||
)
|
||||
|
||||
future_covariates = task["future_covariates"]
|
||||
future_covariates = prepared_input["future_covariates"]
|
||||
if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2:
|
||||
raise TypeError(
|
||||
f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} "
|
||||
|
|
@ -486,9 +486,9 @@ class Chronos2Dataset(IterableDataset):
|
|||
covariates.
|
||||
- `future_covariates` (optional): a dict of future values of known future covariates.
|
||||
|
||||
2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedTask` dicts with keys:
|
||||
2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedInput` dicts with keys:
|
||||
`context`, `future_covariates`, `n_targets`, `n_covariates`, `n_future_covariates`.
|
||||
Use `prepare_tasks()` to create pre-processed inputs.
|
||||
Use `prepare_inputs()` to create pre-processed inputs.
|
||||
context_length
|
||||
The maximum context length used for training or inference
|
||||
prediction_length
|
||||
|
|
@ -510,7 +510,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedTask],
|
||||
inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedInput],
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
batch_size: int,
|
||||
|
|
@ -522,7 +522,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
super().__init__()
|
||||
assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}"
|
||||
|
||||
self.tasks: Sequence[PreparedTask]
|
||||
self.inputs: Sequence[PreparedInput]
|
||||
if convert_inputs:
|
||||
if isinstance(inputs, (torch.Tensor, np.ndarray)):
|
||||
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
|
||||
|
|
@ -530,10 +530,10 @@ class Chronos2Dataset(IterableDataset):
|
|||
isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray))
|
||||
):
|
||||
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
|
||||
self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
|
||||
self.inputs = prepare_inputs(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
|
||||
else:
|
||||
validate_prepared_schema(inputs[0])
|
||||
self.tasks = cast(Sequence[PreparedTask], inputs)
|
||||
self.inputs = cast(Sequence[PreparedInput], inputs)
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
|
|
@ -542,13 +542,13 @@ class Chronos2Dataset(IterableDataset):
|
|||
self.min_past = min_past
|
||||
self.mode = mode
|
||||
|
||||
def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
|
||||
task = self.tasks[task_idx]
|
||||
task_past_tensor = task["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length)
|
||||
task_future_tensor = task["future_covariates"].clone()
|
||||
task_n_targets = int(task["n_targets"])
|
||||
task_n_covariates = int(task["n_covariates"])
|
||||
task_n_future_covariates = int(task["n_future_covariates"])
|
||||
def _construct_slice(self, input_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
|
||||
input = self.inputs[input_idx]
|
||||
task_past_tensor = input["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length)
|
||||
task_future_tensor = input["future_covariates"].clone()
|
||||
task_n_targets = int(input["n_targets"])
|
||||
task_n_covariates = int(input["n_covariates"])
|
||||
task_n_future_covariates = int(input["n_future_covariates"])
|
||||
task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
|
||||
|
||||
full_length = task_past_tensor.shape[-1]
|
||||
|
|
@ -606,8 +606,8 @@ class Chronos2Dataset(IterableDataset):
|
|||
|
||||
return task_context, task_future_target, task_future_covariates, task_n_targets
|
||||
|
||||
def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
|
||||
"""Build a batch from given task indices."""
|
||||
def _build_batch(self, input_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
|
||||
"""Build a batch from given input indices."""
|
||||
batch_context_tensor_list = []
|
||||
batch_future_target_tensor_list = []
|
||||
batch_future_covariates_tensor_list = []
|
||||
|
|
@ -615,8 +615,8 @@ class Chronos2Dataset(IterableDataset):
|
|||
target_idx_ranges: list[tuple[int, int]] = []
|
||||
|
||||
target_start_idx = 0
|
||||
for group_id, task_idx in enumerate(task_indices):
|
||||
task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx)
|
||||
for group_id, input_idx in enumerate(input_indices):
|
||||
task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(input_idx)
|
||||
|
||||
group_size = task_context.shape[0]
|
||||
task_group_ids = torch.full((group_size,), fill_value=group_id)
|
||||
|
|
@ -641,27 +641,27 @@ class Chronos2Dataset(IterableDataset):
|
|||
def _generate_train_batches(self):
|
||||
while True:
|
||||
current_batch_size = 0
|
||||
task_indices = []
|
||||
input_indices = []
|
||||
|
||||
while current_batch_size < self.batch_size:
|
||||
task_idx = np.random.randint(len(self.tasks))
|
||||
task_indices.append(task_idx)
|
||||
current_batch_size += self.tasks[task_idx]["context"].shape[0]
|
||||
input_idx = np.random.randint(len(self.inputs))
|
||||
input_indices.append(input_idx)
|
||||
current_batch_size += self.inputs[input_idx]["context"].shape[0]
|
||||
|
||||
yield self._build_batch(task_indices)
|
||||
yield self._build_batch(input_indices)
|
||||
|
||||
def _generate_sequential_batches(self):
|
||||
task_idx = 0
|
||||
while task_idx < len(self.tasks):
|
||||
input_idx = 0
|
||||
while input_idx < len(self.inputs):
|
||||
current_batch_size = 0
|
||||
task_indices = []
|
||||
input_indices = []
|
||||
|
||||
while task_idx < len(self.tasks) and current_batch_size < self.batch_size:
|
||||
task_indices.append(task_idx)
|
||||
current_batch_size += self.tasks[task_idx]["context"].shape[0]
|
||||
task_idx += 1
|
||||
while input_idx < len(self.inputs) and current_batch_size < self.batch_size:
|
||||
input_indices.append(input_idx)
|
||||
current_batch_size += self.inputs[input_idx]["context"].shape[0]
|
||||
input_idx += 1
|
||||
|
||||
yield self._build_batch(task_indices)
|
||||
yield self._build_batch(input_indices)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue