Simplify diff

This commit is contained in:
Oleksandr Shchur 2026-02-18 14:29:50 +00:00
parent 1485c8af91
commit 40a7071ed8

View file

@ -19,13 +19,6 @@ if TYPE_CHECKING:
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
class DatasetMode(str, Enum):
TRAIN = "train"
VALIDATION = "validation"
TEST = "test"
# Type alias for raw input format
RawTask = Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]
@ -40,29 +33,67 @@ class PreparedTask(TypedDict):
n_future_covariates: int
def prepare_single_task(
task: RawTask,
idx: int,
prediction_length: int,
def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
"""
Left pads tensors in the list to the length of the longest tensor along the second axis, then concats
these equal length tensors along the first axis.
"""
max_len = max(tensor.shape[-1] for tensor in tensors)
padded = []
for tensor in tensors:
n_variates, length = tensor.shape
if length < max_len:
padding = torch.full((n_variates, max_len - length), fill_value=torch.nan, device=tensor.device)
tensor = torch.cat([padding, tensor], dim=-1)
padded.append(tensor)
return torch.cat(padded, dim=0)
def validate_and_prepare_single_dict_task(
task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
) -> PreparedTask:
"""Validate and prepare a single time series task."""
"""Validates and prepares a single dictionary task for Chronos2Model.
Parameters
----------
task
A dictionary representing a time series that contains:
- `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,) or (n_variates, history_length).
Forecasts will be generated for items in `target`.
- `past_covariates` (optional): a dict of past-only covariates or past values of known future covariates. The keys of the dict
must be names of the covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `history_length`
of `target`.
- `future_covariates` (optional): a dict of future values of known future covariates. The keys of the dict must be names of the
covariates and values must be 1-d `torch.Tensor` or `np.ndarray` with length equal to the `prediction_length`. All keys in
`future_covariates` must be a subset of the keys in `past_covariates`.
idx
Index of this task in the list of tasks, used for error messages
prediction_length
Number of future time steps to predict, used to validate future covariates
Returns
------
PreparedTask
A dictionary containing preprocessed arrays ready for model consumption.
"""
allowed_keys = {"target", "past_covariates", "future_covariates"}
# validate keys
keys = set(task.keys())
if not keys.issubset(allowed_keys):
raise ValueError(
f"Found invalid keys in element at index {idx}. "
f"Allowed keys are {allowed_keys}, but found {keys}"
f"Found invalid keys in element at index {idx}. Allowed keys are {allowed_keys}, but found {keys}"
)
if "target" not in keys:
raise ValueError(f"Element at index {idx} does not contain the required key 'target'")
# Process target
# validate target - convert to numpy float32 (handles bfloat16 and other dtypes)
task_target = task["target"]
if isinstance(task_target, torch.Tensor):
task_target = task_target.to(torch.float32).numpy()
task_target = np.asarray(task_target, dtype=np.float32)
if task_target.ndim > 2:
raise ValueError(
"When the input is a list of dicts, the `target` should either be 1-d with shape (history_length,) "
@ -71,7 +102,7 @@ def prepare_single_task(
history_length = task_target.shape[-1]
task_target = task_target.reshape(-1, history_length)
# Process past_covariates
# validate past_covariates
cat_encoders: dict = {}
task_past_covariates = task.get("past_covariates", {})
if not isinstance(task_past_covariates, dict):
@ -80,6 +111,7 @@ def prepare_single_task(
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}'
)
# gather keys and ensure known-future keys come last to match downstream assumptions
task_covariates_keys = sorted(task_past_covariates.keys())
task_future_covariates = task.get("future_covariates", {})
@ -89,77 +121,73 @@ def prepare_single_task(
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
)
task_future_covariates_keys = sorted(task_future_covariates.keys())
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
raise ValueError(
f"Expected keys in `future_covariates` to be a subset of `past_covariates` "
f"{task_covariates_keys}, but found {task_future_covariates_keys} in element at index {idx}"
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
f"but found {task_future_covariates_keys} in element at index {idx}"
)
# Ordered: past-only first, then known-future
# create ordered keys: past-only first, then known-future (so known-future are the last rows)
task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys]
task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys
# Process past covariates
task_past_covariates_list: list[np.ndarray] = []
for key in task_ordered_covariate_keys:
tensor = task_past_covariates[key]
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(torch.float32).numpy()
tensor = np.asarray(tensor)
# Encode categorical variates
# apply encoding to categorical variates
if not np.issubdtype(tensor.dtype, np.number):
# target encoding, if the target is 1-d
if task_target.shape[0] == 1:
cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0)
X = tensor.astype(str).reshape(-1, 1)
y = task_target.reshape(-1)
mask = np.isfinite(y)
cat_encoder.fit(X[mask], y[mask])
# ordinal encoding, if the target is > 1-d
else:
cat_encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan)
cat_encoder.fit(tensor.astype(str).reshape(-1, 1))
tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape)
cat_encoders[key] = cat_encoder
if tensor.ndim != 1 or len(tensor) != history_length:
raise ValueError(
f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {history_length}), "
f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}"
)
task_past_covariates_list.append(tensor)
task_past_covariates_array = (
np.stack(task_past_covariates_list, axis=0)
if task_past_covariates_list
else np.zeros((0, history_length), dtype=np.float32)
)
if task_past_covariates_list:
task_past_covariates_array = np.stack(task_past_covariates_list, axis=0)
else:
task_past_covariates_array = np.zeros((0, history_length), dtype=np.float32)
# Process future covariates
# validate future_covariates (build rows in the same task_ordered_covariate_keys order)
task_future_covariates_list: list[np.ndarray] = []
for key in task_ordered_covariate_keys:
# future values of past-only covariates are filled with NaNs
tensor = task_future_covariates.get(key, np.full(prediction_length, np.nan))
if tensor is None:
tensor = np.full(prediction_length, np.nan)
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(torch.float32).numpy()
tensor = np.asarray(tensor)
# apply encoding to categorical variates
if not np.issubdtype(tensor.dtype, np.number):
cat_encoder = cat_encoders[key]
tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape)
if tensor.ndim != 1 or len(tensor) != prediction_length:
raise ValueError(
f"Individual `future_covariates` must be 1-d with length equal to the {prediction_length=}, "
f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}"
)
task_future_covariates_list.append(tensor)
if task_future_covariates_list:
task_future_covariates_array = np.stack(task_future_covariates_list, axis=0)
else:
task_future_covariates_array = np.zeros((0, prediction_length), dtype=np.float32)
task_future_covariates_array = (
np.stack(task_future_covariates_list, axis=0)
if task_future_covariates_list
else np.zeros((0, prediction_length), dtype=np.float32)
)
# future values of target series are filled with NaNs
task_future_covariates_target_padding = np.full(
(task_target.shape[0], prediction_length), np.nan, dtype=np.float32
)
@ -182,16 +210,21 @@ def prepare_tasks(
raw_tasks: Iterable[RawTask],
prediction_length: int,
min_past: int = 1,
mode: DatasetMode | str = DatasetMode.TRAIN,
mode: "DatasetMode | str" = "train",
) -> list[PreparedTask]:
"""Prepare multiple time series tasks for training/inference."""
"""Prepare multiple time series tasks for training/inference.
This function handles mode-specific preprocessing (e.g., filtering short series)
and calls validate_and_prepare_single_dict_task for each task.
"""
# Import here to avoid issues with forward reference
if isinstance(mode, str):
mode = DatasetMode(mode)
tasks: list[PreparedTask] = []
for idx, raw_task in enumerate(raw_tasks):
# For non-TEST modes, fix future_covariates
# For non-TEST modes, fix future_covariates (replace None/empty with NaN arrays)
if mode != DatasetMode.TEST:
raw_future_covariates = raw_task.get("future_covariates", {})
if raw_future_covariates:
@ -204,9 +237,9 @@ def prepare_tasks(
raw_task = {**raw_task, "future_covariates": fixed_future_covariates}
raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task)
prepared = prepare_single_task(raw_task, idx, prediction_length)
prepared = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length)
# Filter by minimum length
# Filter by minimum length (except in TEST mode)
if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length:
continue
@ -246,23 +279,6 @@ def validate_prepared_schema(task: Any) -> None:
)
def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
"""
Left pads tensors in the list to the length of the longest tensor along the second axis, then concats
these equal length tensors along the first axis.
"""
max_len = max(tensor.shape[-1] for tensor in tensors)
padded = []
for tensor in tensors:
n_variates, length = tensor.shape
if length < max_len:
padding = torch.full((n_variates, max_len - length), fill_value=torch.nan, device=tensor.device)
tensor = torch.cat([padding, tensor], dim=-1)
padded.append(tensor)
return torch.cat(padded, dim=0)
def convert_list_of_tensors_input_to_list_of_dicts_input(
list_of_tensors: Sequence[TensorOrArray],
) -> list[dict[str, torch.Tensor]]:
@ -427,6 +443,12 @@ def convert_fev_window_to_list_of_dicts_input(
return inputs, target_columns, past_dynamic_columns, known_dynamic_columns
class DatasetMode(str, Enum):
TRAIN = "train"
VALIDATION = "validation"
TEST = "test"
class Chronos2Dataset(IterableDataset):
"""
A dataset wrapper for Chronos-2 models.
@ -439,17 +461,14 @@ class Chronos2Dataset(IterableDataset):
1. Raw inputs (when `convert_inputs=True`, default): A sequence of dictionaries where each
dictionary may have the following keys:
- `target` (required): a 1-d or 2-d `torch.Tensor` or `np.ndarray` of shape (history_length,)
or (n_variates, history_length).
or (n_variates, history_length). Forecasts will be generated for items in `target`.
- `past_covariates` (optional): a dict of past-only covariates or past values of known future
covariates.
- `future_covariates` (optional): a dict of future values of known future covariates.
2. Pre-processed inputs (when `convert_inputs=False`): A sequence of prepared tasks, each with keys:
- `context`: 2-d array of shape (n_variates, history_length)
- `future_covariates`: 2-d array of shape (n_variates, prediction_length)
- `n_targets`, `n_covariates`, `n_future_covariates`: int metadata
Use `chronos.chronos2.preprocessing.prepare_tasks()` to create pre-processed inputs.
2. Pre-processed inputs (when `convert_inputs=False`): A sequence of `PreparedTask` dicts with keys:
`context`, `future_covariates`, `n_targets`, `n_covariates`, `n_future_covariates`.
Use `prepare_tasks()` to create pre-processed inputs.
context_length
The maximum context length used for training or inference
prediction_length
@ -471,7 +490,7 @@ class Chronos2Dataset(IterableDataset):
def __init__(
self,
inputs: "Sequence[PreparedTask | RawTask]",
inputs: Sequence[PreparedTask | RawTask],
context_length: int,
prediction_length: int,
batch_size: int,
@ -484,7 +503,7 @@ class Chronos2Dataset(IterableDataset):
assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}"
if convert_inputs:
# Convert various input formats to list of dicts
# Convert various input formats to list of dicts if needed
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)):
@ -668,12 +687,7 @@ class Chronos2Dataset(IterableDataset):
min_past: int = 1,
mode: str | DatasetMode = DatasetMode.TRAIN,
) -> "Chronos2Dataset":
"""Convert from different input formats to a Chronos2Dataset.
This method handles various input formats (tensors, list of tensors, list of dicts)
and creates a dataset with preprocessing.
"""
# Convert various input formats to list of dicts
"""Convert from different input formats to a Chronos2Dataset."""
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]):
@ -694,5 +708,4 @@ class Chronos2Dataset(IterableDataset):
output_patch_size=output_patch_size,
min_past=min_past,
mode=mode,
convert_inputs=True,
)