This commit is contained in:
Alexander März 2026-01-19 14:07:13 +00:00 committed by GitHub
commit 631ebcfddd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 907 additions and 6 deletions

View file

@ -1730,12 +1730,268 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c899976",
"metadata": {},
"outputs": [],
"source": []
"cell_type": "markdown",
"source": [
"## Custom Group IDs: Examples and Use Cases\n",
"\n",
"Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n",
"- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n",
"- cross_learning=True: all series in the batch are jointly predicted and share information.\n",
"- Custom group_ids: only series within the same group share information; groups remain independent.\n",
"\n",
"Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n"
],
"id": "e504694b8cfab326"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:45.846727Z",
"start_time": "2025-12-09T13:39:45.827756Z"
}
},
"cell_type": "code",
"source": [
"# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n",
"import numpy as np, pandas as pd\n",
"np.random.seed(123)\n",
"n = 200\n",
"prediction_length = 24\n",
"ts = pd.date_range('2020-01-01', periods=n, freq='D')\n",
"\n",
"north_ids = [f'station_north_{i+1}' for i in range(10)]\n",
"south_ids = [f'station_south_{i+1}' for i in range(10)]\n",
"coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n",
"all_ids = north_ids + south_ids + coast_ids\n",
"regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n",
"\n",
"def synth_series(base, amp, noise, phase=0.0):\n",
" t = np.arange(n)\n",
" signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n",
" return (signal + noise*np.random.randn(n)).astype('float32')\n",
"\n",
"data_frames = []\n",
"for sid, region in zip(all_ids, regions):\n",
" if region == 'North':\n",
" y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n",
" elif region == 'South':\n",
" y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n",
" else: # Coastal\n",
" y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n",
" df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n",
" data_frames.append(df_i)\n",
"\n",
"weather_df = pd.concat(data_frames, ignore_index=True)\n",
"weather_df.head()"
],
"id": "326d90e1a05586e8",
"outputs": [
{
"data": {
"text/plain": [
" item_id timestamp target region\n",
"0 station_north_1 2020-01-01 5.904617 North\n",
"1 station_north_1 2020-01-02 7.669402 North\n",
"2 station_north_1 2020-01-03 7.195759 North\n",
"3 station_north_1 2020-01-04 5.861607 North\n",
"4 station_north_1 2020-01-05 6.700416 North"
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>item_id</th>\n",
" <th>timestamp</th>\n",
" <th>target</th>\n",
" <th>region</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-01</td>\n",
" <td>5.904617</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-02</td>\n",
" <td>7.669402</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-03</td>\n",
" <td>7.195759</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-04</td>\n",
" <td>5.861607</td>\n",
" <td>North</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>station_north_1</td>\n",
" <td>2020-01-05</td>\n",
" <td>6.700416</td>\n",
" <td>North</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:49.624289Z",
"start_time": "2025-12-09T13:39:49.251047Z"
}
},
"cell_type": "code",
"source": [
"# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n",
"from chronos import (\n",
" create_group_ids_dict_from_category,\n",
" create_group_ids_from_category\n",
")\n",
"\n",
"group_ids_cat = create_group_ids_dict_from_category(\n",
" df=weather_df,\n",
" id_column='item_id',\n",
" category_column='region'\n",
")\n",
"\n",
"# Split into context (train) and future truth (test)\n",
"test_df = weather_df.groupby('item_id').tail(prediction_length)\n",
"future_df = test_df.drop(columns=['target', 'region']).copy()\n",
"train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n",
"\n",
"pred_df_ids = pipeline.predict_df(\n",
" df=train_df,\n",
" future_df=future_df,\n",
" id_column='item_id',\n",
" timestamp_column='timestamp',\n",
" target='target',\n",
" prediction_length=prediction_length,\n",
" group_ids=group_ids_cat,\n",
" quantile_levels=[0.5],\n",
")\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "9113a0f64f257772",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-12-09T13:39:53.070369Z",
"start_time": "2025-12-09T13:39:52.884440Z"
}
},
"cell_type": "code",
"source": [
"# 2) Low-level API: predict_quantiles with group_ids list\n",
"\n",
"# Create group IDs as a list aligned with series_ids\n",
"group_ids = create_group_ids_from_category(\n",
"\t train=weather_df,\n",
"\t id_column=\"item_id\",\n",
"\t category_column=\"region\"\n",
"\t)\n",
"\n",
"# Build inputs in the same order as series_ids\n",
"series_ids = train_df['item_id'].unique().tolist()\n",
"inputs_list = [\n",
" {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n",
" for sid in series_ids\n",
"]\n",
"\n",
"_, mean = pipeline.predict_quantiles(\n",
" inputs=inputs_list,\n",
" prediction_length=prediction_length,\n",
" quantile_levels=[0.5],\n",
" group_ids=group_ids,\n",
")\n",
"\n",
"pred_df_ids = []\n",
"for id in series_ids:\n",
" test_id_df = test_df[test_df['item_id']==id].copy()\n",
" test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n",
" test_id_df.drop(columns=['target'], inplace=True)\n",
" pred_df_ids.append(test_id_df)\n",
"pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n",
"\n",
"# Compute MSE/MAE against ground truth\n",
"eval_df = pred_df_ids.merge(\n",
" test_df[['item_id', 'timestamp', 'target']],\n",
" on=['item_id', 'timestamp'], how='inner',\n",
")\n",
"y_true = eval_df['target'].to_numpy()\n",
"y_pred = eval_df['predictions'].to_numpy()\n",
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
],
"id": "629bab4c4ea6fe53",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE=0.523, MAE=0.537\n"
]
}
],
"execution_count": 5
}
],
"metadata": {

View file

@ -12,6 +12,12 @@ from .chronos import (
)
from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
from .utils import (
create_group_ids_dict_from_category,
create_group_ids_dict_from_mapping,
create_manual_group_ids_dict,
create_group_ids_from_category
)
__all__ = [
"__version__",
@ -27,4 +33,7 @@ __all__ = [
"Chronos2ForecastingConfig",
"Chronos2Model",
"Chronos2Pipeline",
"create_group_ids_dict_from_category",
"create_group_ids_dict_from_mapping",
"create_manual_group_ids_dict",
]

View file

@ -456,6 +456,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
batch_size: int = 256,
context_length: int | None = None,
cross_learning: bool = False,
group_ids: list[int] | torch.Tensor | None = None,
limit_prediction_length: bool = False,
**kwargs,
) -> list[torch.Tensor]:
@ -548,6 +549,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
group_ids
Optional custom group IDs to control information sharing between time series.
If provided, must be a list or tensor of integers with length equal to the number of tasks in `inputs`.
Tasks with the same group ID will share information during prediction via cross-attention.
Cannot be used together with `cross_learning=True`. By default None (each task gets unique group ID).
Example: [0, 0, 1] means first two tasks share information, third is separate.
limit_prediction_length
If True, an error is raised when prediction_length is greater than model's default prediction length, by default False
@ -569,6 +576,40 @@ class Chronos2Pipeline(BaseChronosPipeline):
stacklevel=2,
)
cross_learning = kwargs.pop("predict_batches_jointly")
# Validate group_ids and cross_learning interaction
if group_ids is not None and cross_learning:
raise ValueError(
"Cannot specify both `group_ids` and `cross_learning=True`. "
"Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
)
# Convert group_ids to tensor if provided
custom_group_ids_tensor = None
if group_ids is not None:
if isinstance(group_ids, list):
# Strict type check: only integers are allowed in the list
if not all(isinstance(x, (int, np.integer)) for x in group_ids):
raise TypeError("`group_ids` list must contain only integers")
if any(x < 0 for x in group_ids):
raise ValueError("`group_ids` must contain only non-negative integers")
custom_group_ids_tensor = torch.tensor(group_ids, dtype=torch.long)
elif isinstance(group_ids, torch.Tensor):
# Enforce integer dtype for tensor inputs
if torch.is_floating_point(group_ids) or group_ids.dtype == torch.bool:
raise TypeError("`group_ids` tensor must have an integer dtype")
if (group_ids < 0).any():
raise ValueError("`group_ids` must contain only non-negative integers")
custom_group_ids_tensor = group_ids.to(dtype=torch.long).clone()
else:
raise TypeError(f"`group_ids` must be a list or torch.Tensor, got {type(group_ids)}")
# Validate length matches number of tasks
if len(custom_group_ids_tensor) != len(inputs):
raise ValueError(
f"`group_ids` length ({len(custom_group_ids_tensor)}) must match number of tasks in inputs ({len(inputs)})"
)
# The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger
# than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value
max_output_patches = kwargs.pop("max_output_patches", self.max_output_patches)
@ -623,6 +664,9 @@ class Chronos2Pipeline(BaseChronosPipeline):
)
all_predictions: list[torch.Tensor] = []
# Track the current task index for custom group ID mapping
current_task_idx = 0
for batch in test_loader:
assert batch["future_target"] is None
batch_context = batch["context"]
@ -630,7 +674,38 @@ class Chronos2Pipeline(BaseChronosPipeline):
batch_future_covariates = batch["future_covariates"]
batch_target_idx_ranges = batch["target_idx_ranges"]
if cross_learning:
# Apply custom group IDs if provided
if custom_group_ids_tensor is not None:
# Determine how many tasks are in this batch
num_tasks_in_batch = len(batch_target_idx_ranges)
# The key insight: batch_group_ids already maps variates to tasks
# We just need to replace the task IDs with our custom ones
# Create a mapping from old task IDs to new task IDs
old_group_ids = batch_group_ids.cpu().numpy()
# Preserve first-appearance order of group IDs (robust to non-consecutive IDs)
seen = set()
unique_old_ids_list: list[int] = []
for gid in old_group_ids.tolist():
if gid not in seen:
seen.add(int(gid))
unique_old_ids_list.append(int(gid))
# Map old group IDs to task indices (0, 1, 2, ..., num_tasks_in_batch-1)
# Then map those to custom group IDs
new_group_ids = old_group_ids.copy()
for task_offset, old_group_id in enumerate(unique_old_ids_list):
task_idx = current_task_idx + task_offset
custom_group_id = custom_group_ids_tensor[task_idx].item()
# Replace all occurrences of old_group_id with custom_group_id
new_group_ids[old_group_ids == old_group_id] = custom_group_id
batch_group_ids = torch.tensor(new_group_ids, dtype=torch.long)
current_task_idx += num_tasks_in_batch
elif cross_learning:
batch_group_ids = torch.zeros_like(batch_group_ids)
batch_prediction = self._predict_batch(
@ -824,6 +899,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
batch_size: int = 256,
context_length: int | None = None,
cross_learning: bool = False,
group_ids: dict[str, int] | None = None,
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
@ -864,6 +940,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
group_ids
Optional dictionary mapping series IDs (from id_column) to group IDs.
Series with the same group ID will share information during prediction.
Cannot be used together with `cross_learning=True`. By default None.
Example: {'series_A': 0, 'series_B': 0, 'series_C': 1} means series_A and series_B share info, series_C is separate.
If a series ID is not in the dictionary, it will be assigned a unique group ID.
validate_inputs
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
@ -907,6 +989,48 @@ class Chronos2Pipeline(BaseChronosPipeline):
validate_inputs=validate_inputs,
)
# Convert dictionary group_ids to list format matching inputs order
group_ids_list = None
if group_ids is not None:
# Validate group_ids format
if not isinstance(group_ids, dict):
raise TypeError(f"`group_ids` must be a dictionary, got {type(group_ids)}")
if not all(isinstance(k, str) and isinstance(v, int) for k, v in group_ids.items()):
raise TypeError("`group_ids` dictionary must have string keys and integer values")
if any(v < 0 for v in group_ids.values()):
raise ValueError("`group_ids` values must be non-negative integers")
if cross_learning:
raise ValueError(
"Cannot specify both `group_ids` and `cross_learning=True`. "
"Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
)
# Warn if series IDs in group_ids don't exist in dataframe
series_in_df = set(df[id_column].unique())
unknown_series = set(group_ids.keys()) - series_in_df
if unknown_series:
warnings.warn(
f"The following series IDs in `group_ids` were not found in the dataframe: {unknown_series}. "
f"They will be ignored.",
category=UserWarning,
)
# Create list of group IDs matching the order of inputs
# original_order contains the series IDs in the order they appear in inputs
group_ids_list = []
next_auto_group_id = max(group_ids.values()) + 1 if group_ids else 0
for series_id in original_order:
if series_id in group_ids:
group_ids_list.append(group_ids[series_id])
else:
# Assign unique group ID to series not in the mapping
group_ids_list.append(next_auto_group_id)
next_auto_group_id += 1
# Generate forecasts
quantiles, mean = self.predict_quantiles(
inputs=inputs,
@ -916,6 +1040,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
batch_size=batch_size,
context_length=context_length,
cross_learning=cross_learning,
group_ids=group_ids_list,
**predict_kwargs,
)
# since predict_df tasks are homogenous by input design, we can safely stack the list of tensors into a single tensor

View file

@ -7,6 +7,12 @@ from typing import List
import torch
from einops import repeat
try:
import pandas as pd
_PANDAS_AVAILABLE = True
except ImportError:
_PANDAS_AVAILABLE = False
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
max_len = max(len(c) for c in tensors)
@ -210,3 +216,228 @@ def weighted_quantile(
# Reshape to original shape
final_shape = (*orig_samples_shape[:-1], len(query_quantile_levels))
return interpolated_quantiles.reshape(final_shape).to(dtype=orig_dtype)
def create_group_ids_dict_from_category(
df: "pd.DataFrame",
id_column: str,
category_column: str
) -> dict[str, int]:
"""
Create group_ids dictionary (for predict_df) from a categorical column.
This function is specifically designed for use with Chronos2Pipeline.predict_df(),
which accepts a dictionary mapping series IDs to group IDs.
Parameters
----------
df : pd.DataFrame
DataFrame containing id_column and category_column.
id_column : str
Name of the column containing time series identifiers.
category_column : str
Name of the categorical column to group by (e.g., "region", "product_type", "industry").
Returns
-------
dict[str, int]
Dictionary mapping series IDs to group IDs.
Series with the same category will have the same group ID.
Examples
--------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C'],
... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
... 'value': [1, 2, 3, 4, 5, 6]
... })
>>> create_group_ids_dict_from_category(df, 'item_id', 'region')
{'A': 0, 'B': 0, 'C': 1} # A and B are North (group 0), C is South (group 1)
Notes
-----
- Use this for Chronos2Pipeline.predict_df() which expects dict[str, int]
- The function automatically handles the order of series IDs
- Categories are mapped to consecutive integers starting from 0
"""
if not _PANDAS_AVAILABLE:
raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
# Get unique series and their category
series_categories = df.groupby(id_column, sort=False)[category_column].first()
# Map categories to group IDs
unique_categories = series_categories.unique()
category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
# Create dictionary mapping series_id -> group_id
group_ids_dict = {
series_id: category_to_group[category]
for series_id, category in series_categories.items()
}
return group_ids_dict
def create_group_ids_dict_from_mapping(
df: "pd.DataFrame",
id_column: str,
category_to_group_map: dict[str, int]
) -> dict[str, int]:
"""
Create group_ids dictionary using a custom category-to-group mapping.
This allows you to manually specify which categories belong to which groups,
useful when you want custom grouping logic beyond simple one-to-one category mapping.
Parameters
----------
df : pd.DataFrame
DataFrame containing id_column and a category column.
id_column : str
Name of the column containing time series identifiers.
category_to_group_map : dict[str, int]
Dictionary mapping category names to group IDs.
Example: {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
Returns
-------
dict[str, int]
Dictionary mapping series IDs to group IDs.
Examples
--------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'],
... 'industry': ['Retail', 'Retail', 'Wholesale', 'Wholesale',
... 'Food', 'Food', 'Services', 'Services'],
... 'value': [1, 2, 3, 4, 5, 6, 7, 8]
... })
>>> mapping = {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
>>> create_group_ids_dict_from_mapping(df, 'item_id', mapping)
{'A': 0, 'B': 0, 'C': 1, 'D': 2}
Notes
-----
- This is more flexible than create_group_ids_dict_from_category()
- Useful when multiple categories should map to the same group
- All categories in the data must be in the mapping, or will raise KeyError
"""
if not _PANDAS_AVAILABLE:
raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
# Infer category column by checking which column contains the mapping keys
category_column = None
for col in df.columns:
if col != id_column and df[col].dtype == 'object':
if any(cat in category_to_group_map for cat in df[col].unique()):
category_column = col
break
if category_column is None:
raise ValueError(
f"Could not infer category column. Available columns: {df.columns.tolist()}. "
f"Make sure one of them contains categories from the mapping: {list(category_to_group_map.keys())}"
)
# Get unique series and their category
series_categories = df.groupby(id_column, sort=False)[category_column].first()
# Create dictionary mapping series_id -> group_id using the custom mapping
group_ids_dict = {}
for series_id, category in series_categories.items():
if category not in category_to_group_map:
raise KeyError(
f"Category '{category}' for series '{series_id}' not found in mapping. "
f"Available categories in mapping: {list(category_to_group_map.keys())}"
)
group_ids_dict[series_id] = category_to_group_map[category]
return group_ids_dict
def create_manual_group_ids_dict(
series_ids: List[str],
group_assignments: List[int]
) -> dict[str, int]:
"""
Create group_ids dictionary from manual series ID and group assignment lists.
Parameters
----------
series_ids : list of str
List of series identifiers.
group_assignments : list of int
List of group IDs corresponding to each series.
Must have the same length as series_ids.
Returns
-------
dict[str, int]
Dictionary mapping series IDs to group IDs.
Examples
--------
>>> series_ids = ['store_1', 'store_2', 'store_3', 'store_4']
>>> groups = [0, 0, 1, 1] # First two together, last two together
>>> create_manual_group_ids_dict(series_ids, groups)
{'store_1': 0, 'store_2': 0, 'store_3': 1, 'store_4': 1}
Raises
------
ValueError
If series_ids and group_assignments have different lengths.
"""
if len(series_ids) != len(group_assignments):
raise ValueError(
f"Length mismatch: series_ids has {len(series_ids)} elements, "
f"but group_assignments has {len(group_assignments)} elements."
)
return dict(zip(series_ids, group_assignments))
def create_group_ids_from_category(
train: pd.DataFrame,
id_column: str,
category_column: str
) -> list[int]:
"""
Create group_ids list from a categorical column.
Parameters
----------
train : pd.DataFrame
Training data containing id_column and category_column.
id_column : str
Name of the column containing time series identifiers.
category_column : str
Name of the categorical column to group by (e.g., "region", "product_type").
Returns
-------
list[int]
List of group IDs, one per series, matching the order of series_ids.
Examples
--------
>>> train = pd.DataFrame({
... 'series_id': ['A', 'A', 'B', 'B', 'C', 'C'],
... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
... 'value': [1, 2, 3, 4, 5, 6]
... })
>>> create_group_ids_from_category(train, 'series_id', 'region')
[0, 0, 1] # A and B are North (group 0), C is South (group 1)
"""
# Get unique series and their category
series_categories = train.groupby(id_column, sort=False)[category_column].first()
# Map categories to group IDs
unique_categories = series_categories.unique()
category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
# Create group_ids list
group_ids = [category_to_group[cat] for cat in series_categories.values]
return group_ids

View file

@ -1143,3 +1143,283 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
# ============================================================================
# Tests for custom group_ids functionality
# ============================================================================
@pytest.mark.parametrize("group_ids", [[0, 0, 1], torch.tensor([0, 0, 1])])
def test_predict_with_custom_group_ids_list_and_tensor(pipeline, group_ids):
"""Test basic functionality with custom group_ids as both list and tensor."""
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
outputs = pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
assert isinstance(outputs, list) and len(outputs) == 3
for out in outputs:
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 24), dtype=torch.float32)
def test_predict_with_group_ids_univariate_batch(pipeline):
"""Test group_ids with homogeneous univariate batch."""
inputs = torch.rand(5, 1, 100)
group_ids = [0, 0, 1, 1, 2] # First two together, next two together, last one alone
outputs = pipeline.predict(inputs, prediction_length=12, group_ids=group_ids)
assert len(outputs) == 5
for out in outputs:
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 12), dtype=torch.float32)
def test_predict_with_group_ids_multivariate(pipeline):
"""Test group_ids with multivariate inputs."""
inputs = [torch.rand(2, 100), torch.rand(2, 110), torch.rand(2, 90)]
group_ids = [0, 0, 1] # First two share info, third is separate
outputs = pipeline.predict(inputs, prediction_length=16, group_ids=group_ids)
assert len(outputs) == 3
for out in outputs:
validate_tensor(out, (2, DEFAULT_MODEL_NUM_QUANTILES, 16), dtype=torch.float32)
def test_predict_with_group_ids_and_covariates(pipeline):
"""Test group_ids with covariates."""
prediction_length = 24
inputs = [
{
"target": torch.rand(100),
"past_covariates": {"temperature": torch.rand(100)},
"future_covariates": {"temperature": torch.rand(prediction_length)},
},
{
"target": torch.rand(110),
"past_covariates": {"temperature": torch.rand(110)},
"future_covariates": {"temperature": torch.rand(prediction_length)},
},
{
"target": torch.rand(90),
"past_covariates": {"temperature": torch.rand(90)},
"future_covariates": {"temperature": torch.rand(prediction_length)},
},
]
group_ids = [0, 0, 1]
outputs = pipeline.predict(inputs, prediction_length=prediction_length, group_ids=group_ids)
assert len(outputs) == 3
for out in outputs:
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length), dtype=torch.float32)
def test_predict_df_with_group_ids_dict(pipeline):
"""Test predict_df with dictionary group_ids."""
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10])
group_ids = {"A": 0, "B": 0, "C": 1} # A and B share info, C is separate
pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
assert isinstance(pred_df, pd.DataFrame)
assert len(pred_df) == 15 # 3 series * 5 predictions
assert set(pred_df["item_id"].unique()) == {"A", "B", "C"}
def test_predict_df_with_partial_group_ids(pipeline):
"""Test predict_df when only some series have group_ids assigned."""
df = create_df(series_ids=["A", "B", "C", "D"], n_points=[10, 10, 10, 10])
group_ids = {"A": 0, "B": 0} # Only A and B specified, C and D should get unique IDs
pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
assert isinstance(pred_df, pd.DataFrame)
assert len(pred_df) == 20 # 4 series * 5 predictions
assert set(pred_df["item_id"].unique()) == {"A", "B", "C", "D"}
def test_predict_df_with_group_ids_and_covariates(pipeline):
"""Test predict_df with both group_ids and covariates."""
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10], covariates=["temp"])
future_df = create_future_df(
get_forecast_start_times(df), series_ids=["A", "B", "C"], n_points=[5, 5, 5], covariates=["temp"]
)
group_ids = {"A": 0, "B": 0, "C": 1}
pred_df = pipeline.predict_df(df, future_df=future_df, prediction_length=5, group_ids=group_ids)
assert isinstance(pred_df, pd.DataFrame)
assert len(pred_df) == 15
def test_group_ids_cross_learning_mutual_exclusion(pipeline):
"""Test that error is raised when both group_ids and cross_learning are specified."""
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
group_ids = [0, 0, 1]
with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids, cross_learning=True)
def test_predict_df_group_ids_cross_learning_mutual_exclusion(pipeline):
"""Test that predict_df raises error when both group_ids and cross_learning are specified."""
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
group_ids = {"A": 0, "B": 0}
with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids, cross_learning=True)
def test_group_ids_length_mismatch_raises_error(pipeline):
"""Test that error is raised when group_ids length doesn't match inputs."""
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
group_ids = [0, 0] # Only 2 IDs for 3 inputs
with pytest.raises(ValueError, match="length .* must match number of tasks"):
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
def test_group_ids_negative_values_raises_error(pipeline):
"""Test that error is raised when group_ids contain negative values."""
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
group_ids = [0, -1, 1] # Negative ID not allowed
with pytest.raises(ValueError, match="must contain only non-negative integers"):
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
def test_group_ids_invalid_type_raises_error(pipeline):
"""Test that error is raised when group_ids is not list or tensor."""
inputs = [torch.rand(100), torch.rand(110)]
group_ids = "invalid" # String not allowed
with pytest.raises(TypeError, match="must be a list or torch.Tensor"):
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
def test_predict_df_group_ids_invalid_type_raises_error(pipeline):
"""Test that predict_df raises error when group_ids is not a dict."""
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
group_ids = [0, 0] # List not allowed for predict_df (needs dict)
with pytest.raises(TypeError, match="must be a dictionary"):
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
def test_predict_df_group_ids_invalid_dict_values_raises_error(pipeline):
"""Test that predict_df raises error when group_ids dict has negative values."""
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
group_ids = {"A": 0, "B": -1} # Negative value not allowed
with pytest.raises(ValueError, match="must be non-negative integers"):
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
def test_predict_df_group_ids_warns_unknown_series(pipeline):
"""Test that predict_df warns when group_ids contains unknown series IDs."""
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
group_ids = {"A": 0, "B": 0, "X": 1, "Y": 1} # X and Y don't exist
with pytest.warns(UserWarning, match="not found in the dataframe"):
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
# ============================================================================
# Tests for group_ids helper functions
# ============================================================================
def test_create_group_ids_dict_from_category():
"""Test create_group_ids_dict_from_category helper function."""
from chronos import create_group_ids_dict_from_category
df = pd.DataFrame(
{
"item_id": ["A", "A", "B", "B", "C", "C"],
"region": ["North", "North", "North", "North", "South", "South"],
"value": [1, 2, 3, 4, 5, 6],
}
)
result = create_group_ids_dict_from_category(df, "item_id", "region")
assert isinstance(result, dict)
assert result == {"A": 0, "B": 0, "C": 1}
def test_create_group_ids_dict_from_mapping():
"""Test create_group_ids_dict_from_mapping helper function."""
from chronos import create_group_ids_dict_from_mapping
df = pd.DataFrame(
{
"item_id": ["A", "A", "B", "B", "C", "C", "D", "D"],
"industry": ["Retail", "Retail", "Wholesale", "Wholesale", "Food", "Food", "Services", "Services"],
"value": [1, 2, 3, 4, 5, 6, 7, 8],
}
)
mapping = {"Retail": 0, "Wholesale": 0, "Food": 1, "Services": 2}
result = create_group_ids_dict_from_mapping(df, "item_id", mapping)
assert isinstance(result, dict)
assert result == {"A": 0, "B": 0, "C": 1, "D": 2}
def test_create_group_ids_dict_from_mapping_missing_category_raises_error():
"""Test that create_group_ids_dict_from_mapping raises error for unmapped categories."""
from chronos import create_group_ids_dict_from_mapping
df = pd.DataFrame(
{
"item_id": ["A", "A", "B", "B"],
"industry": ["Retail", "Retail", "Tech", "Tech"],
"value": [1, 2, 3, 4],
}
)
mapping = {"Retail": 0} # Missing "Tech"
with pytest.raises(KeyError, match="not found in mapping"):
create_group_ids_dict_from_mapping(df, "item_id", mapping)
def test_create_manual_group_ids_dict():
"""Test create_manual_group_ids_dict helper function."""
from chronos import create_manual_group_ids_dict
series_ids = ["store_1", "store_2", "store_3", "store_4"]
groups = [0, 0, 1, 1]
result = create_manual_group_ids_dict(series_ids, groups)
assert isinstance(result, dict)
assert result == {"store_1": 0, "store_2": 0, "store_3": 1, "store_4": 1}
def test_create_manual_group_ids_dict_length_mismatch_raises_error():
"""Test that create_manual_group_ids_dict raises error on length mismatch."""
from chronos import create_manual_group_ids_dict
series_ids = ["A", "B", "C"]
groups = [0, 0] # Mismatched length
with pytest.raises(ValueError, match="Length mismatch"):
create_manual_group_ids_dict(series_ids, groups)
def test_create_group_ids_from_category():
"""Test create_group_ids_from_category helper function for list output."""
from chronos import create_group_ids_from_category
df = pd.DataFrame(
{
"item_id": ["A", "A", "B", "B", "C", "C"],
"region": ["North", "North", "North", "North", "South", "South"],
"value": [1, 2, 3, 4, 5, 6],
}
)
result = create_group_ids_from_category(df, "item_id", "region")
assert isinstance(result, list)
assert result == [0, 0, 1] # A and B are North (0), C is South (1)