mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-22 17:21:01 +00:00
Merge 0e621fac5c into 1f099eb265
This commit is contained in:
commit
631ebcfddd
5 changed files with 907 additions and 6 deletions
|
|
@ -1730,12 +1730,268 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7c899976",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Custom Group IDs: Examples and Use Cases\n",
|
||||
"\n",
|
||||
"Custom group IDs let you control how Chronos-2 shares information between series during prediction:\n",
|
||||
"- Default (no group_ids): each series is predicted independently (no cross-series sharing).\n",
|
||||
"- cross_learning=True: all series in the batch are jointly predicted and share information.\n",
|
||||
"- Custom group_ids: only series within the same group share information; groups remain independent.\n",
|
||||
"\n",
|
||||
"Use custom group IDs when you know meaningful clusters (e.g., geography, sector, etc.). This can boost accuracy, especially for short or noisy series, while avoiding contamination from unrelated series.\n"
|
||||
],
|
||||
"id": "e504694b8cfab326"
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-12-09T13:39:45.846727Z",
|
||||
"start_time": "2025-12-09T13:39:45.827756Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Simulate ~30 weather stations with regional clustering (North, South, Coastal)\n",
|
||||
"import numpy as np, pandas as pd\n",
|
||||
"np.random.seed(123)\n",
|
||||
"n = 200\n",
|
||||
"prediction_length = 24\n",
|
||||
"ts = pd.date_range('2020-01-01', periods=n, freq='D')\n",
|
||||
"\n",
|
||||
"north_ids = [f'station_north_{i+1}' for i in range(10)]\n",
|
||||
"south_ids = [f'station_south_{i+1}' for i in range(10)]\n",
|
||||
"coast_ids = [f'station_coast_{i+1}' for i in range(10)]\n",
|
||||
"all_ids = north_ids + south_ids + coast_ids\n",
|
||||
"regions = (['North']*len(north_ids) + ['South']*len(south_ids) + ['Coastal']*len(coast_ids))\n",
|
||||
"\n",
|
||||
"def synth_series(base, amp, noise, phase=0.0):\n",
|
||||
" t = np.arange(n)\n",
|
||||
" signal = base + amp*np.sin(2*np.pi*t/365.0 + phase)\n",
|
||||
" return (signal + noise*np.random.randn(n)).astype('float32')\n",
|
||||
"\n",
|
||||
"data_frames = []\n",
|
||||
"for sid, region in zip(all_ids, regions):\n",
|
||||
" if region == 'North':\n",
|
||||
" y = synth_series(base=5, amp=6, noise=0.8, phase=0.3)\n",
|
||||
" elif region == 'South':\n",
|
||||
" y = synth_series(base=18, amp=8, noise=0.8, phase=0.9)\n",
|
||||
" else: # Coastal\n",
|
||||
" y = synth_series(base=12, amp=3.5, noise=0.3, phase=0.5)\n",
|
||||
" df_i = pd.DataFrame({'item_id': sid, 'timestamp': ts, 'target': y, 'region': region})\n",
|
||||
" data_frames.append(df_i)\n",
|
||||
"\n",
|
||||
"weather_df = pd.concat(data_frames, ignore_index=True)\n",
|
||||
"weather_df.head()"
|
||||
],
|
||||
"id": "326d90e1a05586e8",
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
" item_id timestamp target region\n",
|
||||
"0 station_north_1 2020-01-01 5.904617 North\n",
|
||||
"1 station_north_1 2020-01-02 7.669402 North\n",
|
||||
"2 station_north_1 2020-01-03 7.195759 North\n",
|
||||
"3 station_north_1 2020-01-04 5.861607 North\n",
|
||||
"4 station_north_1 2020-01-05 6.700416 North"
|
||||
],
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>item_id</th>\n",
|
||||
" <th>timestamp</th>\n",
|
||||
" <th>target</th>\n",
|
||||
" <th>region</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>station_north_1</td>\n",
|
||||
" <td>2020-01-01</td>\n",
|
||||
" <td>5.904617</td>\n",
|
||||
" <td>North</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>station_north_1</td>\n",
|
||||
" <td>2020-01-02</td>\n",
|
||||
" <td>7.669402</td>\n",
|
||||
" <td>North</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>station_north_1</td>\n",
|
||||
" <td>2020-01-03</td>\n",
|
||||
" <td>7.195759</td>\n",
|
||||
" <td>North</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>station_north_1</td>\n",
|
||||
" <td>2020-01-04</td>\n",
|
||||
" <td>5.861607</td>\n",
|
||||
" <td>North</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>station_north_1</td>\n",
|
||||
" <td>2020-01-05</td>\n",
|
||||
" <td>6.700416</td>\n",
|
||||
" <td>North</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-12-09T13:39:49.624289Z",
|
||||
"start_time": "2025-12-09T13:39:49.251047Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 1) Using predict_df with create_group_ids_dict_from_category (group by 'region')\n",
|
||||
"from chronos import (\n",
|
||||
" create_group_ids_dict_from_category,\n",
|
||||
" create_group_ids_from_category\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"group_ids_cat = create_group_ids_dict_from_category(\n",
|
||||
" df=weather_df,\n",
|
||||
" id_column='item_id',\n",
|
||||
" category_column='region'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Split into context (train) and future truth (test)\n",
|
||||
"test_df = weather_df.groupby('item_id').tail(prediction_length)\n",
|
||||
"future_df = test_df.drop(columns=['target', 'region']).copy()\n",
|
||||
"train_df = weather_df.drop(test_df.index).drop(columns=[\"region\"])\n",
|
||||
"\n",
|
||||
"pred_df_ids = pipeline.predict_df(\n",
|
||||
" df=train_df,\n",
|
||||
" future_df=future_df,\n",
|
||||
" id_column='item_id',\n",
|
||||
" timestamp_column='timestamp',\n",
|
||||
" target='target',\n",
|
||||
" prediction_length=prediction_length,\n",
|
||||
" group_ids=group_ids_cat,\n",
|
||||
" quantile_levels=[0.5],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Compute MSE/MAE against ground truth\n",
|
||||
"eval_df = pred_df_ids.merge(\n",
|
||||
" test_df[['item_id', 'timestamp', 'target']],\n",
|
||||
" on=['item_id', 'timestamp'], how='inner',\n",
|
||||
")\n",
|
||||
"y_true = eval_df['target'].to_numpy()\n",
|
||||
"y_pred = eval_df['predictions'].to_numpy()\n",
|
||||
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
|
||||
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
|
||||
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
|
||||
],
|
||||
"id": "9113a0f64f257772",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MSE=0.523, MAE=0.537\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-12-09T13:39:53.070369Z",
|
||||
"start_time": "2025-12-09T13:39:52.884440Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 2) Low-level API: predict_quantiles with group_ids list\n",
|
||||
"\n",
|
||||
"# Create group IDs as a list aligned with series_ids\n",
|
||||
"group_ids = create_group_ids_from_category(\n",
|
||||
"\t train=weather_df,\n",
|
||||
"\t id_column=\"item_id\",\n",
|
||||
"\t category_column=\"region\"\n",
|
||||
"\t)\n",
|
||||
"\n",
|
||||
"# Build inputs in the same order as series_ids\n",
|
||||
"series_ids = train_df['item_id'].unique().tolist()\n",
|
||||
"inputs_list = [\n",
|
||||
" {'target': train_df.loc[train_df['item_id']==sid, 'target'].to_numpy(dtype=np.float32) }\n",
|
||||
" for sid in series_ids\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"_, mean = pipeline.predict_quantiles(\n",
|
||||
" inputs=inputs_list,\n",
|
||||
" prediction_length=prediction_length,\n",
|
||||
" quantile_levels=[0.5],\n",
|
||||
" group_ids=group_ids,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pred_df_ids = []\n",
|
||||
"for id in series_ids:\n",
|
||||
" test_id_df = test_df[test_df['item_id']==id].copy()\n",
|
||||
" test_id_df['predictions'] = mean[series_ids.index(id)].numpy().flatten()\n",
|
||||
" test_id_df.drop(columns=['target'], inplace=True)\n",
|
||||
" pred_df_ids.append(test_id_df)\n",
|
||||
"pred_df_ids = pd.concat(pred_df_ids, ignore_index=True)\n",
|
||||
"\n",
|
||||
"# Compute MSE/MAE against ground truth\n",
|
||||
"eval_df = pred_df_ids.merge(\n",
|
||||
" test_df[['item_id', 'timestamp', 'target']],\n",
|
||||
" on=['item_id', 'timestamp'], how='inner',\n",
|
||||
")\n",
|
||||
"y_true = eval_df['target'].to_numpy()\n",
|
||||
"y_pred = eval_df['predictions'].to_numpy()\n",
|
||||
"mse_ids = float(np.mean((y_pred - y_true) ** 2))\n",
|
||||
"mae_ids = float(np.mean(np.abs(y_pred - y_true)))\n",
|
||||
"print(f'MSE={mse_ids:.3f}, MAE={mae_ids:.3f}')"
|
||||
],
|
||||
"id": "629bab4c4ea6fe53",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MSE=0.523, MAE=0.537\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 5
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,12 @@ from .chronos import (
|
|||
)
|
||||
from .chronos2 import Chronos2ForecastingConfig, Chronos2Model, Chronos2Pipeline
|
||||
from .chronos_bolt import ChronosBoltConfig, ChronosBoltPipeline
|
||||
from .utils import (
|
||||
create_group_ids_dict_from_category,
|
||||
create_group_ids_dict_from_mapping,
|
||||
create_manual_group_ids_dict,
|
||||
create_group_ids_from_category
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
|
|
@ -27,4 +33,7 @@ __all__ = [
|
|||
"Chronos2ForecastingConfig",
|
||||
"Chronos2Model",
|
||||
"Chronos2Pipeline",
|
||||
"create_group_ids_dict_from_category",
|
||||
"create_group_ids_dict_from_mapping",
|
||||
"create_manual_group_ids_dict",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -456,6 +456,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
batch_size: int = 256,
|
||||
context_length: int | None = None,
|
||||
cross_learning: bool = False,
|
||||
group_ids: list[int] | torch.Tensor | None = None,
|
||||
limit_prediction_length: bool = False,
|
||||
**kwargs,
|
||||
) -> list[torch.Tensor]:
|
||||
|
|
@ -548,6 +549,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
|
||||
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
|
||||
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
|
||||
group_ids
|
||||
Optional custom group IDs to control information sharing between time series.
|
||||
If provided, must be a list or tensor of integers with length equal to the number of tasks in `inputs`.
|
||||
Tasks with the same group ID will share information during prediction via cross-attention.
|
||||
Cannot be used together with `cross_learning=True`. By default None (each task gets unique group ID).
|
||||
Example: [0, 0, 1] means first two tasks share information, third is separate.
|
||||
limit_prediction_length
|
||||
If True, an error is raised when prediction_length is greater than model's default prediction length, by default False
|
||||
|
||||
|
|
@ -569,6 +576,40 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
stacklevel=2,
|
||||
)
|
||||
cross_learning = kwargs.pop("predict_batches_jointly")
|
||||
|
||||
# Validate group_ids and cross_learning interaction
|
||||
if group_ids is not None and cross_learning:
|
||||
raise ValueError(
|
||||
"Cannot specify both `group_ids` and `cross_learning=True`. "
|
||||
"Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
|
||||
)
|
||||
|
||||
# Convert group_ids to tensor if provided
|
||||
custom_group_ids_tensor = None
|
||||
if group_ids is not None:
|
||||
if isinstance(group_ids, list):
|
||||
# Strict type check: only integers are allowed in the list
|
||||
if not all(isinstance(x, (int, np.integer)) for x in group_ids):
|
||||
raise TypeError("`group_ids` list must contain only integers")
|
||||
if any(x < 0 for x in group_ids):
|
||||
raise ValueError("`group_ids` must contain only non-negative integers")
|
||||
custom_group_ids_tensor = torch.tensor(group_ids, dtype=torch.long)
|
||||
elif isinstance(group_ids, torch.Tensor):
|
||||
# Enforce integer dtype for tensor inputs
|
||||
if torch.is_floating_point(group_ids) or group_ids.dtype == torch.bool:
|
||||
raise TypeError("`group_ids` tensor must have an integer dtype")
|
||||
if (group_ids < 0).any():
|
||||
raise ValueError("`group_ids` must contain only non-negative integers")
|
||||
custom_group_ids_tensor = group_ids.to(dtype=torch.long).clone()
|
||||
else:
|
||||
raise TypeError(f"`group_ids` must be a list or torch.Tensor, got {type(group_ids)}")
|
||||
|
||||
# Validate length matches number of tasks
|
||||
if len(custom_group_ids_tensor) != len(inputs):
|
||||
raise ValueError(
|
||||
f"`group_ids` length ({len(custom_group_ids_tensor)}) must match number of tasks in inputs ({len(inputs)})"
|
||||
)
|
||||
|
||||
# The maximum number of output patches to generate in a single forward pass before the long-horizon heuristic kicks in. Note: A value larger
|
||||
# than the model's default max_output_patches may lead to degradation in forecast accuracy, defaults to a model-specific value
|
||||
max_output_patches = kwargs.pop("max_output_patches", self.max_output_patches)
|
||||
|
|
@ -623,6 +664,9 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
)
|
||||
|
||||
all_predictions: list[torch.Tensor] = []
|
||||
# Track the current task index for custom group ID mapping
|
||||
current_task_idx = 0
|
||||
|
||||
for batch in test_loader:
|
||||
assert batch["future_target"] is None
|
||||
batch_context = batch["context"]
|
||||
|
|
@ -630,7 +674,38 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
batch_future_covariates = batch["future_covariates"]
|
||||
batch_target_idx_ranges = batch["target_idx_ranges"]
|
||||
|
||||
if cross_learning:
|
||||
# Apply custom group IDs if provided
|
||||
if custom_group_ids_tensor is not None:
|
||||
# Determine how many tasks are in this batch
|
||||
num_tasks_in_batch = len(batch_target_idx_ranges)
|
||||
|
||||
# The key insight: batch_group_ids already maps variates to tasks
|
||||
# We just need to replace the task IDs with our custom ones
|
||||
|
||||
# Create a mapping from old task IDs to new task IDs
|
||||
old_group_ids = batch_group_ids.cpu().numpy()
|
||||
# Preserve first-appearance order of group IDs (robust to non-consecutive IDs)
|
||||
seen = set()
|
||||
unique_old_ids_list: list[int] = []
|
||||
for gid in old_group_ids.tolist():
|
||||
if gid not in seen:
|
||||
seen.add(int(gid))
|
||||
unique_old_ids_list.append(int(gid))
|
||||
|
||||
# Map old group IDs to task indices (0, 1, 2, ..., num_tasks_in_batch-1)
|
||||
# Then map those to custom group IDs
|
||||
new_group_ids = old_group_ids.copy()
|
||||
|
||||
for task_offset, old_group_id in enumerate(unique_old_ids_list):
|
||||
task_idx = current_task_idx + task_offset
|
||||
custom_group_id = custom_group_ids_tensor[task_idx].item()
|
||||
# Replace all occurrences of old_group_id with custom_group_id
|
||||
new_group_ids[old_group_ids == old_group_id] = custom_group_id
|
||||
|
||||
batch_group_ids = torch.tensor(new_group_ids, dtype=torch.long)
|
||||
current_task_idx += num_tasks_in_batch
|
||||
|
||||
elif cross_learning:
|
||||
batch_group_ids = torch.zeros_like(batch_group_ids)
|
||||
|
||||
batch_prediction = self._predict_batch(
|
||||
|
|
@ -824,6 +899,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
batch_size: int = 256,
|
||||
context_length: int | None = None,
|
||||
cross_learning: bool = False,
|
||||
group_ids: dict[str, int] | None = None,
|
||||
validate_inputs: bool = True,
|
||||
freq: str | None = None,
|
||||
**predict_kwargs,
|
||||
|
|
@ -864,6 +940,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
|
||||
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
|
||||
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
|
||||
group_ids
|
||||
Optional dictionary mapping series IDs (from id_column) to group IDs.
|
||||
Series with the same group ID will share information during prediction.
|
||||
Cannot be used together with `cross_learning=True`. By default None.
|
||||
Example: {'series_A': 0, 'series_B': 0, 'series_C': 1} means series_A and series_B share info, series_C is separate.
|
||||
If a series ID is not in the dictionary, it will be assigned a unique group ID.
|
||||
validate_inputs
|
||||
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
|
||||
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
|
||||
|
|
@ -907,6 +989,48 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
# Convert dictionary group_ids to list format matching inputs order
|
||||
group_ids_list = None
|
||||
if group_ids is not None:
|
||||
# Validate group_ids format
|
||||
if not isinstance(group_ids, dict):
|
||||
raise TypeError(f"`group_ids` must be a dictionary, got {type(group_ids)}")
|
||||
|
||||
if not all(isinstance(k, str) and isinstance(v, int) for k, v in group_ids.items()):
|
||||
raise TypeError("`group_ids` dictionary must have string keys and integer values")
|
||||
|
||||
if any(v < 0 for v in group_ids.values()):
|
||||
raise ValueError("`group_ids` values must be non-negative integers")
|
||||
|
||||
if cross_learning:
|
||||
raise ValueError(
|
||||
"Cannot specify both `group_ids` and `cross_learning=True`. "
|
||||
"Use `group_ids` to define custom groups, or `cross_learning=True` to enable full batch-wide learning."
|
||||
)
|
||||
|
||||
# Warn if series IDs in group_ids don't exist in dataframe
|
||||
series_in_df = set(df[id_column].unique())
|
||||
unknown_series = set(group_ids.keys()) - series_in_df
|
||||
if unknown_series:
|
||||
warnings.warn(
|
||||
f"The following series IDs in `group_ids` were not found in the dataframe: {unknown_series}. "
|
||||
f"They will be ignored.",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
# Create list of group IDs matching the order of inputs
|
||||
# original_order contains the series IDs in the order they appear in inputs
|
||||
group_ids_list = []
|
||||
next_auto_group_id = max(group_ids.values()) + 1 if group_ids else 0
|
||||
|
||||
for series_id in original_order:
|
||||
if series_id in group_ids:
|
||||
group_ids_list.append(group_ids[series_id])
|
||||
else:
|
||||
# Assign unique group ID to series not in the mapping
|
||||
group_ids_list.append(next_auto_group_id)
|
||||
next_auto_group_id += 1
|
||||
|
||||
# Generate forecasts
|
||||
quantiles, mean = self.predict_quantiles(
|
||||
inputs=inputs,
|
||||
|
|
@ -916,6 +1040,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
batch_size=batch_size,
|
||||
context_length=context_length,
|
||||
cross_learning=cross_learning,
|
||||
group_ids=group_ids_list,
|
||||
**predict_kwargs,
|
||||
)
|
||||
# since predict_df tasks are homogenous by input design, we can safely stack the list of tensors into a single tensor
|
||||
|
|
|
|||
|
|
@ -7,6 +7,12 @@ from typing import List
|
|||
import torch
|
||||
from einops import repeat
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
_PANDAS_AVAILABLE = True
|
||||
except ImportError:
|
||||
_PANDAS_AVAILABLE = False
|
||||
|
||||
|
||||
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
max_len = max(len(c) for c in tensors)
|
||||
|
|
@ -210,3 +216,228 @@ def weighted_quantile(
|
|||
# Reshape to original shape
|
||||
final_shape = (*orig_samples_shape[:-1], len(query_quantile_levels))
|
||||
return interpolated_quantiles.reshape(final_shape).to(dtype=orig_dtype)
|
||||
|
||||
|
||||
def create_group_ids_dict_from_category(
|
||||
df: "pd.DataFrame",
|
||||
id_column: str,
|
||||
category_column: str
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Create group_ids dictionary (for predict_df) from a categorical column.
|
||||
|
||||
This function is specifically designed for use with Chronos2Pipeline.predict_df(),
|
||||
which accepts a dictionary mapping series IDs to group IDs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
DataFrame containing id_column and category_column.
|
||||
id_column : str
|
||||
Name of the column containing time series identifiers.
|
||||
category_column : str
|
||||
Name of the categorical column to group by (e.g., "region", "product_type", "industry").
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, int]
|
||||
Dictionary mapping series IDs to group IDs.
|
||||
Series with the same category will have the same group ID.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pandas as pd
|
||||
>>> df = pd.DataFrame({
|
||||
... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C'],
|
||||
... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
|
||||
... 'value': [1, 2, 3, 4, 5, 6]
|
||||
... })
|
||||
>>> create_group_ids_dict_from_category(df, 'item_id', 'region')
|
||||
{'A': 0, 'B': 0, 'C': 1} # A and B are North (group 0), C is South (group 1)
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Use this for Chronos2Pipeline.predict_df() which expects dict[str, int]
|
||||
- The function automatically handles the order of series IDs
|
||||
- Categories are mapped to consecutive integers starting from 0
|
||||
"""
|
||||
if not _PANDAS_AVAILABLE:
|
||||
raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
|
||||
|
||||
# Get unique series and their category
|
||||
series_categories = df.groupby(id_column, sort=False)[category_column].first()
|
||||
|
||||
# Map categories to group IDs
|
||||
unique_categories = series_categories.unique()
|
||||
category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
|
||||
|
||||
# Create dictionary mapping series_id -> group_id
|
||||
group_ids_dict = {
|
||||
series_id: category_to_group[category]
|
||||
for series_id, category in series_categories.items()
|
||||
}
|
||||
|
||||
return group_ids_dict
|
||||
|
||||
|
||||
def create_group_ids_dict_from_mapping(
|
||||
df: "pd.DataFrame",
|
||||
id_column: str,
|
||||
category_to_group_map: dict[str, int]
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Create group_ids dictionary using a custom category-to-group mapping.
|
||||
|
||||
This allows you to manually specify which categories belong to which groups,
|
||||
useful when you want custom grouping logic beyond simple one-to-one category mapping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
DataFrame containing id_column and a category column.
|
||||
id_column : str
|
||||
Name of the column containing time series identifiers.
|
||||
category_to_group_map : dict[str, int]
|
||||
Dictionary mapping category names to group IDs.
|
||||
Example: {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, int]
|
||||
Dictionary mapping series IDs to group IDs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pandas as pd
|
||||
>>> df = pd.DataFrame({
|
||||
... 'item_id': ['A', 'A', 'B', 'B', 'C', 'C', 'D', 'D'],
|
||||
... 'industry': ['Retail', 'Retail', 'Wholesale', 'Wholesale',
|
||||
... 'Food', 'Food', 'Services', 'Services'],
|
||||
... 'value': [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
... })
|
||||
>>> mapping = {'Retail': 0, 'Wholesale': 0, 'Food': 1, 'Services': 2}
|
||||
>>> create_group_ids_dict_from_mapping(df, 'item_id', mapping)
|
||||
{'A': 0, 'B': 0, 'C': 1, 'D': 2}
|
||||
|
||||
Notes
|
||||
-----
|
||||
- This is more flexible than create_group_ids_dict_from_category()
|
||||
- Useful when multiple categories should map to the same group
|
||||
- All categories in the data must be in the mapping, or will raise KeyError
|
||||
"""
|
||||
if not _PANDAS_AVAILABLE:
|
||||
raise ImportError("pandas is required for this function. Please install it with `pip install pandas`.")
|
||||
|
||||
# Infer category column by checking which column contains the mapping keys
|
||||
category_column = None
|
||||
for col in df.columns:
|
||||
if col != id_column and df[col].dtype == 'object':
|
||||
if any(cat in category_to_group_map for cat in df[col].unique()):
|
||||
category_column = col
|
||||
break
|
||||
|
||||
if category_column is None:
|
||||
raise ValueError(
|
||||
f"Could not infer category column. Available columns: {df.columns.tolist()}. "
|
||||
f"Make sure one of them contains categories from the mapping: {list(category_to_group_map.keys())}"
|
||||
)
|
||||
|
||||
# Get unique series and their category
|
||||
series_categories = df.groupby(id_column, sort=False)[category_column].first()
|
||||
|
||||
# Create dictionary mapping series_id -> group_id using the custom mapping
|
||||
group_ids_dict = {}
|
||||
for series_id, category in series_categories.items():
|
||||
if category not in category_to_group_map:
|
||||
raise KeyError(
|
||||
f"Category '{category}' for series '{series_id}' not found in mapping. "
|
||||
f"Available categories in mapping: {list(category_to_group_map.keys())}"
|
||||
)
|
||||
group_ids_dict[series_id] = category_to_group_map[category]
|
||||
|
||||
return group_ids_dict
|
||||
|
||||
|
||||
def create_manual_group_ids_dict(
|
||||
series_ids: List[str],
|
||||
group_assignments: List[int]
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Create group_ids dictionary from manual series ID and group assignment lists.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
series_ids : list of str
|
||||
List of series identifiers.
|
||||
group_assignments : list of int
|
||||
List of group IDs corresponding to each series.
|
||||
Must have the same length as series_ids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, int]
|
||||
Dictionary mapping series IDs to group IDs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> series_ids = ['store_1', 'store_2', 'store_3', 'store_4']
|
||||
>>> groups = [0, 0, 1, 1] # First two together, last two together
|
||||
>>> create_manual_group_ids_dict(series_ids, groups)
|
||||
{'store_1': 0, 'store_2': 0, 'store_3': 1, 'store_4': 1}
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If series_ids and group_assignments have different lengths.
|
||||
"""
|
||||
if len(series_ids) != len(group_assignments):
|
||||
raise ValueError(
|
||||
f"Length mismatch: series_ids has {len(series_ids)} elements, "
|
||||
f"but group_assignments has {len(group_assignments)} elements."
|
||||
)
|
||||
|
||||
return dict(zip(series_ids, group_assignments))
|
||||
|
||||
def create_group_ids_from_category(
|
||||
train: pd.DataFrame,
|
||||
id_column: str,
|
||||
category_column: str
|
||||
) -> list[int]:
|
||||
"""
|
||||
Create group_ids list from a categorical column.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train : pd.DataFrame
|
||||
Training data containing id_column and category_column.
|
||||
id_column : str
|
||||
Name of the column containing time series identifiers.
|
||||
category_column : str
|
||||
Name of the categorical column to group by (e.g., "region", "product_type").
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[int]
|
||||
List of group IDs, one per series, matching the order of series_ids.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> train = pd.DataFrame({
|
||||
... 'series_id': ['A', 'A', 'B', 'B', 'C', 'C'],
|
||||
... 'region': ['North', 'North', 'North', 'North', 'South', 'South'],
|
||||
... 'value': [1, 2, 3, 4, 5, 6]
|
||||
... })
|
||||
>>> create_group_ids_from_category(train, 'series_id', 'region')
|
||||
[0, 0, 1] # A and B are North (group 0), C is South (group 1)
|
||||
"""
|
||||
# Get unique series and their category
|
||||
series_categories = train.groupby(id_column, sort=False)[category_column].first()
|
||||
|
||||
# Map categories to group IDs
|
||||
unique_categories = series_categories.unique()
|
||||
category_to_group = {cat: i for i, cat in enumerate(unique_categories)}
|
||||
|
||||
# Create group_ids list
|
||||
group_ids = [category_to_group[cat] for cat in series_categories.values]
|
||||
|
||||
return group_ids
|
||||
|
|
|
|||
|
|
@ -1143,3 +1143,283 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for custom group_ids functionality
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_ids", [[0, 0, 1], torch.tensor([0, 0, 1])])
|
||||
def test_predict_with_custom_group_ids_list_and_tensor(pipeline, group_ids):
|
||||
"""Test basic functionality with custom group_ids as both list and tensor."""
|
||||
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
|
||||
outputs = pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
|
||||
|
||||
assert isinstance(outputs, list) and len(outputs) == 3
|
||||
for out in outputs:
|
||||
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 24), dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_with_group_ids_univariate_batch(pipeline):
|
||||
"""Test group_ids with homogeneous univariate batch."""
|
||||
inputs = torch.rand(5, 1, 100)
|
||||
group_ids = [0, 0, 1, 1, 2] # First two together, next two together, last one alone
|
||||
|
||||
outputs = pipeline.predict(inputs, prediction_length=12, group_ids=group_ids)
|
||||
|
||||
assert len(outputs) == 5
|
||||
for out in outputs:
|
||||
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 12), dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_with_group_ids_multivariate(pipeline):
|
||||
"""Test group_ids with multivariate inputs."""
|
||||
inputs = [torch.rand(2, 100), torch.rand(2, 110), torch.rand(2, 90)]
|
||||
group_ids = [0, 0, 1] # First two share info, third is separate
|
||||
|
||||
outputs = pipeline.predict(inputs, prediction_length=16, group_ids=group_ids)
|
||||
|
||||
assert len(outputs) == 3
|
||||
for out in outputs:
|
||||
validate_tensor(out, (2, DEFAULT_MODEL_NUM_QUANTILES, 16), dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_with_group_ids_and_covariates(pipeline):
|
||||
"""Test group_ids with covariates."""
|
||||
prediction_length = 24
|
||||
inputs = [
|
||||
{
|
||||
"target": torch.rand(100),
|
||||
"past_covariates": {"temperature": torch.rand(100)},
|
||||
"future_covariates": {"temperature": torch.rand(prediction_length)},
|
||||
},
|
||||
{
|
||||
"target": torch.rand(110),
|
||||
"past_covariates": {"temperature": torch.rand(110)},
|
||||
"future_covariates": {"temperature": torch.rand(prediction_length)},
|
||||
},
|
||||
{
|
||||
"target": torch.rand(90),
|
||||
"past_covariates": {"temperature": torch.rand(90)},
|
||||
"future_covariates": {"temperature": torch.rand(prediction_length)},
|
||||
},
|
||||
]
|
||||
group_ids = [0, 0, 1]
|
||||
|
||||
outputs = pipeline.predict(inputs, prediction_length=prediction_length, group_ids=group_ids)
|
||||
|
||||
assert len(outputs) == 3
|
||||
for out in outputs:
|
||||
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, prediction_length), dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_df_with_group_ids_dict(pipeline):
|
||||
"""Test predict_df with dictionary group_ids."""
|
||||
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10])
|
||||
group_ids = {"A": 0, "B": 0, "C": 1} # A and B share info, C is separate
|
||||
|
||||
pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
assert isinstance(pred_df, pd.DataFrame)
|
||||
assert len(pred_df) == 15 # 3 series * 5 predictions
|
||||
assert set(pred_df["item_id"].unique()) == {"A", "B", "C"}
|
||||
|
||||
|
||||
def test_predict_df_with_partial_group_ids(pipeline):
|
||||
"""Test predict_df when only some series have group_ids assigned."""
|
||||
df = create_df(series_ids=["A", "B", "C", "D"], n_points=[10, 10, 10, 10])
|
||||
group_ids = {"A": 0, "B": 0} # Only A and B specified, C and D should get unique IDs
|
||||
|
||||
pred_df = pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
assert isinstance(pred_df, pd.DataFrame)
|
||||
assert len(pred_df) == 20 # 4 series * 5 predictions
|
||||
assert set(pred_df["item_id"].unique()) == {"A", "B", "C", "D"}
|
||||
|
||||
|
||||
def test_predict_df_with_group_ids_and_covariates(pipeline):
|
||||
"""Test predict_df with both group_ids and covariates."""
|
||||
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 10, 10], covariates=["temp"])
|
||||
future_df = create_future_df(
|
||||
get_forecast_start_times(df), series_ids=["A", "B", "C"], n_points=[5, 5, 5], covariates=["temp"]
|
||||
)
|
||||
group_ids = {"A": 0, "B": 0, "C": 1}
|
||||
|
||||
pred_df = pipeline.predict_df(df, future_df=future_df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
assert isinstance(pred_df, pd.DataFrame)
|
||||
assert len(pred_df) == 15
|
||||
|
||||
|
||||
def test_group_ids_cross_learning_mutual_exclusion(pipeline):
|
||||
"""Test that error is raised when both group_ids and cross_learning are specified."""
|
||||
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
|
||||
group_ids = [0, 0, 1]
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
|
||||
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids, cross_learning=True)
|
||||
|
||||
|
||||
def test_predict_df_group_ids_cross_learning_mutual_exclusion(pipeline):
|
||||
"""Test that predict_df raises error when both group_ids and cross_learning are specified."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
|
||||
group_ids = {"A": 0, "B": 0}
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both `group_ids` and `cross_learning=True`"):
|
||||
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids, cross_learning=True)
|
||||
|
||||
|
||||
def test_group_ids_length_mismatch_raises_error(pipeline):
|
||||
"""Test that error is raised when group_ids length doesn't match inputs."""
|
||||
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
|
||||
group_ids = [0, 0] # Only 2 IDs for 3 inputs
|
||||
|
||||
with pytest.raises(ValueError, match="length .* must match number of tasks"):
|
||||
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
|
||||
|
||||
|
||||
def test_group_ids_negative_values_raises_error(pipeline):
|
||||
"""Test that error is raised when group_ids contain negative values."""
|
||||
inputs = [torch.rand(100), torch.rand(110), torch.rand(120)]
|
||||
group_ids = [0, -1, 1] # Negative ID not allowed
|
||||
|
||||
with pytest.raises(ValueError, match="must contain only non-negative integers"):
|
||||
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
|
||||
|
||||
|
||||
def test_group_ids_invalid_type_raises_error(pipeline):
|
||||
"""Test that error is raised when group_ids is not list or tensor."""
|
||||
inputs = [torch.rand(100), torch.rand(110)]
|
||||
group_ids = "invalid" # String not allowed
|
||||
|
||||
with pytest.raises(TypeError, match="must be a list or torch.Tensor"):
|
||||
pipeline.predict(inputs, prediction_length=24, group_ids=group_ids)
|
||||
|
||||
|
||||
def test_predict_df_group_ids_invalid_type_raises_error(pipeline):
|
||||
"""Test that predict_df raises error when group_ids is not a dict."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
|
||||
group_ids = [0, 0] # List not allowed for predict_df (needs dict)
|
||||
|
||||
with pytest.raises(TypeError, match="must be a dictionary"):
|
||||
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
|
||||
def test_predict_df_group_ids_invalid_dict_values_raises_error(pipeline):
|
||||
"""Test that predict_df raises error when group_ids dict has negative values."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
|
||||
group_ids = {"A": 0, "B": -1} # Negative value not allowed
|
||||
|
||||
with pytest.raises(ValueError, match="must be non-negative integers"):
|
||||
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
|
||||
def test_predict_df_group_ids_warns_unknown_series(pipeline):
|
||||
"""Test that predict_df warns when group_ids contains unknown series IDs."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 10])
|
||||
group_ids = {"A": 0, "B": 0, "X": 1, "Y": 1} # X and Y don't exist
|
||||
|
||||
with pytest.warns(UserWarning, match="not found in the dataframe"):
|
||||
pipeline.predict_df(df, prediction_length=5, group_ids=group_ids)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for group_ids helper functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_create_group_ids_dict_from_category():
|
||||
"""Test create_group_ids_dict_from_category helper function."""
|
||||
from chronos import create_group_ids_dict_from_category
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "B", "B", "C", "C"],
|
||||
"region": ["North", "North", "North", "North", "South", "South"],
|
||||
"value": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
)
|
||||
|
||||
result = create_group_ids_dict_from_category(df, "item_id", "region")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"A": 0, "B": 0, "C": 1}
|
||||
|
||||
|
||||
def test_create_group_ids_dict_from_mapping():
|
||||
"""Test create_group_ids_dict_from_mapping helper function."""
|
||||
from chronos import create_group_ids_dict_from_mapping
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "B", "B", "C", "C", "D", "D"],
|
||||
"industry": ["Retail", "Retail", "Wholesale", "Wholesale", "Food", "Food", "Services", "Services"],
|
||||
"value": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
}
|
||||
)
|
||||
mapping = {"Retail": 0, "Wholesale": 0, "Food": 1, "Services": 2}
|
||||
|
||||
result = create_group_ids_dict_from_mapping(df, "item_id", mapping)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"A": 0, "B": 0, "C": 1, "D": 2}
|
||||
|
||||
|
||||
def test_create_group_ids_dict_from_mapping_missing_category_raises_error():
|
||||
"""Test that create_group_ids_dict_from_mapping raises error for unmapped categories."""
|
||||
from chronos import create_group_ids_dict_from_mapping
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "B", "B"],
|
||||
"industry": ["Retail", "Retail", "Tech", "Tech"],
|
||||
"value": [1, 2, 3, 4],
|
||||
}
|
||||
)
|
||||
mapping = {"Retail": 0} # Missing "Tech"
|
||||
|
||||
with pytest.raises(KeyError, match="not found in mapping"):
|
||||
create_group_ids_dict_from_mapping(df, "item_id", mapping)
|
||||
|
||||
|
||||
def test_create_manual_group_ids_dict():
|
||||
"""Test create_manual_group_ids_dict helper function."""
|
||||
from chronos import create_manual_group_ids_dict
|
||||
|
||||
series_ids = ["store_1", "store_2", "store_3", "store_4"]
|
||||
groups = [0, 0, 1, 1]
|
||||
|
||||
result = create_manual_group_ids_dict(series_ids, groups)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"store_1": 0, "store_2": 0, "store_3": 1, "store_4": 1}
|
||||
|
||||
|
||||
def test_create_manual_group_ids_dict_length_mismatch_raises_error():
|
||||
"""Test that create_manual_group_ids_dict raises error on length mismatch."""
|
||||
from chronos import create_manual_group_ids_dict
|
||||
|
||||
series_ids = ["A", "B", "C"]
|
||||
groups = [0, 0] # Mismatched length
|
||||
|
||||
with pytest.raises(ValueError, match="Length mismatch"):
|
||||
create_manual_group_ids_dict(series_ids, groups)
|
||||
|
||||
|
||||
def test_create_group_ids_from_category():
|
||||
"""Test create_group_ids_from_category helper function for list output."""
|
||||
from chronos import create_group_ids_from_category
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "B", "B", "C", "C"],
|
||||
"region": ["North", "North", "North", "North", "South", "South"],
|
||||
"value": [1, 2, 3, 4, 5, 6],
|
||||
}
|
||||
)
|
||||
|
||||
result = create_group_ids_from_category(df, "item_id", "region")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == [0, 0, 1] # A and B are North (0), C is South (1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue