mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Add predict_df support for Chronos and Chronos-Bolt (#371)
*Issue #, if available:* *Description of changes:* This PR adds `predict_df` to the base pipeline which enables pandas support for the univariate Chronos and Chronos-Bolt models. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
22c036bf07
commit
e48f48071f
10 changed files with 671 additions and 381 deletions
|
|
@ -17,8 +17,10 @@ import torch
|
|||
if TYPE_CHECKING:
|
||||
import datasets
|
||||
import fev
|
||||
import pandas as pd
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
from .utils import left_pad_and_stack_1D
|
||||
|
||||
|
||||
|
|
@ -53,6 +55,14 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
# for easy access to the inner HF-style model
|
||||
self.inner_model = inner_model
|
||||
|
||||
@property
|
||||
def model_context_length(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def model_prediction_length(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
if isinstance(context, list):
|
||||
context = left_pad_and_stack_1D(context)
|
||||
|
|
@ -122,6 +132,106 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict_df(
|
||||
self,
|
||||
df: "pd.DataFrame",
|
||||
*,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
target: str = "target",
|
||||
prediction_length: int | None = None,
|
||||
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
**predict_kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Perform forecasting on time series data in a long-format pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
Time series data in long format with an id column, a timestamp, and one target column.
|
||||
Any other columns, if present, will be ignored
|
||||
id_column
|
||||
The name of the column which contains the unique time series identifiers, by default "item_id"
|
||||
timestamp_column
|
||||
The name of the column which contains timestamps, by default "timestamp"
|
||||
All time series in the dataframe must have regular timestamps with the same frequency (no gaps)
|
||||
target
|
||||
The name of the column which contains the target variables to be forecasted, by default "target"
|
||||
prediction_length
|
||||
Number of steps to predict for each time series
|
||||
quantile_levels
|
||||
Quantile levels to compute
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
Returns
|
||||
-------
|
||||
The forecasts dataframe generated by the model with the following columns
|
||||
- `id_column`: The time series ID
|
||||
- `timestamp_column`: Future timestamps
|
||||
- "target_name": The name of the target column
|
||||
- "predictions": The point predictions generated by the model
|
||||
- One column for predictions at each quantile level in `quantile_levels`
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
from .df_utils import convert_df_input_to_list_of_dicts_input
|
||||
except ImportError:
|
||||
raise ImportError("pandas is required for predict_df. Please install it with `pip install pandas`.")
|
||||
|
||||
if not isinstance(target, str):
|
||||
raise ValueError(
|
||||
f"Expected `target` to be str, but found {type(target)}. {self.__class__.__name__} only supports univariate forecasting."
|
||||
)
|
||||
|
||||
if prediction_length is None:
|
||||
prediction_length = self.model_prediction_length
|
||||
|
||||
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=None,
|
||||
id_column=id_column,
|
||||
timestamp_column=timestamp_column,
|
||||
target_columns=[target],
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
|
||||
# NOTE: any covariates, if present, are ignored here
|
||||
context = [torch.tensor(item["target"]).squeeze(0) for item in inputs] # squeeze the extra variate dim
|
||||
|
||||
# Generate forecasts
|
||||
quantiles, mean = self.predict_quantiles(
|
||||
inputs=context,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
limit_prediction_length=False,
|
||||
**predict_kwargs,
|
||||
)
|
||||
|
||||
quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
|
||||
mean_np = mean.numpy() # [n_series, horizon]
|
||||
|
||||
results_dfs = []
|
||||
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
|
||||
q_pred = quantiles_np[i] # (horizon, num_quantiles)
|
||||
point_pred = mean_np[i] # (horizon)
|
||||
|
||||
series_forecast_data = {id_column: series_id, timestamp_column: future_ts, "target_name": target}
|
||||
series_forecast_data["predictions"] = point_pred
|
||||
for q_idx, q_level in enumerate(quantile_levels):
|
||||
series_forecast_data[str(q_level)] = q_pred[:, q_idx]
|
||||
|
||||
results_dfs.append(pd.DataFrame(series_forecast_data))
|
||||
|
||||
predictions_df = pd.concat(results_dfs, ignore_index=True)
|
||||
predictions_df.set_index(id_column, inplace=True)
|
||||
predictions_df = predictions_df.loc[original_order]
|
||||
predictions_df.reset_index(inplace=True)
|
||||
|
||||
return predictions_df
|
||||
|
||||
def predict_fev(
|
||||
self, task: "fev.Task", batch_size: int = 32, **kwargs
|
||||
) -> tuple[list["datasets.DatasetDict"], float]:
|
||||
|
|
|
|||
|
|
@ -377,6 +377,14 @@ class ChronosPipeline(BaseChronosPipeline):
|
|||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def model_context_length(self) -> int:
|
||||
return self.model.config.context_length
|
||||
|
||||
@property
|
||||
def model_prediction_length(self) -> int:
|
||||
return self.model.config.prediction_length
|
||||
|
||||
def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
if isinstance(context, list):
|
||||
context = left_pad_and_stack_1D(context)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from torch.utils.data import IterableDataset
|
|||
if TYPE_CHECKING:
|
||||
import datasets
|
||||
import fev
|
||||
import pandas as pd
|
||||
|
||||
|
||||
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
|
||||
|
|
@ -275,308 +274,6 @@ def convert_tensor_input_to_list_of_dicts_input(tensor: TensorOrArray) -> list[d
|
|||
return output
|
||||
|
||||
|
||||
def _validate_df_types_and_cast(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple["pd.DataFrame", "pd.DataFrame | None"]:
|
||||
import pandas as pd
|
||||
|
||||
astype_dict = {}
|
||||
future_astype_dict = {}
|
||||
for col in df.columns.drop([id_column, timestamp_column]):
|
||||
col_dtype = df[col].dtype
|
||||
if col in target_columns and not pd.api.types.is_numeric_dtype(df[col]):
|
||||
raise ValueError(f"All target columns must be numeric but got {col=} with dtype={col_dtype}")
|
||||
|
||||
if (
|
||||
pd.api.types.is_object_dtype(df[col])
|
||||
or pd.api.types.is_string_dtype(df[col])
|
||||
or isinstance(col_dtype, pd.CategoricalDtype)
|
||||
):
|
||||
astype_dict[col] = "category"
|
||||
elif pd.api.types.is_numeric_dtype(df[col]) or pd.api.types.is_bool_dtype(df[col]):
|
||||
astype_dict[col] = "float32"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"All columns must contain numeric, object, category, string, or bool dtype but got {col=} with dtype={col_dtype}"
|
||||
)
|
||||
|
||||
if future_df is not None and col in future_df.columns:
|
||||
if future_df[col].dtype != col_dtype:
|
||||
raise ValueError(
|
||||
f"Column {col} in future_df has dtype {future_df[col].dtype} but column in df has dtype {col_dtype}"
|
||||
)
|
||||
future_astype_dict[col] = astype_dict[col]
|
||||
|
||||
df = df.astype(astype_dict, copy=True)
|
||||
if future_df is not None:
|
||||
future_df = future_df.astype(future_astype_dict, copy=True)
|
||||
|
||||
return df, future_df
|
||||
|
||||
|
||||
def validate_df_inputs(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple["pd.DataFrame", "pd.DataFrame | None", "pd.Timedelta", list[int], list[int] | None, np.ndarray]:
|
||||
"""
|
||||
Validates and prepares dataframe inputs passed to `Chronos2Pipeline.predict_df`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
Input dataframe containing time series data with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Timestamps for each observation
|
||||
- target_columns: One or more target variables to forecast
|
||||
- Additional columns are treated as covariates
|
||||
future_df
|
||||
Optional dataframe containing future covariate values with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Future timestamps
|
||||
- Subset of covariate columns from df
|
||||
target_columns
|
||||
Names of target columns to forecast
|
||||
prediction_length
|
||||
Number of future time steps to predict
|
||||
id_column
|
||||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- Validated and sorted input dataframe
|
||||
- Validated and sorted future dataframe (if provided)
|
||||
- Inferred frequency of the time series
|
||||
- List of series lengths from input dataframe
|
||||
- List of series lengths from future dataframe (if provided)
|
||||
- Original order of time series IDs
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If validation fails for:
|
||||
- Missing required columns
|
||||
- Invalid data types
|
||||
- Inconsistent frequencies
|
||||
- Insufficient data points
|
||||
- Mismatched series between df and future_df
|
||||
- Invalid future_df lengths
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
required_cols = [id_column, timestamp_column] + target_columns
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"df does not contain all expected columns. Missing columns: {missing_cols}")
|
||||
|
||||
if future_df is not None:
|
||||
future_required_cols = [id_column, timestamp_column]
|
||||
missing_future_cols = [col for col in future_required_cols if col not in future_df.columns]
|
||||
targets_in_future = [col for col in future_df.columns if col in target_columns]
|
||||
extra_future_cols = [col for col in future_df.columns if col not in df.columns]
|
||||
if missing_future_cols:
|
||||
raise ValueError(
|
||||
f"future_df does not contain all expected columns. Missing columns: {missing_future_cols}"
|
||||
)
|
||||
if targets_in_future:
|
||||
raise ValueError(
|
||||
f"future_df cannot contain target columns. Target columns found in future_df: {targets_in_future}"
|
||||
)
|
||||
if extra_future_cols:
|
||||
raise ValueError(f"future_df cannot contain columns not present in df. Extra columns: {extra_future_cols}")
|
||||
|
||||
df, future_df = _validate_df_types_and_cast(
|
||||
df, future_df, id_column=id_column, timestamp_column=timestamp_column, target_columns=target_columns
|
||||
)
|
||||
|
||||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
||||
# Sort and prepare df
|
||||
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
|
||||
df = df.sort_values([id_column, timestamp_column])
|
||||
|
||||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
def validate_freq(timestamps: pd.Series, series_id: str):
|
||||
freq = pd.infer_freq(timestamps)
|
||||
if not freq:
|
||||
raise ValueError(f"Could not infer frequency for series {series_id}")
|
||||
return freq
|
||||
|
||||
# Validate each series
|
||||
all_freqs = []
|
||||
start_idx = 0
|
||||
for length in series_lengths:
|
||||
if length < 3:
|
||||
series_id = df.iloc[start_idx][id_column]
|
||||
raise ValueError(
|
||||
f"Every time series must have at least 3 data points, found {length=} for series {series_id}"
|
||||
)
|
||||
|
||||
series_data = df.iloc[start_idx : start_idx + length]
|
||||
timestamps = series_data[timestamp_column]
|
||||
series_id = series_data.iloc[0][id_column]
|
||||
all_freqs.append(validate_freq(timestamps, series_id))
|
||||
start_idx += length
|
||||
|
||||
if len(set(all_freqs)) > 1:
|
||||
raise ValueError("All time series must have the same frequency")
|
||||
|
||||
inferred_freq = all_freqs[0]
|
||||
|
||||
# Sort future_df if provided and validate its series lengths
|
||||
future_series_lengths = None
|
||||
if future_df is not None:
|
||||
future_df[timestamp_column] = pd.to_datetime(future_df[timestamp_column])
|
||||
future_df = future_df.sort_values([id_column, timestamp_column])
|
||||
|
||||
# Validate that future_df contains all series from df
|
||||
context_ids = set(df[id_column].unique())
|
||||
future_ids = set(future_df[id_column].unique())
|
||||
if context_ids != future_ids:
|
||||
raise ValueError("future_df must contain the same time series IDs as df")
|
||||
|
||||
future_series_lengths = future_df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
# Validate future series lengths match prediction_length
|
||||
future_start_idx = 0
|
||||
for future_length in future_series_lengths:
|
||||
future_series_data = future_df.iloc[future_start_idx : future_start_idx + future_length]
|
||||
future_timestamps = future_series_data[timestamp_column]
|
||||
future_series_id = future_series_data.iloc[0][id_column]
|
||||
if future_length != prediction_length:
|
||||
raise ValueError(
|
||||
f"Future covariates all time series must have length {prediction_length}, got {future_length} for series {future_series_id}"
|
||||
)
|
||||
if future_length < 3 or inferred_freq != validate_freq(future_timestamps, future_series_id):
|
||||
raise ValueError(
|
||||
f"Future covariates must have the same frequency as context, found series {future_series_id} with a different frequency"
|
||||
)
|
||||
future_start_idx += future_length
|
||||
|
||||
assert len(series_lengths) == len(future_series_lengths)
|
||||
|
||||
return df, future_df, inferred_freq, series_lengths, future_series_lengths, original_order
|
||||
|
||||
|
||||
def convert_df_input_to_list_of_dicts_input(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
|
||||
"""
|
||||
Convert from dataframe input format to a list of dictionaries input format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
Input dataframe containing time series data with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Timestamps for each observation
|
||||
- target_columns: One or more target variables to forecast
|
||||
- Additional columns are treated as covariates
|
||||
future_df
|
||||
Optional dataframe containing future covariate values with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Future timestamps
|
||||
- Subset of covariate columns from df
|
||||
target_columns
|
||||
Names of target columns to forecast
|
||||
prediction_length
|
||||
Number of future time steps to predict
|
||||
id_column
|
||||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- List of dictionaries in the format expected by `Chronos2Pipeline.predict`
|
||||
- Original order of time series IDs
|
||||
- Dictionary mapping series IDs to future time index
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df, future_df, freq, series_lengths, future_series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
timestamp_column=timestamp_column,
|
||||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
|
||||
# Convert to list of dicts format
|
||||
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
|
||||
prediction_timestamps: dict[str, pd.DatetimeIndex] = {}
|
||||
start_idx: int = 0
|
||||
future_start_idx: int = 0
|
||||
|
||||
for i, length in enumerate(series_lengths):
|
||||
series_data = df.iloc[start_idx : start_idx + length]
|
||||
# Extract target(s)
|
||||
target_data = series_data[target_columns].to_numpy().T # Shape: (n_targets, history_length)
|
||||
task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_data}
|
||||
|
||||
# Generate future timestamps
|
||||
series_id = series_data.iloc[0][id_column]
|
||||
last_timestamp = series_data[timestamp_column].iloc[-1]
|
||||
future_ts = pd.date_range(start=last_timestamp, periods=prediction_length + 1, freq=freq)[1:]
|
||||
prediction_timestamps[series_id] = future_ts
|
||||
|
||||
# Handle covariates if present
|
||||
covariate_cols = [
|
||||
col for col in series_data.columns if col not in [id_column, timestamp_column] + target_columns
|
||||
]
|
||||
|
||||
if covariate_cols:
|
||||
past_covariates = {col: series_data[col].to_numpy() for col in covariate_cols}
|
||||
task["past_covariates"] = past_covariates
|
||||
|
||||
# Handle future covariates
|
||||
if future_df is not None:
|
||||
assert future_series_lengths is not None
|
||||
future_length = future_series_lengths[i]
|
||||
future_data = future_df.iloc[future_start_idx : future_start_idx + future_length]
|
||||
assert future_data[timestamp_column].iloc[0] == future_ts[0], (
|
||||
f"the first timestamp in future_df must be the first forecast timestamp, found mismatch "
|
||||
f"({future_data[timestamp_column].iloc[0]} != {future_ts[0]}) in series {series_id}"
|
||||
)
|
||||
|
||||
if len(future_data) > 0:
|
||||
future_covariates = {
|
||||
col: future_data[col].to_numpy() for col in covariate_cols if col in future_data.columns
|
||||
}
|
||||
if future_covariates:
|
||||
task["future_covariates"] = future_covariates
|
||||
future_start_idx += future_length
|
||||
|
||||
inputs.append(task)
|
||||
start_idx += length
|
||||
|
||||
assert len(inputs) == len(series_lengths)
|
||||
|
||||
return inputs, original_order, prediction_timestamps
|
||||
|
||||
|
||||
def _cast_fev_features(
|
||||
past_data: "datasets.Dataset",
|
||||
future_data: "datasets.Dataset",
|
||||
|
|
|
|||
|
|
@ -20,12 +20,8 @@ from transformers import AutoConfig
|
|||
import chronos.chronos2
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.chronos2 import Chronos2Model
|
||||
from chronos.chronos2.dataset import (
|
||||
Chronos2Dataset,
|
||||
DatasetMode,
|
||||
TensorOrArray,
|
||||
convert_df_input_to_list_of_dicts_input,
|
||||
)
|
||||
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray
|
||||
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
|
||||
from chronos.utils import interpolate_quantiles, weighted_quantile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -541,9 +537,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
output_patch_size=self.model_output_patch_size,
|
||||
mode=DatasetMode.TEST,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False
|
||||
)
|
||||
test_loader = DataLoader(test_dataset, batch_size=None, pin_memory=True, shuffle=False, drop_last=False)
|
||||
|
||||
all_predictions: list[torch.Tensor] = []
|
||||
for batch in test_loader:
|
||||
|
|
|
|||
|
|
@ -407,8 +407,14 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
super().__init__(inner_model=model) # type: ignore
|
||||
self.model = model
|
||||
self.model_context_length: int = self.model.config.chronos_config["context_length"]
|
||||
self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
|
||||
|
||||
@property
|
||||
def model_context_length(self) -> int:
|
||||
return self.model.chronos_config.context_length
|
||||
|
||||
@property
|
||||
def model_prediction_length(self) -> int:
|
||||
return self.model.chronos_config.prediction_length
|
||||
|
||||
@property
|
||||
def quantiles(self) -> List[float]:
|
||||
|
|
|
|||
314
src/chronos/df_utils.py
Normal file
314
src/chronos/df_utils.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def _validate_df_types_and_cast(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple["pd.DataFrame", "pd.DataFrame | None"]:
|
||||
import pandas as pd
|
||||
|
||||
astype_dict = {}
|
||||
future_astype_dict = {}
|
||||
for col in df.columns.drop([id_column, timestamp_column]):
|
||||
col_dtype = df[col].dtype
|
||||
if col in target_columns and not pd.api.types.is_numeric_dtype(df[col]):
|
||||
raise ValueError(f"All target columns must be numeric but got {col=} with dtype={col_dtype}")
|
||||
|
||||
if (
|
||||
pd.api.types.is_object_dtype(df[col])
|
||||
or pd.api.types.is_string_dtype(df[col])
|
||||
or isinstance(col_dtype, pd.CategoricalDtype)
|
||||
):
|
||||
astype_dict[col] = "category"
|
||||
elif pd.api.types.is_numeric_dtype(df[col]) or pd.api.types.is_bool_dtype(df[col]):
|
||||
astype_dict[col] = "float32"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"All columns must contain numeric, object, category, string, or bool dtype but got {col=} with dtype={col_dtype}"
|
||||
)
|
||||
|
||||
if future_df is not None and col in future_df.columns:
|
||||
if future_df[col].dtype != col_dtype:
|
||||
raise ValueError(
|
||||
f"Column {col} in future_df has dtype {future_df[col].dtype} but column in df has dtype {col_dtype}"
|
||||
)
|
||||
future_astype_dict[col] = astype_dict[col]
|
||||
|
||||
df = df.astype(astype_dict, copy=True)
|
||||
if future_df is not None:
|
||||
future_df = future_df.astype(future_astype_dict, copy=True)
|
||||
|
||||
return df, future_df
|
||||
|
||||
|
||||
def validate_df_inputs(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple["pd.DataFrame", "pd.DataFrame | None", "pd.Timedelta", list[int], list[int] | None, np.ndarray]:
|
||||
"""
|
||||
Validates and prepares dataframe inputs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
Input dataframe containing time series data with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Timestamps for each observation
|
||||
- target_columns: One or more target variables to forecast
|
||||
- Additional columns are treated as covariates
|
||||
future_df
|
||||
Optional dataframe containing future covariate values with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Future timestamps
|
||||
- Subset of covariate columns from df
|
||||
target_columns
|
||||
Names of target columns to forecast
|
||||
prediction_length
|
||||
Number of future time steps to predict
|
||||
id_column
|
||||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- Validated and sorted input dataframe
|
||||
- Validated and sorted future dataframe (if provided)
|
||||
- Inferred frequency of the time series
|
||||
- List of series lengths from input dataframe
|
||||
- List of series lengths from future dataframe (if provided)
|
||||
- Original order of time series IDs
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If validation fails for:
|
||||
- Missing required columns
|
||||
- Invalid data types
|
||||
- Inconsistent frequencies
|
||||
- Insufficient data points
|
||||
- Mismatched series between df and future_df
|
||||
- Invalid future_df lengths
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
required_cols = [id_column, timestamp_column] + target_columns
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"df does not contain all expected columns. Missing columns: {missing_cols}")
|
||||
|
||||
if future_df is not None:
|
||||
future_required_cols = [id_column, timestamp_column]
|
||||
missing_future_cols = [col for col in future_required_cols if col not in future_df.columns]
|
||||
targets_in_future = [col for col in future_df.columns if col in target_columns]
|
||||
extra_future_cols = [col for col in future_df.columns if col not in df.columns]
|
||||
if missing_future_cols:
|
||||
raise ValueError(
|
||||
f"future_df does not contain all expected columns. Missing columns: {missing_future_cols}"
|
||||
)
|
||||
if targets_in_future:
|
||||
raise ValueError(
|
||||
f"future_df cannot contain target columns. Target columns found in future_df: {targets_in_future}"
|
||||
)
|
||||
if extra_future_cols:
|
||||
raise ValueError(f"future_df cannot contain columns not present in df. Extra columns: {extra_future_cols}")
|
||||
|
||||
df, future_df = _validate_df_types_and_cast(
|
||||
df, future_df, id_column=id_column, timestamp_column=timestamp_column, target_columns=target_columns
|
||||
)
|
||||
|
||||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
||||
# Sort and prepare df
|
||||
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
|
||||
df = df.sort_values([id_column, timestamp_column])
|
||||
|
||||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
def validate_freq(timestamps: pd.Series, series_id: str):
|
||||
freq = pd.infer_freq(timestamps)
|
||||
if not freq:
|
||||
raise ValueError(f"Could not infer frequency for series {series_id}")
|
||||
return freq
|
||||
|
||||
# Validate each series
|
||||
all_freqs = []
|
||||
start_idx = 0
|
||||
for length in series_lengths:
|
||||
if length < 3:
|
||||
series_id = df.iloc[start_idx][id_column]
|
||||
raise ValueError(
|
||||
f"Every time series must have at least 3 data points, found {length=} for series {series_id}"
|
||||
)
|
||||
|
||||
series_data = df.iloc[start_idx : start_idx + length]
|
||||
timestamps = series_data[timestamp_column]
|
||||
series_id = series_data.iloc[0][id_column]
|
||||
all_freqs.append(validate_freq(timestamps, series_id))
|
||||
start_idx += length
|
||||
|
||||
if len(set(all_freqs)) > 1:
|
||||
raise ValueError("All time series must have the same frequency")
|
||||
|
||||
inferred_freq = all_freqs[0]
|
||||
|
||||
# Sort future_df if provided and validate its series lengths
|
||||
future_series_lengths = None
|
||||
if future_df is not None:
|
||||
future_df[timestamp_column] = pd.to_datetime(future_df[timestamp_column])
|
||||
future_df = future_df.sort_values([id_column, timestamp_column])
|
||||
|
||||
# Validate that future_df contains all series from df
|
||||
context_ids = set(df[id_column].unique())
|
||||
future_ids = set(future_df[id_column].unique())
|
||||
if context_ids != future_ids:
|
||||
raise ValueError("future_df must contain the same time series IDs as df")
|
||||
|
||||
future_series_lengths = future_df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
# Validate future series lengths match prediction_length
|
||||
future_start_idx = 0
|
||||
for future_length in future_series_lengths:
|
||||
future_series_data = future_df.iloc[future_start_idx : future_start_idx + future_length]
|
||||
future_timestamps = future_series_data[timestamp_column]
|
||||
future_series_id = future_series_data.iloc[0][id_column]
|
||||
if future_length != prediction_length:
|
||||
raise ValueError(
|
||||
f"Future covariates all time series must have length {prediction_length}, got {future_length} for series {future_series_id}"
|
||||
)
|
||||
if future_length < 3 or inferred_freq != validate_freq(future_timestamps, future_series_id):
|
||||
raise ValueError(
|
||||
f"Future covariates must have the same frequency as context, found series {future_series_id} with a different frequency"
|
||||
)
|
||||
future_start_idx += future_length
|
||||
|
||||
assert len(series_lengths) == len(future_series_lengths)
|
||||
|
||||
return df, future_df, inferred_freq, series_lengths, future_series_lengths, original_order
|
||||
|
||||
|
||||
def convert_df_input_to_list_of_dicts_input(
|
||||
df: "pd.DataFrame",
|
||||
future_df: "pd.DataFrame | None",
|
||||
target_columns: list[str],
|
||||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
|
||||
"""
|
||||
Convert from dataframe input format to a list of dictionaries input format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
Input dataframe containing time series data with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Timestamps for each observation
|
||||
- target_columns: One or more target variables to forecast
|
||||
- Additional columns are treated as covariates
|
||||
future_df
|
||||
Optional dataframe containing future covariate values with columns:
|
||||
- id_column: Identifier for each time series
|
||||
- timestamp_column: Future timestamps
|
||||
- Subset of covariate columns from df
|
||||
target_columns
|
||||
Names of target columns to forecast
|
||||
prediction_length
|
||||
Number of future time steps to predict
|
||||
id_column
|
||||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- Time series converted to list of dictionaries format
|
||||
- Original order of time series IDs
|
||||
- Dictionary mapping series IDs to future time index
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df, future_df, freq, series_lengths, future_series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
timestamp_column=timestamp_column,
|
||||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
|
||||
# Convert to list of dicts format
|
||||
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
|
||||
prediction_timestamps: dict[str, pd.DatetimeIndex] = {}
|
||||
start_idx: int = 0
|
||||
future_start_idx: int = 0
|
||||
|
||||
for i, length in enumerate(series_lengths):
|
||||
series_data = df.iloc[start_idx : start_idx + length]
|
||||
# Extract target(s)
|
||||
target_data = series_data[target_columns].to_numpy().T # Shape: (n_targets, history_length)
|
||||
task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_data}
|
||||
|
||||
# Generate future timestamps
|
||||
series_id = series_data.iloc[0][id_column]
|
||||
last_timestamp = series_data[timestamp_column].iloc[-1]
|
||||
future_ts = pd.date_range(start=last_timestamp, periods=prediction_length + 1, freq=freq)[1:]
|
||||
prediction_timestamps[series_id] = future_ts
|
||||
|
||||
# Handle covariates if present
|
||||
covariate_cols = [
|
||||
col for col in series_data.columns if col not in [id_column, timestamp_column] + target_columns
|
||||
]
|
||||
|
||||
if covariate_cols:
|
||||
past_covariates = {col: series_data[col].to_numpy() for col in covariate_cols}
|
||||
task["past_covariates"] = past_covariates
|
||||
|
||||
# Handle future covariates
|
||||
if future_df is not None:
|
||||
assert future_series_lengths is not None
|
||||
future_length = future_series_lengths[i]
|
||||
future_data = future_df.iloc[future_start_idx : future_start_idx + future_length]
|
||||
assert future_data[timestamp_column].iloc[0] == future_ts[0], (
|
||||
f"the first timestamp in future_df must be the first forecast timestamp, found mismatch "
|
||||
f"({future_data[timestamp_column].iloc[0]} != {future_ts[0]}) in series {series_id}"
|
||||
)
|
||||
|
||||
if len(future_data) > 0:
|
||||
future_covariates = {
|
||||
col: future_data[col].to_numpy() for col in covariate_cols if col in future_data.columns
|
||||
}
|
||||
if future_covariates:
|
||||
task["future_covariates"] = future_covariates
|
||||
future_start_idx += future_length
|
||||
|
||||
inputs.append(task)
|
||||
start_idx += length
|
||||
|
||||
assert len(inputs) == len(series_lengths)
|
||||
|
||||
return inputs, original_order, prediction_timestamps
|
||||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
|
@ -12,7 +14,14 @@ from chronos import (
|
|||
ChronosPipeline,
|
||||
MeanScaleUniformBins,
|
||||
)
|
||||
from test.util import validate_tensor
|
||||
from test.util import create_df, get_forecast_start_times, validate_tensor
|
||||
|
||||
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline() -> ChronosPipeline:
|
||||
return BaseChronosPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu")
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
|
|
@ -167,11 +176,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
|||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
pipeline = ChronosPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=model_dtype)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
|
|
@ -238,11 +243,7 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length: int,
|
||||
quantile_levels: list[int],
|
||||
):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
pipeline = ChronosPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=model_dtype)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
|
|
@ -284,11 +285,7 @@ def test_pipeline_predict_quantiles(
|
|||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
pipeline = ChronosPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=model_dtype)
|
||||
d_model = pipeline.model.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
|
@ -312,6 +309,88 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
validate_tensor(scale, shape=(1,), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_setup, expected_rows",
|
||||
[
|
||||
# Targets only
|
||||
({}, 6), # 2 series * 3 predictions
|
||||
# Different context lengths
|
||||
(
|
||||
{"series_ids": ["X", "Y", "Z"], "n_points": [10, 17, 56], "target_cols": ["custom_target"]},
|
||||
9,
|
||||
), # 3 series * 3 predictions
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
|
||||
def test_predict_df_works_for_valid_inputs(pipeline, context_setup, expected_rows, freq):
|
||||
prediction_length = 3
|
||||
df = create_df(**context_setup, freq=freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq)
|
||||
|
||||
series_ids = context_setup.get("series_ids", ["A", "B"])
|
||||
target_columns = context_setup.get("target_cols", ["target"])
|
||||
n_series = len(series_ids)
|
||||
n_targets = len(target_columns)
|
||||
result = pipeline.predict_df(df, target=target_columns[0], prediction_length=prediction_length)
|
||||
|
||||
assert len(result) == expected_rows
|
||||
assert "item_id" in result.columns and np.all(
|
||||
result["item_id"].to_numpy() == np.array(series_ids).repeat(n_targets * prediction_length)
|
||||
)
|
||||
assert "target_name" in result.columns and np.all(
|
||||
result["target_name"].to_numpy() == np.tile(np.array(target_columns).repeat(prediction_length), n_series)
|
||||
)
|
||||
assert "timestamp" in result.columns and np.all(
|
||||
result.groupby("item_id")["timestamp"].min().to_numpy() == pd.to_datetime(forecast_start_times).to_numpy()
|
||||
)
|
||||
assert "predictions" in result.columns
|
||||
assert all(str(q) in result.columns for q in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
||||
|
||||
|
||||
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
|
||||
df = create_df()
|
||||
# Make timestamps non-uniform for series A
|
||||
df.loc[df["item_id"] == "A", "timestamp"] = [
|
||||
"2023-01-01",
|
||||
"2023-01-02",
|
||||
"2023-01-04",
|
||||
"2023-01-05",
|
||||
"2023-01-06",
|
||||
"2023-01-07",
|
||||
"2023-01-08",
|
||||
"2023-01-09",
|
||||
"2023-01-10",
|
||||
"2023-01-11",
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="not infer frequency"):
|
||||
pipeline.predict_df(df)
|
||||
|
||||
|
||||
def test_predict_df_with_inconsistent_frequencies_raises_error(pipeline):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"],
|
||||
"timestamp": [
|
||||
"2023-01-01",
|
||||
"2023-01-02",
|
||||
"2023-01-03",
|
||||
"2023-01-04",
|
||||
"2023-01-05",
|
||||
"2023-01-01",
|
||||
"2023-02-01",
|
||||
"2023-03-01",
|
||||
"2023-04-01",
|
||||
"2023-05-01",
|
||||
],
|
||||
"target": [1.0] * 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="same frequency"):
|
||||
pipeline.predict_df(df)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
|
||||
def test_tokenizer_number_of_buckets(n_tokens):
|
||||
config = ChronosConfig(
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ import torch
|
|||
|
||||
from chronos import BaseChronosPipeline, Chronos2Pipeline
|
||||
from chronos.chronos2.config import Chronos2CoreConfig
|
||||
from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input
|
||||
from chronos.chronos2.layers import MHA
|
||||
from test.util import validate_tensor
|
||||
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
|
||||
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor
|
||||
|
||||
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
|
||||
|
||||
|
|
@ -387,39 +387,6 @@ def test_pipeline_can_evaluate_on_dummy_fev_task(pipeline, task_kwargs):
|
|||
assert isinstance(eval_summary["test_error"], float)
|
||||
|
||||
|
||||
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
|
||||
"""Helper to create test context DataFrames."""
|
||||
series_dfs = []
|
||||
for series_id, length in zip(series_ids, n_points):
|
||||
series_data = {"item_id": series_id, "timestamp": pd.date_range(end="2001-10-01", periods=length, freq=freq)}
|
||||
for target_col in target_cols:
|
||||
series_data[target_col] = np.random.randn(length)
|
||||
if covariates:
|
||||
for cov in covariates:
|
||||
series_data[cov] = np.random.randn(length)
|
||||
series_dfs.append(pd.DataFrame(series_data))
|
||||
return pd.concat(series_dfs, ignore_index=True)
|
||||
|
||||
|
||||
def create_future_df(forecast_start_times: list, series_ids=["A", "B"], n_points=[5, 5], covariates=None, freq="h"):
|
||||
"""Helper to create test future DataFrames."""
|
||||
series_dfs = []
|
||||
for series_id, length, start in zip(series_ids, n_points, forecast_start_times):
|
||||
series_data = {"item_id": series_id, "timestamp": pd.date_range(start=start, periods=length, freq=freq)}
|
||||
if covariates:
|
||||
for cov in covariates:
|
||||
series_data[cov] = np.random.randn(length)
|
||||
series_dfs.append(pd.DataFrame(series_data))
|
||||
return pd.concat(series_dfs, ignore_index=True)
|
||||
|
||||
|
||||
def get_forecast_start_times(df, freq="h"):
|
||||
context_end_times = df.groupby("item_id")["timestamp"].max()
|
||||
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
|
||||
|
||||
return forecast_start_times
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_setup, future_setup, expected_rows",
|
||||
[
|
||||
|
|
|
|||
|
|
@ -5,12 +5,21 @@ from pathlib import Path
|
|||
|
||||
import datasets
|
||||
import fev
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from chronos import BaseChronosPipeline, ChronosBoltPipeline
|
||||
from chronos.chronos_bolt import InstanceNorm, Patch
|
||||
from test.util import validate_tensor
|
||||
from test.util import create_df, get_forecast_start_times, validate_tensor
|
||||
|
||||
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos-bolt-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline() -> ChronosBoltPipeline:
|
||||
return BaseChronosPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu")
|
||||
|
||||
|
||||
def test_base_chronos_pipeline_loads_from_huggingface():
|
||||
|
|
@ -20,11 +29,7 @@ def test_base_chronos_pipeline_loads_from_huggingface():
|
|||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=torch_dtype)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
expected_num_quantiles = len(pipeline.quantiles)
|
||||
|
|
@ -84,11 +89,7 @@ def test_pipeline_predict_quantiles(
|
|||
prediction_length: int,
|
||||
quantile_levels: list[float],
|
||||
):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=torch_dtype)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
|
|
@ -127,11 +128,7 @@ def test_pipeline_predict_quantiles(
|
|||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(DUMMY_MODEL_PATH, device_map="cpu", torch_dtype=model_dtype)
|
||||
d_model = pipeline.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
|
@ -160,6 +157,88 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
|||
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_setup, expected_rows",
|
||||
[
|
||||
# Targets only
|
||||
({}, 6), # 2 series * 3 predictions
|
||||
# Different context lengths
|
||||
(
|
||||
{"series_ids": ["X", "Y", "Z"], "n_points": [10, 17, 56], "target_cols": ["custom_target"]},
|
||||
9,
|
||||
), # 3 series * 3 predictions
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
|
||||
def test_predict_df_works_for_valid_inputs(pipeline, context_setup, expected_rows, freq):
|
||||
prediction_length = 3
|
||||
df = create_df(**context_setup, freq=freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq)
|
||||
|
||||
series_ids = context_setup.get("series_ids", ["A", "B"])
|
||||
target_columns = context_setup.get("target_cols", ["target"])
|
||||
n_series = len(series_ids)
|
||||
n_targets = len(target_columns)
|
||||
result = pipeline.predict_df(df, target=target_columns[0], prediction_length=prediction_length)
|
||||
|
||||
assert len(result) == expected_rows
|
||||
assert "item_id" in result.columns and np.all(
|
||||
result["item_id"].to_numpy() == np.array(series_ids).repeat(n_targets * prediction_length)
|
||||
)
|
||||
assert "target_name" in result.columns and np.all(
|
||||
result["target_name"].to_numpy() == np.tile(np.array(target_columns).repeat(prediction_length), n_series)
|
||||
)
|
||||
assert "timestamp" in result.columns and np.all(
|
||||
result.groupby("item_id")["timestamp"].min().to_numpy() == pd.to_datetime(forecast_start_times).to_numpy()
|
||||
)
|
||||
assert "predictions" in result.columns
|
||||
assert all(str(q) in result.columns for q in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
||||
|
||||
|
||||
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
|
||||
df = create_df()
|
||||
# Make timestamps non-uniform for series A
|
||||
df.loc[df["item_id"] == "A", "timestamp"] = [
|
||||
"2023-01-01",
|
||||
"2023-01-02",
|
||||
"2023-01-04",
|
||||
"2023-01-05",
|
||||
"2023-01-06",
|
||||
"2023-01-07",
|
||||
"2023-01-08",
|
||||
"2023-01-09",
|
||||
"2023-01-10",
|
||||
"2023-01-11",
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="not infer frequency"):
|
||||
pipeline.predict_df(df)
|
||||
|
||||
|
||||
def test_predict_df_with_inconsistent_frequencies_raises_error(pipeline):
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"item_id": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"],
|
||||
"timestamp": [
|
||||
"2023-01-01",
|
||||
"2023-01-02",
|
||||
"2023-01-03",
|
||||
"2023-01-04",
|
||||
"2023-01-05",
|
||||
"2023-01-01",
|
||||
"2023-02-01",
|
||||
"2023-03-01",
|
||||
"2023-04-01",
|
||||
"2023-05-01",
|
||||
],
|
||||
"target": [1.0] * 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="same frequency"):
|
||||
pipeline.predict_df(df)
|
||||
|
||||
|
||||
# The following tests have been taken from
|
||||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
|
||||
# Author: Caner Turkmen <atturkm@amazon.com>
|
||||
|
|
|
|||
36
test/util.py
36
test/util.py
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -9,3 +11,37 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[tor
|
|||
|
||||
if dtype is not None:
|
||||
assert a.dtype == dtype
|
||||
|
||||
|
||||
|
||||
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
|
||||
"""Helper to create test context DataFrames."""
|
||||
series_dfs = []
|
||||
for series_id, length in zip(series_ids, n_points):
|
||||
series_data = {"item_id": series_id, "timestamp": pd.date_range(end="2001-10-01", periods=length, freq=freq)}
|
||||
for target_col in target_cols:
|
||||
series_data[target_col] = np.random.randn(length)
|
||||
if covariates:
|
||||
for cov in covariates:
|
||||
series_data[cov] = np.random.randn(length)
|
||||
series_dfs.append(pd.DataFrame(series_data))
|
||||
return pd.concat(series_dfs, ignore_index=True)
|
||||
|
||||
|
||||
def create_future_df(forecast_start_times: list, series_ids=["A", "B"], n_points=[5, 5], covariates=None, freq="h"):
|
||||
"""Helper to create test future DataFrames."""
|
||||
series_dfs = []
|
||||
for series_id, length, start in zip(series_ids, n_points, forecast_start_times):
|
||||
series_data = {"item_id": series_id, "timestamp": pd.date_range(start=start, periods=length, freq=freq)}
|
||||
if covariates:
|
||||
for cov in covariates:
|
||||
series_data[cov] = np.random.randn(length)
|
||||
series_dfs.append(pd.DataFrame(series_data))
|
||||
return pd.concat(series_dfs, ignore_index=True)
|
||||
|
||||
|
||||
def get_forecast_start_times(df, freq="h"):
|
||||
context_end_times = df.groupby("item_id")["timestamp"].max()
|
||||
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
|
||||
|
||||
return forecast_start_times
|
||||
Loading…
Reference in a new issue