diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 1b8a77a..b0bbbeb 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -20,8 +20,8 @@ if TYPE_CHECKING: TensorOrArray: TypeAlias = torch.Tensor | np.ndarray -class PreparedTask(TypedDict): - """A preprocessed time series task ready for model training/inference.""" +class PreparedInput(TypedDict): + """A preprocessed time series input ready for model training/inference.""" context: torch.Tensor # (n_variates, history_length), float32 future_covariates: torch.Tensor # (n_variates, prediction_length), float32 @@ -47,10 +47,10 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor: return torch.cat(padded, dim=0) -def validate_and_prepare_single_dict_task( +def validate_and_prepare_single_dict_input( task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int -) -> PreparedTask: - """Validates and prepares a single dictionary task for Chronos2Model. +) -> PreparedInput: + """Validates and prepares a single dictionary input for Chronos2Model. Parameters ---------- @@ -207,7 +207,7 @@ def validate_and_prepare_single_dict_task( # number of known-future covariates task_n_future_covariates = len(task_future_covariates_keys) - return PreparedTask( + return PreparedInput( context=task_context_tensor, future_covariates=task_future_covariates_tensor, n_targets=task_n_targets, @@ -216,23 +216,23 @@ def validate_and_prepare_single_dict_task( ) -def prepare_tasks( - raw_tasks: Iterable[Mapping[str, Any]], +def prepare_inputs( + raw_inputs: Iterable[Mapping[str, Any]], prediction_length: int, min_past: int = 1, mode: "DatasetMode | str" = "train", -) -> list[PreparedTask]: - """Prepare multiple time series tasks for training/inference. +) -> list[PreparedInput]: + """Prepare multiple time series inputs for training/inference. This function handles mode-specific preprocessing (e.g., filtering short series) - and calls validate_and_prepare_single_dict_task for each task. + and calls validate_and_prepare_single_dict_input for each input. """ - tasks: list[PreparedTask] = [] + inputs: list[PreparedInput] = [] - for idx, raw_task in enumerate(raw_tasks): + for idx, raw_input in enumerate(raw_inputs): # For non-TEST modes, fix future_covariates (replace None/empty with NaN arrays) if mode != DatasetMode.TEST: - raw_future_covariates = raw_task.get("future_covariates", {}) + raw_future_covariates = raw_input.get("future_covariates", {}) if raw_future_covariates: raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates) fixed_future_covariates = {} @@ -240,42 +240,42 @@ def prepare_tasks( fixed_future_covariates[key] = ( np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value ) - raw_task = {**raw_task, "future_covariates": fixed_future_covariates} + raw_input = {**raw_input, "future_covariates": fixed_future_covariates} - raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task) - task = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length) + raw_input = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_input) + prepared = validate_and_prepare_single_dict_input(raw_input, idx, prediction_length) # Filter by minimum length (except in TEST mode) - if mode != DatasetMode.TEST and task["context"].shape[-1] < min_past + prediction_length: + if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length: continue - tasks.append(task) + inputs.append(prepared) - if len(tasks) == 0: + if len(inputs) == 0: raise ValueError( "The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). " "Please provide longer time series or reduce `min_past` or `prediction_length`. " ) - return tasks + return inputs -def validate_prepared_schema(task: Any) -> None: - """Validate that a task matches the PreparedTask schema.""" - if not isinstance(task, Mapping): +def validate_prepared_schema(prepared_input: Any) -> None: + """Validate that an input matches the PreparedInput schema.""" + if not isinstance(prepared_input, Mapping): raise TypeError( - f"Expected task to be a dict-like, got {type(task).__name__}. " + f"Expected input to be a dict-like, got {type(prepared_input).__name__}. " "Set convert_inputs=True when calling fit() to preprocess raw inputs." ) required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"} - missing = required_keys - set(task.keys()) + missing = required_keys - set(prepared_input.keys()) if missing: raise TypeError( - f"Task is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs." + f"Input is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs." ) - context = task["context"] + context = prepared_input["context"] if not isinstance(context, torch.Tensor) or context.ndim != 2: raise TypeError( f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} " @@ -283,7 +283,7 @@ def validate_prepared_schema(task: Any) -> None: "Set convert_inputs=True when calling fit() to preprocess raw inputs." ) - future_covariates = task["future_covariates"] + future_covariates = prepared_input["future_covariates"] if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2: raise TypeError( f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} " @@ -486,9 +486,9 @@ class Chronos2Dataset(IterableDataset): covariates. - `future_covariates` (optional): a dict of future values of known future covariates. - 2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedTask` dicts with keys: + 2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedInput` dicts with keys: `context`, `future_covariates`, `n_targets`, `n_covariates`, `n_future_covariates`. - Use `prepare_tasks()` to create pre-processed inputs. + Use `prepare_inputs()` to create pre-processed inputs. context_length The maximum context length used for training or inference prediction_length @@ -510,7 +510,7 @@ class Chronos2Dataset(IterableDataset): def __init__( self, - inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedTask], + inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedInput], context_length: int, prediction_length: int, batch_size: int, @@ -522,7 +522,7 @@ class Chronos2Dataset(IterableDataset): super().__init__() assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}" - self.tasks: Sequence[PreparedTask] + self.inputs: Sequence[PreparedInput] if convert_inputs: if isinstance(inputs, (torch.Tensor, np.ndarray)): inputs = convert_tensor_input_to_list_of_dicts_input(inputs) @@ -530,10 +530,10 @@ class Chronos2Dataset(IterableDataset): isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)) ): inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs)) - self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode) + self.inputs = prepare_inputs(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode) else: validate_prepared_schema(inputs[0]) - self.tasks = cast(Sequence[PreparedTask], inputs) + self.inputs = cast(Sequence[PreparedInput], inputs) self.context_length = context_length self.prediction_length = prediction_length @@ -542,13 +542,13 @@ class Chronos2Dataset(IterableDataset): self.min_past = min_past self.mode = mode - def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: - task = self.tasks[task_idx] - task_past_tensor = task["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length) - task_future_tensor = task["future_covariates"].clone() - task_n_targets = int(task["n_targets"]) - task_n_covariates = int(task["n_covariates"]) - task_n_future_covariates = int(task["n_future_covariates"]) + def _construct_slice(self, input_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: + input = self.inputs[input_idx] + task_past_tensor = input["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length) + task_future_tensor = input["future_covariates"].clone() + task_n_targets = int(input["n_targets"]) + task_n_covariates = int(input["n_covariates"]) + task_n_future_covariates = int(input["n_future_covariates"]) task_n_past_only_covariates = task_n_covariates - task_n_future_covariates full_length = task_past_tensor.shape[-1] @@ -606,8 +606,8 @@ class Chronos2Dataset(IterableDataset): return task_context, task_future_target, task_future_covariates, task_n_targets - def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: - """Build a batch from given task indices.""" + def _build_batch(self, input_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: + """Build a batch from given input indices.""" batch_context_tensor_list = [] batch_future_target_tensor_list = [] batch_future_covariates_tensor_list = [] @@ -615,8 +615,8 @@ class Chronos2Dataset(IterableDataset): target_idx_ranges: list[tuple[int, int]] = [] target_start_idx = 0 - for group_id, task_idx in enumerate(task_indices): - task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx) + for group_id, input_idx in enumerate(input_indices): + task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(input_idx) group_size = task_context.shape[0] task_group_ids = torch.full((group_size,), fill_value=group_id) @@ -641,27 +641,27 @@ class Chronos2Dataset(IterableDataset): def _generate_train_batches(self): while True: current_batch_size = 0 - task_indices = [] + input_indices = [] while current_batch_size < self.batch_size: - task_idx = np.random.randint(len(self.tasks)) - task_indices.append(task_idx) - current_batch_size += self.tasks[task_idx]["context"].shape[0] + input_idx = np.random.randint(len(self.inputs)) + input_indices.append(input_idx) + current_batch_size += self.inputs[input_idx]["context"].shape[0] - yield self._build_batch(task_indices) + yield self._build_batch(input_indices) def _generate_sequential_batches(self): - task_idx = 0 - while task_idx < len(self.tasks): + input_idx = 0 + while input_idx < len(self.inputs): current_batch_size = 0 - task_indices = [] + input_indices = [] - while task_idx < len(self.tasks) and current_batch_size < self.batch_size: - task_indices.append(task_idx) - current_batch_size += self.tasks[task_idx]["context"].shape[0] - task_idx += 1 + while input_idx < len(self.inputs) and current_batch_size < self.batch_size: + input_indices.append(input_idx) + current_batch_size += self.inputs[input_idx]["context"].shape[0] + input_idx += 1 - yield self._build_batch(task_indices) + yield self._build_batch(input_indices) def __iter__(self) -> Iterator: """