Merge branch 'main' into transformers-v5

This commit is contained in:
Kashif Rasul 2026-02-19 17:04:25 +01:00 committed by GitHub
commit debad9d1fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 297 additions and 232 deletions

View file

@ -5,7 +5,7 @@
import math
from enum import Enum
from typing import TYPE_CHECKING, Iterator, Mapping, Sequence, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, TypeAlias, TypedDict, cast
import numpy as np
import torch
@ -20,6 +20,16 @@ if TYPE_CHECKING:
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
class PreparedInput(TypedDict):
"""A preprocessed time series input ready for model training/inference."""
context: torch.Tensor # (n_variates, history_length), float32
future_covariates: torch.Tensor # (n_variates, prediction_length), float32
n_targets: int
n_covariates: int
n_future_covariates: int
def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
"""
Left pads tensors in the list to the length of the longest tensor along the second axis, then concats
@ -37,14 +47,14 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
return torch.cat(padded, dim=0)
def validate_and_prepare_single_dict_task(
task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
) -> tuple[torch.Tensor, torch.Tensor, int, int, int]:
"""Validates and prepares a single dictionary task for Chronos2Model.
def validate_and_prepare_single_dict_input(
raw_input: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
) -> PreparedInput:
"""Validates and prepares a single dictionary input for Chronos2Model.
Parameters
----------
task
raw_input
A dictionary representing a time series that contains:
- `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length).
Forecasts will be generated for items in `target`.
@ -55,27 +65,27 @@ def validate_and_prepare_single_dict_task(
covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in
`future_covariates` must be a subset of the keys in `past_covariates`.
idx
Index of this task in the list of tasks, used for error messages
Index of this input in the list of inputs, used for error messages
prediction_length
Number of future time steps to predict, used to validate future covariates
Returns
------
A tuple containing:
- task_context_tensor: Concatenated tensor of target and past covariates of shape (group_size, history_length),
the first `task_n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates
A PreparedInput containing:
- context: Concatenated tensor of target and past covariates of shape (group_size, history_length),
the first `n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates
and past values of known future covariates.
- task_future_covariates_tensor: Tensor of future covariates of shape (group_size, prediction_length). The last `task_n_future_covariates`
- future_covariates: Tensor of future covariates of shape (group_size, prediction_length). The last `n_future_covariates`
items along the first axis contain future covariates. All the remaining elements corresponding to target and past-only covariates are NaNs.
- task_n_targets: Number of target variables
- task_n_covariates: Total number of covariates (sum of past-only and known future covariates)
- task_n_future_covariates: Number of known future covariates
- n_targets: Number of target variables
- n_covariates: Total number of covariates (sum of past-only and known future covariates)
- n_future_covariates: Number of known future covariates
"""
allowed_keys = {"target", "past_covariates", "future_covariates"}
# validate keys
keys = set(task.keys())
keys = set(raw_input.keys())
if not keys.issubset(allowed_keys):
raise ValueError(
f"Found invalid keys in element at index {idx}. Allowed keys are {allowed_keys}, but found {keys}"
@ -84,58 +94,58 @@ def validate_and_prepare_single_dict_task(
raise ValueError(f"Element at index {idx} does not contain the required key 'target'")
# validate target
task_target = task["target"]
if isinstance(task_target, np.ndarray):
task_target = torch.from_numpy(task_target)
assert isinstance(task_target, torch.Tensor)
if task_target.ndim > 2:
target = raw_input["target"]
if isinstance(target, np.ndarray):
target = torch.from_numpy(target)
assert isinstance(target, torch.Tensor)
if target.ndim > 2:
raise ValueError(
"When the input is a list of dicts, the `target` should either be 1-d with shape (history_length,) "
f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(task_target.shape)}."
f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(target.shape)}."
)
history_length = task_target.shape[-1]
task_target = task_target.view(-1, history_length)
history_length = target.shape[-1]
target = target.view(-1, history_length)
# validate past_covariates
cat_encoders: dict = {}
task_past_covariates = task.get("past_covariates", {})
if not isinstance(task_past_covariates, dict):
past_covariates = raw_input.get("past_covariates", {})
if not isinstance(past_covariates, dict):
raise ValueError(
f"Found invalid type for `past_covariates` in element at index {idx}. "
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}'
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(past_covariates)}'
)
# gather keys and ensure known-future keys come last to match downstream assumptions
task_covariates_keys = sorted(task_past_covariates.keys())
covariates_keys = sorted(past_covariates.keys())
task_future_covariates = task.get("future_covariates", {})
if not isinstance(task_future_covariates, dict):
future_covariates = raw_input.get("future_covariates", {})
if not isinstance(future_covariates, dict):
raise ValueError(
f"Found invalid type for `future_covariates` in element at index {idx}. "
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(future_covariates)}'
)
task_future_covariates_keys = sorted(task_future_covariates.keys())
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
future_covariates_keys = sorted(future_covariates.keys())
if not set(future_covariates_keys).issubset(covariates_keys):
raise ValueError(
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
f"but found {task_future_covariates_keys} in element at index {idx}"
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {covariates_keys}, "
f"but found {future_covariates_keys} in element at index {idx}"
)
# create ordered keys: past-only first, then known-future (so known-future are the last rows)
task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys
task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys
past_only_keys = [k for k in covariates_keys if k not in future_covariates_keys]
ordered_covariate_keys = past_only_keys + future_covariates_keys
task_past_covariates_list: list[torch.Tensor] = []
for key in task_ordered_covariate_keys:
tensor = task_past_covariates[key]
past_covariates_list: list[torch.Tensor] = []
for key in ordered_covariate_keys:
tensor = past_covariates[key]
if isinstance(tensor, np.ndarray):
# apply encoding to categorical variates
if not np.issubdtype(tensor.dtype, np.number):
# target encoding, if the target is 1-d
if task_target.shape[0] == 1:
if target.shape[0] == 1:
cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0)
X = tensor.astype(str).reshape(-1, 1)
y = task_target.view(-1).numpy()
y = target.view(-1).numpy()
mask = np.isfinite(y)
X = X[mask]
y = y[mask]
@ -153,18 +163,18 @@ def validate_and_prepare_single_dict_task(
f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {history_length}), "
f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}"
)
task_past_covariates_list.append(tensor)
task_past_covariates_tensor = (
torch.stack(task_past_covariates_list, dim=0)
if task_past_covariates_list
else torch.zeros((0, history_length), device=task_target.device)
past_covariates_list.append(tensor)
past_covariates_tensor = (
torch.stack(past_covariates_list, dim=0)
if past_covariates_list
else torch.zeros((0, history_length), device=target.device)
)
# validate future_covariates (build rows in the same task_ordered_covariate_keys order)
task_future_covariates_list: list[torch.Tensor] = []
for key in task_ordered_covariate_keys:
# validate future_covariates (build rows in the same ordered_covariate_keys order)
future_covariates_list: list[torch.Tensor] = []
for key in ordered_covariate_keys:
# future values of past-only covariates are filled with NaNs
tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan))
tensor = future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan))
if isinstance(tensor, np.ndarray):
# apply encoding to categorical variates
if not np.issubdtype(tensor.dtype, np.number):
@ -177,35 +187,118 @@ def validate_and_prepare_single_dict_task(
f"Individual `future_covariates` must be 1-d with length equal to the {prediction_length=}, "
f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}"
)
task_future_covariates_list.append(tensor)
task_future_covariates_tensor = (
torch.stack(task_future_covariates_list, dim=0)
if task_future_covariates_list
else torch.zeros((0, prediction_length), device=task_target.device)
future_covariates_list.append(tensor)
future_covariates_tensor = (
torch.stack(future_covariates_list, dim=0)
if future_covariates_list
else torch.zeros((0, prediction_length), device=target.device)
)
# future values of target series are filled with NaNs
task_future_covariates_target_padding = torch.full(
(task_target.shape[0], prediction_length), fill_value=torch.nan, device=task_target.device
future_covariates_target_padding = torch.full(
(target.shape[0], prediction_length), fill_value=torch.nan, device=target.device
)
task_context_tensor = torch.cat([task_target, task_past_covariates_tensor], dim=0).to(dtype=torch.float32)
task_future_covariates_tensor = torch.cat(
[task_future_covariates_target_padding, task_future_covariates_tensor], dim=0
context_tensor = torch.cat([target, past_covariates_tensor], dim=0).to(dtype=torch.float32)
future_covariates_tensor = torch.cat(
[future_covariates_target_padding, future_covariates_tensor], dim=0
).to(dtype=torch.float32)
task_n_targets = task_target.shape[0]
task_n_covariates = task_past_covariates_tensor.shape[0]
n_targets = target.shape[0]
n_covariates = past_covariates_tensor.shape[0]
# number of known-future covariates
task_n_future_covariates = len(task_future_covariates_keys)
n_future_covariates = len(future_covariates_keys)
return (
task_context_tensor,
task_future_covariates_tensor,
task_n_targets,
task_n_covariates,
task_n_future_covariates,
return PreparedInput(
context=context_tensor,
future_covariates=future_covariates_tensor,
n_targets=n_targets,
n_covariates=n_covariates,
n_future_covariates=n_future_covariates,
)
def prepare_inputs(
raw_inputs: Iterable[Mapping[str, Any]],
prediction_length: int,
min_past: int = 1,
mode: "DatasetMode | str" = "train",
) -> list[PreparedInput]:
"""Prepare multiple time series inputs for training/inference.
This function handles mode-specific preprocessing (e.g., filtering short series)
and calls validate_and_prepare_single_dict_input for each input.
"""
inputs: list[PreparedInput] = []
for idx, raw_input in enumerate(raw_inputs):
# For non-TEST modes, fix future_covariates (replace None/empty with NaN arrays)
if mode != DatasetMode.TEST:
raw_future_covariates = raw_input.get("future_covariates", {})
if raw_future_covariates:
raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates)
fixed_future_covariates = {}
for key, value in raw_future_covariates.items():
fixed_future_covariates[key] = (
np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value
)
raw_input = {**raw_input, "future_covariates": fixed_future_covariates}
raw_input = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_input)
prepared = validate_and_prepare_single_dict_input(raw_input, idx, prediction_length)
# Filter by minimum length (except in TEST mode)
if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length:
continue
inputs.append(prepared)
if len(inputs) == 0:
raise ValueError(
"The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). "
"Please provide longer time series or reduce `min_past` or `prediction_length`. "
)
return inputs
def validate_prepared_schema(prepared_input: Any) -> None:
"""Validate that an input matches the PreparedInput schema."""
if not isinstance(prepared_input, Mapping):
raise TypeError(
f"Expected input to be a dict-like, got {type(prepared_input).__name__}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
missing = required_keys - set(prepared_input.keys())
if missing:
raise TypeError(
f"Input is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
context = prepared_input["context"]
if not isinstance(context, torch.Tensor) or context.ndim != 2:
raise TypeError(
f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} "
f"with shape {getattr(context, 'shape', 'N/A')}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
future_covariates = prepared_input["future_covariates"]
if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2:
raise TypeError(
f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} "
f"with shape {getattr(future_covariates, 'shape', 'N/A')}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
if context.shape[0] != future_covariates.shape[0]:
raise ValueError(
f"Expected 'context' and 'future_covariates' to have the same first dimension, "
f"got {context.shape[0]} and {future_covariates.shape[0]}. "
"Set convert_inputs=True when calling fit() to preprocess raw inputs."
)
def convert_list_of_tensors_input_to_list_of_dicts_input(
list_of_tensors: Sequence[TensorOrArray],
) -> list[dict[str, torch.Tensor]]:
@ -383,49 +476,65 @@ class Chronos2Dataset(IterableDataset):
Arguments
----------
inputs
Time series data. Must be a list of dictionaries where each dictionary may have the following keys.
- `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length).
Forecasts will be generated for items in `target`.
- `past_covariates` (optional): a dict of past-only covariates or past values of known future covariates. The keys of the dict
must be names of the covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `history_length`
of `target`.
- `future_covariates` (optional): a dict of future values of known future covariates. The keys of the dict must be names of the
covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in
`future_covariates` must be a subset of the keys in `past_covariates`.
Note: when the mode is set to TRAIN, the values inside `future_covariates` are not technically used for training the model;
however, this key is used to infer which covariates are known into the future. Therefore, if your task contains known future covariates,
make sure that this key exists in `inputs`. The values of individual future covariates may be set to `None` or an empty array.
Time series data. Can be either:
1. Raw inputs (when `convert_inputs=True`, default): A sequence of dictionaries where each
dictionary may have the following keys:
- `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,)
or (n_variates, history_length). Forecasts will be generated for items in `target`.
- `past_covariates` (optional): a dict of past-only covariates or past values of known future
covariates.
- `future_covariates` (optional): a dict of future values of known future covariates.
2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedInput` dicts with keys:
`context`, `future_covariates`, `n_targets`, `n_covariates`, `n_future_covariates`.
Use `prepare_inputs()` to create pre-processed inputs.
context_length
The maximum context length used for training or inference
prediction_length
The prediction horizon
batch_size
The batch size for training the model. Note that the batch size here means the number of time series, including target(s) and
covariates, that are input into the model. If your data has multiple target and/or covariates, the effective number of time series
tasks in a batch will be lower than this value.
The batch size for training the model. Note that the batch size here means the number of time series,
including target(s) and covariates, that are input into the model.
output_patch_size
The output patch size of the model. This is used to compute the number of patches needed to cover `prediction_length`
The output patch size of the model. This is used to compute the number of patches needed to cover
`prediction_length`
min_past
The minimum number of time steps the context must have during training. All time series shorter than `min_past + prediction_length`
are filtered out, by default 1
The minimum number of time steps the context must have during training. All time series shorter than
`min_past + prediction_length` are filtered out, by default 1
mode
`DatasetMode` governing whether to generate training, validation or test samples, by default "train"
convert_inputs
If True (default), preprocess raw inputs. If False, inputs are expected to be already preprocessed.
"""
def __init__(
self,
inputs: Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]],
inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedInput],
context_length: int,
prediction_length: int,
batch_size: int,
output_patch_size: int,
min_past: int = 1,
mode: str | DatasetMode = DatasetMode.TRAIN,
convert_inputs: bool = True,
) -> None:
super().__init__()
assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}"
self.tasks = Chronos2Dataset._prepare_tasks(inputs, prediction_length, min_past, mode)
self.inputs: Sequence[PreparedInput]
if convert_inputs:
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif (
isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray))
):
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
self.inputs = prepare_inputs(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
else:
validate_prepared_schema(inputs[0])
self.inputs = cast(Sequence[PreparedInput], inputs)
self.context_length = context_length
self.prediction_length = prediction_length
self.batch_size = batch_size
@ -433,54 +542,16 @@ class Chronos2Dataset(IterableDataset):
self.min_past = min_past
self.mode = mode
@staticmethod
def _prepare_tasks(
inputs: Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]],
prediction_length: int,
min_past: int,
mode: str | DatasetMode,
):
tasks = []
for idx, raw_task in enumerate(inputs):
if mode != DatasetMode.TEST:
raw_future_covariates = raw_task.get("future_covariates", {})
raw_future_covariates = cast(dict[str, TensorOrArray | None], raw_future_covariates)
if raw_future_covariates:
fixed_future_covariates = {}
for key, value in raw_future_covariates.items():
fixed_future_covariates[key] = (
np.full(prediction_length, np.nan) if value is None or len(value) == 0 else value
)
raw_task = {**raw_task, "future_covariates": fixed_future_covariates}
def _construct_slice(self, input_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
prepared = self.inputs[input_idx]
past_tensor = prepared["context"].clone() # shape: (n_targets + n_covariates, history_length)
future_tensor = prepared["future_covariates"].clone()
n_targets = int(prepared["n_targets"])
n_covariates = int(prepared["n_covariates"])
n_future_covariates = int(prepared["n_future_covariates"])
n_past_only_covariates = n_covariates - n_future_covariates
raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task)
# convert to a format compatible with model's forward
task = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length)
if mode != DatasetMode.TEST and task[0].shape[-1] < min_past + prediction_length:
# filter tasks based on min_past + prediction_length
continue
tasks.append(task)
if len(tasks) == 0:
raise ValueError(
"The dataset is empty after filtering based on the length of the time series (length >= min_past + prediction_length). "
"Please provide longer time series or reduce `min_past` or `prediction_length`. "
)
return tasks
def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]:
(
task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length)
task_future_tensor,
task_n_targets,
task_n_covariates,
task_n_future_covariates,
) = self.tasks[task_idx]
task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone()
task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
full_length = task_past_tensor.shape[-1]
full_length = past_tensor.shape[-1]
if self.mode == DatasetMode.TRAIN:
# slice a random subsequence from the full series
@ -494,74 +565,74 @@ class Chronos2Dataset(IterableDataset):
if slice_idx >= self.context_length:
# slice series, if it is longer than context_length
task_context = task_past_tensor[:, slice_idx - self.context_length : slice_idx]
context = past_tensor[:, slice_idx - self.context_length : slice_idx]
else:
task_context = task_past_tensor[:, :slice_idx]
context = past_tensor[:, :slice_idx]
# In the TEST mode, we have no target available and the task_future_covariates can be directly used
# In the TRAIN and VALIDATION modes, the target and task_future_covariates need to be constructed from
# the task_context_tensor by slicing the appropriate indices which we do below
# In the TEST mode, we have no target available and the future_covariates can be directly used
# In the TRAIN and VALIDATION modes, the target and future_covariates need to be constructed from
# the context_tensor by slicing the appropriate indices which we do below
if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
# the first task_n_targets elements in task_context_tensor are the targets
task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
# the first n_targets elements in context_tensor are the targets
future_target = past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
# mask out all rows corresponding to covariates
task_future_target[task_n_targets:] = torch.nan
future_target[n_targets:] = torch.nan
if task_n_future_covariates > 0:
# the last task_n_future_covariates elements in task_context_tensor are the known covariates
task_future_covariates = task_past_tensor[
-task_n_future_covariates:, slice_idx : slice_idx + self.prediction_length
if n_future_covariates > 0:
# the last n_future_covariates elements in context_tensor are the known covariates
future_covariates = past_tensor[
-n_future_covariates:, slice_idx : slice_idx + self.prediction_length
]
else:
# zero-length tensor for easy concatenation later
task_future_covariates = torch.zeros((0, self.prediction_length))
future_covariates = torch.zeros((0, self.prediction_length))
# the leading task_n_targets + task_n_past_only_covariates elements are masked because the target(s)
# the leading n_targets + n_past_only_covariates elements are masked because the target(s)
# and past-only covariates are not known into the future
task_future_covariates_padding = torch.full(
(task_n_targets + task_n_past_only_covariates, self.prediction_length),
future_covariates_padding = torch.full(
(n_targets + n_past_only_covariates, self.prediction_length),
fill_value=torch.nan,
)
task_future_covariates = torch.cat([task_future_covariates_padding, task_future_covariates], dim=0)
future_covariates = torch.cat([future_covariates_padding, future_covariates], dim=0)
else:
task_future_target = None
task_future_covariates = task_future_tensor
future_target = None
future_covariates = future_tensor
# task_context: (task_n_targets + task_n_covariates, min(context_length, history_length))
# task_future_target: (task_n_targets + task_n_covariates, prediction_length), the future values of known future covariates
# context: (n_targets + n_covariates, min(context_length, history_length))
# future_target: (n_targets + n_covariates, prediction_length), the future values of known future covariates
# are ignored during loss computation
# task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length),
# future_covariates: (n_targets + n_past_only_covariates + n_future_covariates, prediction_length),
# the entries corresponding to targets and past-only covariates are NaNs
return task_context, task_future_target, task_future_covariates, task_n_targets
return context, future_target, future_covariates, n_targets
def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
"""Build a batch from given task indices."""
batch_context_tensor_list = []
batch_future_target_tensor_list = []
batch_future_covariates_tensor_list = []
def _build_batch(self, input_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
"""Build a batch from given input indices."""
batch_context_list = []
batch_future_target_list = []
batch_future_covariates_list = []
batch_group_ids_list = []
target_idx_ranges: list[tuple[int, int]] = []
target_start_idx = 0
for group_id, task_idx in enumerate(task_indices):
task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx)
for group_id, input_idx in enumerate(input_indices):
context, future_target, future_covariates, n_targets = self._construct_slice(input_idx)
group_size = task_context.shape[0]
task_group_ids = torch.full((group_size,), fill_value=group_id)
batch_context_tensor_list.append(task_context)
batch_future_target_tensor_list.append(task_future_target)
batch_future_covariates_tensor_list.append(task_future_covariates)
batch_group_ids_list.append(task_group_ids)
target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets))
group_size = context.shape[0]
group_ids = torch.full((group_size,), fill_value=group_id)
batch_context_list.append(context)
batch_future_target_list.append(future_target)
batch_future_covariates_list.append(future_covariates)
batch_group_ids_list.append(group_ids)
target_idx_ranges.append((target_start_idx, target_start_idx + n_targets))
target_start_idx += group_size
return {
"context": left_pad_and_cat_2D(batch_context_tensor_list),
"context": left_pad_and_cat_2D(batch_context_list),
"future_target": None
if self.mode == DatasetMode.TEST
else torch.cat(cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0),
"future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0),
else torch.cat(cast(list[torch.Tensor], batch_future_target_list), dim=0),
"future_covariates": torch.cat(batch_future_covariates_list, dim=0),
"group_ids": torch.cat(batch_group_ids_list, dim=0),
"num_output_patches": self.num_output_patches,
"target_idx_ranges": target_idx_ranges,
@ -570,27 +641,27 @@ class Chronos2Dataset(IterableDataset):
def _generate_train_batches(self):
while True:
current_batch_size = 0
task_indices = []
input_indices = []
while current_batch_size < self.batch_size:
task_idx = np.random.randint(len(self.tasks))
task_indices.append(task_idx)
current_batch_size += self.tasks[task_idx][0].shape[0]
input_idx = np.random.randint(len(self.inputs))
input_indices.append(input_idx)
current_batch_size += self.inputs[input_idx]["context"].shape[0]
yield self._build_batch(task_indices)
yield self._build_batch(input_indices)
def _generate_sequential_batches(self):
task_idx = 0
while task_idx < len(self.tasks):
input_idx = 0
while input_idx < len(self.inputs):
current_batch_size = 0
task_indices = []
input_indices = []
while task_idx < len(self.tasks) and current_batch_size < self.batch_size:
task_indices.append(task_idx)
current_batch_size += self.tasks[task_idx][0].shape[0]
task_idx += 1
while input_idx < len(self.inputs) and current_batch_size < self.batch_size:
input_indices.append(input_idx)
current_batch_size += self.inputs[input_idx]["context"].shape[0]
input_idx += 1
yield self._build_batch(task_indices)
yield self._build_batch(input_indices)
def __iter__(self) -> Iterator:
"""
@ -617,39 +688,3 @@ class Chronos2Dataset(IterableDataset):
yield batch
else:
yield from self._generate_sequential_batches()
@classmethod
def convert_inputs(
cls,
inputs: TensorOrArray
| Sequence[TensorOrArray]
| Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]],
context_length: int,
prediction_length: int,
batch_size: int,
output_patch_size: int,
min_past: int = 1,
mode: str | DatasetMode = DatasetMode.TRAIN,
) -> "Chronos2Dataset":
"""Convert from different input formats to a Chronos2Dataset."""
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]):
inputs = cast(list[TensorOrArray], inputs)
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]):
pass
else:
raise ValueError("Unexpected inputs format")
inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs)
return cls(
inputs,
context_length=context_length,
prediction_length=prediction_length,
batch_size=batch_size,
output_patch_size=output_patch_size,
min_past=min_past,
mode=mode,
)

View file

@ -115,6 +115,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
callbacks: list["TrainerCallback"] | None = None,
remove_printer_callback: bool = False,
disable_data_parallel: bool = True,
convert_inputs: bool = True,
**extra_trainer_kwargs,
) -> "Chronos2Pipeline":
"""
@ -161,6 +162,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
If True, all instances of `PrinterCallback` are removed from callbacks
disable_data_parallel
If True, ensures that DataParallel is disabled and training happens on a single GPU
convert_inputs
If True (default), preprocess raw inputs (convert tensors, encode categoricals, validate).
If False, inputs are expected to be already preprocessed using `chronos.chronos2.dataset.prepare_inputs`.
This allows for efficient training on large datasets that don't fit in memory.
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`
@ -229,7 +234,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
if min_past is None:
min_past = prediction_length
train_dataset = Chronos2Dataset.convert_inputs(
train_dataset = Chronos2Dataset(
inputs=inputs,
context_length=context_length,
prediction_length=prediction_length,
@ -237,6 +242,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
output_patch_size=self.model_output_patch_size,
min_past=min_past,
mode=DatasetMode.TRAIN,
convert_inputs=convert_inputs,
)
if output_dir is None:
@ -290,14 +296,14 @@ class Chronos2Pipeline(BaseChronosPipeline):
eval_dataset = None
callbacks = callbacks or []
if validation_inputs is not None:
# construct validation dataset
eval_dataset = Chronos2Dataset.convert_inputs(
eval_dataset = Chronos2Dataset(
inputs=validation_inputs,
context_length=context_length,
prediction_length=prediction_length,
batch_size=batch_size,
output_patch_size=self.model_output_patch_size,
mode=DatasetMode.VALIDATION,
convert_inputs=convert_inputs,
)
# set validation parameters
@ -610,8 +616,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
context_length = self.model_context_length
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
test_dataset = Chronos2Dataset(
inputs,
context_length=context_length,
prediction_length=prediction_length,
batch_size=batch_size,
@ -1136,8 +1142,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
context_length = self.model_context_length
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
test_dataset = Chronos2Dataset(
inputs,
context_length=context_length,
prediction_length=0,
batch_size=batch_size,

View file

@ -1143,3 +1143,27 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
def test_pipeline_can_be_finetuned_with_preprocessed_hf_dataset(pipeline):
"""Test that fine-tuning works with preprocessed inputs from a HuggingFace Dataset."""
from chronos.chronos2.dataset import prepare_inputs
prediction_length = 8
raw_inputs = [{"target": torch.rand(20)}, {"target": torch.rand(25)}, {"target": torch.rand(30)}]
# Preprocess and convert to HF Dataset (simulating Arrow-based lazy loading)
prepared_tasks = prepare_inputs(raw_inputs, prediction_length=prediction_length, min_past=1, mode="train")
hf_dataset = datasets.Dataset.from_list(prepared_tasks).with_format("torch")
# Fine-tune with preprocessed inputs
ft_pipeline = pipeline.fit(
hf_dataset, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32, convert_inputs=False
)
# Verify fine-tuned model can predict
ft_outputs = ft_pipeline.predict(raw_inputs, prediction_length=prediction_length)
assert len(ft_outputs) == len(raw_inputs)
for ft_out in ft_outputs:
assert ft_out.shape == (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length)
assert not torch.isnan(ft_out).any()