Allow explicitly passing the frequency

This commit is contained in:
Oleksandr Shchur 2026-01-17 10:13:30 +00:00
parent f889ae6647
commit 57fb54f7ed
4 changed files with 115 additions and 29 deletions

View file

@ -142,6 +142,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
prediction_length: int | None = None,
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
@ -166,6 +167,10 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead.
**predict_kwargs
Additional arguments passed to predict_quantiles
@ -200,6 +205,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
timestamp_column=timestamp_column,
target_columns=[target],
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

View file

@ -825,6 +825,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
context_length: int | None = None,
cross_learning: bool = False,
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
@ -866,6 +867,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted
from future_df when it's available.
**predict_kwargs
Additional arguments passed to predict_quantiles
@ -896,6 +902,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
timestamp_column=timestamp_column,
target_columns=target,
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

View file

@ -203,6 +203,7 @@ def convert_df_input_to_list_of_dicts_input(
prediction_length: int,
id_column: str = "item_id",
timestamp_column: str = "timestamp",
freq: str | None = None,
validate_inputs: bool = True,
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
"""
@ -229,8 +230,12 @@ def convert_df_input_to_list_of_dicts_input(
Name of column containing time series identifiers
timestamp_column
Name of column containing timestamps
freq
Frequency string for timestamp generation. If provided, this frequency is used
instead of inferring it from the data. Only used when future_df is not provided,
since timestamps are extracted from future_df when it's available.
validate_inputs
When True, the dataframe(s) will be validated be conversion
When True, the dataframe(s) will be validated before conversion
Returns
-------
@ -243,7 +248,7 @@ def convert_df_input_to_list_of_dicts_input(
import pandas as pd
if validate_inputs:
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
df,
future_df=future_df,
id_column=id_column,
@ -251,6 +256,9 @@ def convert_df_input_to_list_of_dicts_input(
target_columns=target_columns,
prediction_length=prediction_length,
)
# Use provided freq if available, otherwise use inferred freq
if freq is None:
freq = inferred_freq
else:
# Get the original order of time series IDs
original_order = df[id_column].unique()
@ -258,19 +266,19 @@ def convert_df_input_to_list_of_dicts_input(
# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# If validation is skipped, the first freq in the dataframe is used
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
freq = None
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
# If freq is not provided, infer from the first series with >= 3 points
if freq is None:
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
# Convert to list of dicts format
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
@ -278,29 +286,29 @@ def convert_df_input_to_list_of_dicts_input(
indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64")
target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df))
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_array = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)
past_covariates_dict = {
col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
}
future_covariates_dict = {}
if future_df is not None:
# Use timestamps from future_df
prediction_timestamps_array = pd.DatetimeIndex(future_df[timestamp_column])
for col in future_df.columns.drop([id_column, timestamp_column]):
future_covariates_dict[col] = future_df[col].to_numpy()
if validate_inputs:
if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
raise ValueError(
"future_df timestamps do not match the expected prediction timestamps. "
"You can disable this check by setting `validate_inputs=False`"
)
else:
# Generate timestamps from freq
assert freq is not None, "freq must be provided or inferred when future_df is not provided"
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_array = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)
for i in range(len(series_lengths)):
start_idx, end_idx = indptr[i], indptr[i + 1]

View file

@ -359,3 +359,68 @@ def test_convert_df_preserves_all_values_with_random_inputs():
assert len(inputs) == n_series
assert list(original_order) == series_ids
assert len(prediction_timestamps) == n_series
# Tests for freq parameter
@pytest.mark.parametrize("validate_inputs", [True, False])
@pytest.mark.parametrize("use_future_df", [True, False])
def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
"""Test that provided freq works with different combinations of validate_inputs and future_df."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
prediction_length = 5
future_df = None
if use_future_df:
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[prediction_length, prediction_length],
covariates=["cov1"],
freq="h"
)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=prediction_length,
freq="h",
validate_inputs=validate_inputs,
)
assert len(inputs) == 2
assert len(prediction_timestamps) == 2
for series_id in ["A", "B"]:
assert len(prediction_timestamps[series_id]) == prediction_length
def test_convert_df_with_future_df_uses_future_df_timestamps():
"""Test that timestamps from future_df are used when future_df is provided."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
# Create future_df with 2h freq (different from df's 1h freq)
forecast_start_times = get_forecast_start_times(df, freq="2h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[5, 5],
covariates=["cov1"],
freq="2h"
)
inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=5,
validate_inputs=False,
)
# Verify timestamps come from future_df (2h spacing)
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
for series_id in ["A", "B"]:
expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"])
pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)