diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index cb75571..2e1b6a1 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -5,7 +5,7 @@ import math from enum import Enum -from typing import TYPE_CHECKING, Iterator, Mapping, Sequence, TypeAlias, cast +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, TypeAlias, TypedDict, cast import numpy as np import torch @@ -20,6 +20,16 @@ if TYPE_CHECKING: TensorOrArray: TypeAlias = torch.Tensor | np.ndarray +class PreparedInput(TypedDict): + """A preprocessed time series input ready for model training/inference.""" + + context: torch.Tensor # (n_variates, history_length), float32 + future_covariates: torch.Tensor # (n_variates, prediction_length), float32 + n_targets: int + n_covariates: int + n_future_covariates: int + + def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor: """ Left pads tensors in the list to the length of the longest tensor along the second axis, then concats @@ -37,14 +47,14 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor: return torch.cat(padded, dim=0) -def validate_and_prepare_single_dict_task( - task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int -) -> tuple[torch.Tensor, torch.Tensor, int, int, int]: - """Validates and prepares a single dictionary task for Chronos2Model. +def validate_and_prepare_single_dict_input( + raw_input: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int +) -> PreparedInput: + """Validates and prepares a single dictionary input for Chronos2Model. Parameters ---------- - task + raw_input A dictionary representing a time series that contains: - `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length). Forecasts will be generated for items in `target`. @@ -55,27 +65,27 @@ def validate_and_prepare_single_dict_task( covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in `future_covariates` must be a subset of the keys in `past_covariates`. idx - Index of this task in the list of tasks, used for error messages + Index of this input in the list of inputs, used for error messages prediction_length Number of future time steps to predict, used to validate future covariates Returns ------ - A tuple containing: - - task_context_tensor: Concatenated tensor of target and past covariates of shape (group_size, history_length), - the first `task_n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates + A PreparedInput containing: + - context: Concatenated tensor of target and past covariates of shape (group_size, history_length), + the first `n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates and past values of known future covariates. - - task_future_covariates_tensor: Tensor of future covariates of shape (group_size, prediction_length). The last `task_n_future_covariates` + - future_covariates: Tensor of future covariates of shape (group_size, prediction_length). The last `n_future_covariates` items along the first axis contain future covariates. All the remaining elements corresponding to target and past-only covariates are NaNs. - - task_n_targets: Number of target variables - - task_n_covariates: Total number of covariates (sum of past-only and known future covariates) - - task_n_future_covariates: Number of known future covariates + - n_targets: Number of target variables + - n_covariates: Total number of covariates (sum of past-only and known future covariates) + - n_future_covariates: Number of known future covariates """ allowed_keys = {"target", "past_covariates", "future_covariates"} # validate keys - keys = set(task.keys()) + keys = set(raw_input.keys()) if not keys.issubset(allowed_keys): raise ValueError( f"Found invalid keys in element at index {idx}. Allowed keys are {allowed_keys}, but found {keys}" @@ -84,58 +94,58 @@ def validate_and_prepare_single_dict_task( raise ValueError(f"Element at index {idx} does not contain the required key 'target'") # validate target - task_target = task["target"] - if isinstance(task_target, np.ndarray): - task_target = torch.from_numpy(task_target) - assert isinstance(task_target, torch.Tensor) - if task_target.ndim > 2: + target = raw_input["target"] + if isinstance(target, np.ndarray): + target = torch.from_numpy(target) + assert isinstance(target, torch.Tensor) + if target.ndim > 2: raise ValueError( "When the input is a list of dicts, the `target` should either be 1-d with shape (history_length,) " - f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(task_target.shape)}." + f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(target.shape)}." ) - history_length = task_target.shape[-1] - task_target = task_target.view(-1, history_length) + history_length = target.shape[-1] + target = target.view(-1, history_length) # validate past_covariates cat_encoders: dict = {} - task_past_covariates = task.get("past_covariates", {}) - if not isinstance(task_past_covariates, dict): + past_covariates = raw_input.get("past_covariates", {}) + if not isinstance(past_covariates, dict): raise ValueError( f"Found invalid type for `past_covariates` in element at index {idx}. " - f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}' + f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(past_covariates)}' ) # gather keys and ensure known-future keys come last to match downstream assumptions - task_covariates_keys = sorted(task_past_covariates.keys()) + covariates_keys = sorted(past_covariates.keys()) - task_future_covariates = task.get("future_covariates", {}) - if not isinstance(task_future_covariates, dict): + future_covariates = raw_input.get("future_covariates", {}) + if not isinstance(future_covariates, dict): raise ValueError( f"Found invalid type for `future_covariates` in element at index {idx}. " - f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}' + f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(future_covariates)}' ) - task_future_covariates_keys = sorted(task_future_covariates.keys()) - if not set(task_future_covariates_keys).issubset(task_covariates_keys): + future_covariates_keys = sorted(future_covariates.keys()) + if not set(future_covariates_keys).issubset(covariates_keys): raise ValueError( - f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, " - f"but found {task_future_covariates_keys} in element at index {idx}" + f"Expected keys in `future_covariates` to be a subset of `past_covariates` {covariates_keys}, " + f"but found {future_covariates_keys} in element at index {idx}" ) # create ordered keys: past-only first, then known-future (so known-future are the last rows) - task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys - task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys + past_only_keys = [k for k in covariates_keys if k not in future_covariates_keys] + ordered_covariate_keys = past_only_keys + future_covariates_keys - task_past_covariates_list: list[torch.Tensor] = [] - for key in task_ordered_covariate_keys: - tensor = task_past_covariates[key] + past_covariates_list: list[torch.Tensor] = [] + for key in ordered_covariate_keys: + tensor = past_covariates[key] if isinstance(tensor, np.ndarray): # apply encoding to categorical variates if not np.issubdtype(tensor.dtype, np.number): # target encoding, if the target is 1-d - if task_target.shape[0] == 1: + if target.shape[0] == 1: cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0) X = tensor.astype(str).reshape(-1, 1) - y = task_target.view(-1).numpy() + y = target.view(-1).numpy() mask = np.isfinite(y) X = X[mask] y = y[mask] @@ -153,18 +163,18 @@ def validate_and_prepare_single_dict_task( f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {history_length}), " f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" ) - task_past_covariates_list.append(tensor) - task_past_covariates_tensor = ( - torch.stack(task_past_covariates_list, dim=0) - if task_past_covariates_list - else torch.zeros((0, history_length), device=task_target.device) + past_covariates_list.append(tensor) + past_covariates_tensor = ( + torch.stack(past_covariates_list, dim=0) + if past_covariates_list + else torch.zeros((0, history_length), device=target.device) ) - # validate future_covariates (build rows in the same task_ordered_covariate_keys order) - task_future_covariates_list: list[torch.Tensor] = [] - for key in task_ordered_covariate_keys: + # validate future_covariates (build rows in the same ordered_covariate_keys order) + future_covariates_list: list[torch.Tensor] = [] + for key in ordered_covariate_keys: # future values of past-only covariates are filled with NaNs - tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan)) + tensor = future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan)) if isinstance(tensor, np.ndarray): # apply encoding to categorical variates if not np.issubdtype(tensor.dtype, np.number): @@ -177,35 +187,118 @@ def validate_and_prepare_single_dict_task( f"Individual `future_covariates` must be 1-d with length equal to the {prediction_length=}, " f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" ) - task_future_covariates_list.append(tensor) - task_future_covariates_tensor = ( - torch.stack(task_future_covariates_list, dim=0) - if task_future_covariates_list - else torch.zeros((0, prediction_length), device=task_target.device) + future_covariates_list.append(tensor) + future_covariates_tensor = ( + torch.stack(future_covariates_list, dim=0) + if future_covariates_list + else torch.zeros((0, prediction_length), device=target.device) ) # future values of target series are filled with NaNs - task_future_covariates_target_padding = torch.full( - (task_target.shape[0], prediction_length), fill_value=torch.nan, device=task_target.device + future_covariates_target_padding = torch.full( + (target.shape[0], prediction_length), fill_value=torch.nan, device=target.device ) - task_context_tensor = torch.cat([task_target, task_past_covariates_tensor], dim=0).to(dtype=torch.float32) - task_future_covariates_tensor = torch.cat( - [task_future_covariates_target_padding, task_future_covariates_tensor], dim=0 + context_tensor = torch.cat([target, past_covariates_tensor], dim=0).to(dtype=torch.float32) + future_covariates_tensor = torch.cat( + [future_covariates_target_padding, future_covariates_tensor], dim=0 ).to(dtype=torch.float32) - task_n_targets = task_target.shape[0] - task_n_covariates = task_past_covariates_tensor.shape[0] + n_targets = target.shape[0] + n_covariates = past_covariates_tensor.shape[0] # number of known-future covariates - task_n_future_covariates = len(task_future_covariates_keys) + n_future_covariates = len(future_covariates_keys) - return ( - task_context_tensor, - task_future_covariates_tensor, - task_n_targets, - task_n_covariates, - task_n_future_covariates, + return PreparedInput( + context=context_tensor, + future_covariates=future_covariates_tensor, + n_targets=n_targets, + n_covariates=n_covariates, + n_future_covariates=n_future_covariates, ) +def prepare_inputs( + raw_inputs: Iterable[Mapping[str, Any]], + prediction_length: int, + min_past: int = 1, + mode: "DatasetMode | str" = "train", +) -> list[PreparedInput]: + """Prepare multiple time series inputs for training/inference. + + This function handles mode-specific preprocessing (e.g., filtering short series) + and calls validate_and_prepare_single_dict_input for each input. + """ + inputs: list[PreparedInput] = [] + + for idx, raw_input in enumerate(raw_inputs): + # For non-TEST modes, fix future_covariates (replace None/empty with NaN arrays) + if mode != DatasetMode.TEST: + raw_future_covariates = raw_input.get("future_covariates", {}) + if raw_future_covariates: + raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates) + fixed_future_covariates = {} + for key, value in raw_future_covariates.items(): + fixed_future_covariates[key] = ( + np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value + ) + raw_input = {**raw_input, "future_covariates": fixed_future_covariates} + + raw_input = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_input) + prepared = validate_and_prepare_single_dict_input(raw_input, idx, prediction_length) + + # Filter by minimum length (except in TEST mode) + if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length: + continue + + inputs.append(prepared) + + if len(inputs) == 0: + raise ValueError( + "The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). " + "Please provide longer time series or reduce `min_past` or `prediction_length`. " + ) + + return inputs + + +def validate_prepared_schema(prepared_input: Any) -> None: + """Validate that an input matches the PreparedInput schema.""" + if not isinstance(prepared_input, Mapping): + raise TypeError( + f"Expected input to be a dict-like, got {type(prepared_input).__name__}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"} + missing = required_keys - set(prepared_input.keys()) + if missing: + raise TypeError( + f"Input is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + context = prepared_input["context"] + if not isinstance(context, torch.Tensor) or context.ndim != 2: + raise TypeError( + f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} " + f"with shape {getattr(context, 'shape', 'N/A')}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + future_covariates = prepared_input["future_covariates"] + if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2: + raise TypeError( + f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} " + f"with shape {getattr(future_covariates, 'shape', 'N/A')}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + if context.shape[0] != future_covariates.shape[0]: + raise ValueError( + f"Expected 'context' and 'future_covariates' to have the same first dimension, " + f"got {context.shape[0]} and {future_covariates.shape[0]}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + def convert_list_of_tensors_input_to_list_of_dicts_input( list_of_tensors: Sequence[TensorOrArray], ) -> list[dict[str, torch.Tensor]]: @@ -383,49 +476,65 @@ class Chronos2Dataset(IterableDataset): Arguments ---------- inputs - Time series data. Must be a list of dictionaries where each dictionary may have the following keys. - - `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length). - Forecasts will be generated for items in `target`. - - `past_covariates` (optional): a dict of past-only covariates or past values of known future covariates. The keys of the dict - must be names of the covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `history_length` - of `target`. - - `future_covariates` (optional): a dict of future values of known future covariates. The keys of the dict must be names of the - covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in - `future_covariates` must be a subset of the keys in `past_covariates`. - Note: when the mode is set to TRAIN, the values inside `future_covariates` are not technically used for training the model; - however, this key is used to infer which covariates are known into the future. Therefore, if your task contains known future covariates, - make sure that this key exists in `inputs`. The values of individual future covariates may be set to `None` or an empty array. + Time series data. Can be either: + + 1. Raw inputs (when `convert_inputs=True`, default): A sequence of dictionaries where each + dictionary may have the following keys: + - `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) + or (n_variates, history_length). Forecasts will be generated for items in `target`. + - `past_covariates` (optional): a dict of past-only covariates or past values of known future + covariates. + - `future_covariates` (optional): a dict of future values of known future covariates. + + 2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedInput` dicts with keys: + `context`, `future_covariates`, `n_targets`, `n_covariates`, `n_future_covariates`. + Use `prepare_inputs()` to create pre-processed inputs. context_length The maximum context length used for training or inference prediction_length The prediction horizon batch_size - The batch size for training the model. Note that the batch size here means the number of time series, including target(s) and - covariates, that are input into the model. If your data has multiple target and/or covariates, the effective number of time series - tasks in a batch will be lower than this value. + The batch size for training the model. Note that the batch size here means the number of time series, + including target(s) and covariates, that are input into the model. output_patch_size - The output patch size of the model. This is used to compute the number of patches needed to cover `prediction_length` + The output patch size of the model. This is used to compute the number of patches needed to cover + `prediction_length` min_past - The minimum number of time steps the context must have during training. All time series shorter than `min_past + prediction_length` - are filtered out, by default 1 + The minimum number of time steps the context must have during training. All time series shorter than + `min_past + prediction_length` are filtered out, by default 1 mode `DatasetMode` governing whether to generate training, validation or test samples, by default "train" + convert_inputs + If True (default), preprocess raw inputs. If False, inputs are expected to be already preprocessed. """ def __init__( self, - inputs: Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]], + inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedInput], context_length: int, prediction_length: int, batch_size: int, output_patch_size: int, min_past: int = 1, mode: str | DatasetMode = DatasetMode.TRAIN, + convert_inputs: bool = True, ) -> None: super().__init__() assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}" - self.tasks = Chronos2Dataset._prepare_tasks(inputs, prediction_length, min_past, mode) + self.inputs: Sequence[PreparedInput] + if convert_inputs: + if isinstance(inputs, (torch.Tensor, np.ndarray)): + inputs = convert_tensor_input_to_list_of_dicts_input(inputs) + elif ( + isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)) + ): + inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs)) + self.inputs = prepare_inputs(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode) + else: + validate_prepared_schema(inputs[0]) + self.inputs = cast(Sequence[PreparedInput], inputs) + self.context_length = context_length self.prediction_length = prediction_length self.batch_size = batch_size @@ -433,54 +542,16 @@ class Chronos2Dataset(IterableDataset): self.min_past = min_past self.mode = mode - @staticmethod - def _prepare_tasks( - inputs: Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]], - prediction_length: int, - min_past: int, - mode: str | DatasetMode, - ): - tasks = [] - for idx, raw_task in enumerate(inputs): - if mode != DatasetMode.TEST: - raw_future_covariates = raw_task.get("future_covariates", {}) - raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates) - if raw_future_covariates: - fixed_future_covariates = {} - for key, value in raw_future_covariates.items(): - fixed_future_covariates[key] = ( - np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value - ) - raw_task = {**raw_task, "future_covariates": fixed_future_covariates} + def _construct_slice(self, input_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: + prepared = self.inputs[input_idx] + past_tensor = prepared["context"].clone() # shape: (n_targets + n_covariates, history_length) + future_tensor = prepared["future_covariates"].clone() + n_targets = int(prepared["n_targets"]) + n_covariates = int(prepared["n_covariates"]) + n_future_covariates = int(prepared["n_future_covariates"]) + n_past_only_covariates = n_covariates - n_future_covariates - raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task) - # convert to a format compatible with model's forward - task = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length) - - if mode != DatasetMode.TEST and task[0].shape[-1] < min_past + prediction_length: - # filter tasks based on min_past + prediction_length - continue - tasks.append(task) - - if len(tasks) == 0: - raise ValueError( - "The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). " - "Please provide longer time series or reduce `min_past` or `prediction_length`. " - ) - return tasks - - def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: - ( - task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length) - task_future_tensor, - task_n_targets, - task_n_covariates, - task_n_future_covariates, - ) = self.tasks[task_idx] - task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone() - task_n_past_only_covariates = task_n_covariates - task_n_future_covariates - - full_length = task_past_tensor.shape[-1] + full_length = past_tensor.shape[-1] if self.mode == DatasetMode.TRAIN: # slice a random subsequence from the full series @@ -494,74 +565,74 @@ class Chronos2Dataset(IterableDataset): if slice_idx >= self.context_length: # slice series, if it is longer than context_length - task_context = task_past_tensor[:, slice_idx - self.context_length : slice_idx] + context = past_tensor[:, slice_idx - self.context_length : slice_idx] else: - task_context = task_past_tensor[:, :slice_idx] + context = past_tensor[:, :slice_idx] - # In the TEST mode, we have no target available and the task_future_covariates can be directly used - # In the TRAIN and VALIDATION modes, the target and task_future_covariates need to be constructed from - # the task_context_tensor by slicing the appropriate indices which we do below + # In the TEST mode, we have no target available and the future_covariates can be directly used + # In the TRAIN and VALIDATION modes, the target and future_covariates need to be constructed from + # the context_tensor by slicing the appropriate indices which we do below if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]: - # the first task_n_targets elements in task_context_tensor are the targets - task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone() + # the first n_targets elements in context_tensor are the targets + future_target = past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone() # mask out all rows corresponding to covariates - task_future_target[task_n_targets:] = torch.nan + future_target[n_targets:] = torch.nan - if task_n_future_covariates > 0: - # the last task_n_future_covariates elements in task_context_tensor are the known covariates - task_future_covariates = task_past_tensor[ - -task_n_future_covariates:, slice_idx : slice_idx + self.prediction_length + if n_future_covariates > 0: + # the last n_future_covariates elements in context_tensor are the known covariates + future_covariates = past_tensor[ + -n_future_covariates:, slice_idx : slice_idx + self.prediction_length ] else: # zero-length tensor for easy concatenation later - task_future_covariates = torch.zeros((0, self.prediction_length)) + future_covariates = torch.zeros((0, self.prediction_length)) - # the leading task_n_targets + task_n_past_only_covariates elements are masked because the target(s) + # the leading n_targets + n_past_only_covariates elements are masked because the target(s) # and past-only covariates are not known into the future - task_future_covariates_padding = torch.full( - (task_n_targets + task_n_past_only_covariates, self.prediction_length), + future_covariates_padding = torch.full( + (n_targets + n_past_only_covariates, self.prediction_length), fill_value=torch.nan, ) - task_future_covariates = torch.cat([task_future_covariates_padding, task_future_covariates], dim=0) + future_covariates = torch.cat([future_covariates_padding, future_covariates], dim=0) else: - task_future_target = None - task_future_covariates = task_future_tensor + future_target = None + future_covariates = future_tensor - # task_context: (task_n_targets + task_n_covariates, min(context_length, history_length)) - # task_future_target: (task_n_targets + task_n_covariates, prediction_length), the future values of known future covariates + # context: (n_targets + n_covariates, min(context_length, history_length)) + # future_target: (n_targets + n_covariates, prediction_length), the future values of known future covariates # are ignored during loss computation - # task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length), + # future_covariates: (n_targets + n_past_only_covariates + n_future_covariates, prediction_length), # the entries corresponding to targets and past-only covariates are NaNs - return task_context, task_future_target, task_future_covariates, task_n_targets + return context, future_target, future_covariates, n_targets - def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: - """Build a batch from given task indices.""" - batch_context_tensor_list = [] - batch_future_target_tensor_list = [] - batch_future_covariates_tensor_list = [] + def _build_batch(self, input_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: + """Build a batch from given input indices.""" + batch_context_list = [] + batch_future_target_list = [] + batch_future_covariates_list = [] batch_group_ids_list = [] target_idx_ranges: list[tuple[int, int]] = [] target_start_idx = 0 - for group_id, task_idx in enumerate(task_indices): - task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx) + for group_id, input_idx in enumerate(input_indices): + context, future_target, future_covariates, n_targets = self._construct_slice(input_idx) - group_size = task_context.shape[0] - task_group_ids = torch.full((group_size,), fill_value=group_id) - batch_context_tensor_list.append(task_context) - batch_future_target_tensor_list.append(task_future_target) - batch_future_covariates_tensor_list.append(task_future_covariates) - batch_group_ids_list.append(task_group_ids) - target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets)) + group_size = context.shape[0] + group_ids = torch.full((group_size,), fill_value=group_id) + batch_context_list.append(context) + batch_future_target_list.append(future_target) + batch_future_covariates_list.append(future_covariates) + batch_group_ids_list.append(group_ids) + target_idx_ranges.append((target_start_idx, target_start_idx + n_targets)) target_start_idx += group_size return { - "context": left_pad_and_cat_2D(batch_context_tensor_list), + "context": left_pad_and_cat_2D(batch_context_list), "future_target": None if self.mode == DatasetMode.TEST - else torch.cat(cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0), - "future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0), + else torch.cat(cast(list[torch.Tensor], batch_future_target_list), dim=0), + "future_covariates": torch.cat(batch_future_covariates_list, dim=0), "group_ids": torch.cat(batch_group_ids_list, dim=0), "num_output_patches": self.num_output_patches, "target_idx_ranges": target_idx_ranges, @@ -570,27 +641,27 @@ class Chronos2Dataset(IterableDataset): def _generate_train_batches(self): while True: current_batch_size = 0 - task_indices = [] + input_indices = [] while current_batch_size < self.batch_size: - task_idx = np.random.randint(len(self.tasks)) - task_indices.append(task_idx) - current_batch_size += self.tasks[task_idx][0].shape[0] + input_idx = np.random.randint(len(self.inputs)) + input_indices.append(input_idx) + current_batch_size += self.inputs[input_idx]["context"].shape[0] - yield self._build_batch(task_indices) + yield self._build_batch(input_indices) def _generate_sequential_batches(self): - task_idx = 0 - while task_idx < len(self.tasks): + input_idx = 0 + while input_idx < len(self.inputs): current_batch_size = 0 - task_indices = [] + input_indices = [] - while task_idx < len(self.tasks) and current_batch_size < self.batch_size: - task_indices.append(task_idx) - current_batch_size += self.tasks[task_idx][0].shape[0] - task_idx += 1 + while input_idx < len(self.inputs) and current_batch_size < self.batch_size: + input_indices.append(input_idx) + current_batch_size += self.inputs[input_idx]["context"].shape[0] + input_idx += 1 - yield self._build_batch(task_indices) + yield self._build_batch(input_indices) def __iter__(self) -> Iterator: """ @@ -617,39 +688,3 @@ class Chronos2Dataset(IterableDataset): yield batch else: yield from self._generate_sequential_batches() - - @classmethod - def convert_inputs( - cls, - inputs: TensorOrArray - | Sequence[TensorOrArray] - | Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]], - context_length: int, - prediction_length: int, - batch_size: int, - output_patch_size: int, - min_past: int = 1, - mode: str | DatasetMode = DatasetMode.TRAIN, - ) -> "Chronos2Dataset": - """Convert from different input formats to a Chronos2Dataset.""" - if isinstance(inputs, (torch.Tensor, np.ndarray)): - inputs = convert_tensor_input_to_list_of_dicts_input(inputs) - elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]): - inputs = cast(list[TensorOrArray], inputs) - inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs) - elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]): - pass - else: - raise ValueError("Unexpected inputs format") - - inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs) - - return cls( - inputs, - context_length=context_length, - prediction_length=prediction_length, - batch_size=batch_size, - output_patch_size=output_patch_size, - min_past=min_past, - mode=mode, - ) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index c7ccbbb..223689d 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -115,6 +115,7 @@ class Chronos2Pipeline(BaseChronosPipeline): callbacks: list["TrainerCallback"] | None = None, remove_printer_callback: bool = False, disable_data_parallel: bool = True, + convert_inputs: bool = True, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -161,6 +162,10 @@ class Chronos2Pipeline(BaseChronosPipeline): If True, all instances of `PrinterCallback` are removed from callbacks disable_data_parallel If True, ensures that DataParallel is disabled and training happens on a single GPU + convert_inputs + If True (default), preprocess raw inputs (convert tensors, encode categoricals, validate). + If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_inputs`. + This allows for efficient training on large datasets that don't fit in memory. **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -229,7 +234,7 @@ class Chronos2Pipeline(BaseChronosPipeline): if min_past is None: min_past = prediction_length - train_dataset = Chronos2Dataset.convert_inputs( + train_dataset = Chronos2Dataset( inputs=inputs, context_length=context_length, prediction_length=prediction_length, @@ -237,6 +242,7 @@ class Chronos2Pipeline(BaseChronosPipeline): output_patch_size=self.model_output_patch_size, min_past=min_past, mode=DatasetMode.TRAIN, + convert_inputs=convert_inputs, ) if output_dir is None: @@ -290,14 +296,14 @@ class Chronos2Pipeline(BaseChronosPipeline): eval_dataset = None callbacks = callbacks or [] if validation_inputs is not None: - # construct validation dataset - eval_dataset = Chronos2Dataset.convert_inputs( + eval_dataset = Chronos2Dataset( inputs=validation_inputs, context_length=context_length, prediction_length=prediction_length, batch_size=batch_size, output_patch_size=self.model_output_patch_size, mode=DatasetMode.VALIDATION, + convert_inputs=convert_inputs, ) # set validation parameters @@ -610,8 +616,8 @@ class Chronos2Pipeline(BaseChronosPipeline): ) context_length = self.model_context_length - test_dataset = Chronos2Dataset.convert_inputs( - inputs=inputs, + test_dataset = Chronos2Dataset( + inputs, context_length=context_length, prediction_length=prediction_length, batch_size=batch_size, @@ -1136,8 +1142,8 @@ class Chronos2Pipeline(BaseChronosPipeline): ) context_length = self.model_context_length - test_dataset = Chronos2Dataset.convert_inputs( - inputs=inputs, + test_dataset = Chronos2Dataset( + inputs, context_length=context_length, prediction_length=0, batch_size=batch_size, diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3d8884e..2d95a37 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1143,3 +1143,27 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + +def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline): + """Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset.""" + from chronos.chronos2.dataset import prepare_inputs + + prediction_length = 8 + raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}] + + # Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading) + prepared_tasks = prepare_inputs(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train") + hf_dataset = datasets.Dataset.from_list(prepared_tasks).with_format("torch") + + # Fine-tune with preprocessed inputs + ft_pipeline = pipeline.fit( + hf_dataset, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32, convert_inputs=False + ) + + # Verify fine-tuned model can predict + ft_outputs = ft_pipeline.predict(raw_inputs, prediction_length=prediction_length) + assert len(ft_outputs) == len(raw_inputs) + for ft_out in ft_outputs: + assert ft_out.shape == (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length) + assert not torch.isnan(ft_out).any()