mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Allow explicitly passing the frequency
This commit is contained in:
parent
f889ae6647
commit
57fb54f7ed
4 changed files with 115 additions and 29 deletions
|
|
@ -142,6 +142,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
prediction_length: int | None = None,
|
||||
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
validate_inputs: bool = True,
|
||||
freq: str | None = None,
|
||||
**predict_kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
|
|
@ -166,6 +167,10 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
validate_inputs
|
||||
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
|
||||
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
|
||||
freq
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
|
||||
instead of inferring it from the data. This is useful when you already know the frequency and want to
|
||||
skip the inference overhead.
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
|
|
@ -200,6 +205,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
timestamp_column=timestamp_column,
|
||||
target_columns=[target],
|
||||
prediction_length=prediction_length,
|
||||
freq=freq,
|
||||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -825,6 +825,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
context_length: int | None = None,
|
||||
cross_learning: bool = False,
|
||||
validate_inputs: bool = True,
|
||||
freq: str | None = None,
|
||||
**predict_kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
|
|
@ -866,6 +867,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
validate_inputs
|
||||
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
|
||||
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
|
||||
freq
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
|
||||
instead of inferring it from the data. This is useful when you already know the frequency and want to
|
||||
skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted
|
||||
from future_df when it's available.
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
|
|
@ -896,6 +902,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
timestamp_column=timestamp_column,
|
||||
target_columns=target,
|
||||
prediction_length=prediction_length,
|
||||
freq=freq,
|
||||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -203,6 +203,7 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
freq: str | None = None,
|
||||
validate_inputs: bool = True,
|
||||
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
|
||||
"""
|
||||
|
|
@ -229,8 +230,12 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
freq
|
||||
Frequency string for timestamp generation. If provided, this frequency is used
|
||||
instead of inferring it from the data. Only used when future_df is not provided,
|
||||
since timestamps are extracted from future_df when it's available.
|
||||
validate_inputs
|
||||
When True, the dataframe(s) will be validated be conversion
|
||||
When True, the dataframe(s) will be validated before conversion
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -243,7 +248,7 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
import pandas as pd
|
||||
|
||||
if validate_inputs:
|
||||
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
|
||||
df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
|
|
@ -251,6 +256,9 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
# Use provided freq if available, otherwise use inferred freq
|
||||
if freq is None:
|
||||
freq = inferred_freq
|
||||
else:
|
||||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
|
@ -258,19 +266,19 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
# If validation is skipped, the first freq in the dataframe is used
|
||||
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
|
||||
start_idx = 0
|
||||
freq = None
|
||||
for length in series_lengths:
|
||||
if length < 3:
|
||||
start_idx += length
|
||||
continue
|
||||
timestamps = timestamp_index[start_idx : start_idx + length]
|
||||
freq = pd.infer_freq(timestamps)
|
||||
break
|
||||
# If freq is not provided, infer from the first series with >= 3 points
|
||||
if freq is None:
|
||||
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
|
||||
start_idx = 0
|
||||
for length in series_lengths:
|
||||
if length < 3:
|
||||
start_idx += length
|
||||
continue
|
||||
timestamps = timestamp_index[start_idx : start_idx + length]
|
||||
freq = pd.infer_freq(timestamps)
|
||||
break
|
||||
|
||||
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
|
||||
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
|
||||
|
||||
# Convert to list of dicts format
|
||||
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
|
||||
|
|
@ -278,29 +286,29 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
|
||||
indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64")
|
||||
target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df))
|
||||
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
|
||||
offset = pd.tseries.frequencies.to_offset(freq)
|
||||
with warnings.catch_warnings():
|
||||
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
|
||||
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
|
||||
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
|
||||
prediction_timestamps_array = pd.DatetimeIndex(
|
||||
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
|
||||
)
|
||||
|
||||
past_covariates_dict = {
|
||||
col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
|
||||
}
|
||||
future_covariates_dict = {}
|
||||
|
||||
if future_df is not None:
|
||||
# Use timestamps from future_df
|
||||
prediction_timestamps_array = pd.DatetimeIndex(future_df[timestamp_column])
|
||||
for col in future_df.columns.drop([id_column, timestamp_column]):
|
||||
future_covariates_dict[col] = future_df[col].to_numpy()
|
||||
if validate_inputs:
|
||||
if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
|
||||
raise ValueError(
|
||||
"future_df timestamps do not match the expected prediction timestamps. "
|
||||
"You can disable this check by setting `validate_inputs=False`"
|
||||
)
|
||||
else:
|
||||
# Generate timestamps from freq
|
||||
assert freq is not None, "freq must be provided or inferred when future_df is not provided"
|
||||
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
|
||||
offset = pd.tseries.frequencies.to_offset(freq)
|
||||
with warnings.catch_warnings():
|
||||
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
|
||||
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
|
||||
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
|
||||
prediction_timestamps_array = pd.DatetimeIndex(
|
||||
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
|
||||
)
|
||||
|
||||
for i in range(len(series_lengths)):
|
||||
start_idx, end_idx = indptr[i], indptr[i + 1]
|
||||
|
|
|
|||
|
|
@ -359,3 +359,68 @@ def test_convert_df_preserves_all_values_with_random_inputs():
|
|||
assert len(inputs) == n_series
|
||||
assert list(original_order) == series_ids
|
||||
assert len(prediction_timestamps) == n_series
|
||||
|
||||
|
||||
# Tests for freq parameter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("validate_inputs", [True, False])
|
||||
@pytest.mark.parametrize("use_future_df", [True, False])
|
||||
def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
|
||||
"""Test that provided freq works with different combinations of validate_inputs and future_df."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
|
||||
prediction_length = 5
|
||||
|
||||
future_df = None
|
||||
if use_future_df:
|
||||
forecast_start_times = get_forecast_start_times(df, freq="h")
|
||||
future_df = create_future_df(
|
||||
forecast_start_times=forecast_start_times,
|
||||
series_ids=["A", "B"],
|
||||
n_points=[prediction_length, prediction_length],
|
||||
covariates=["cov1"],
|
||||
freq="h"
|
||||
)
|
||||
|
||||
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=future_df,
|
||||
target_columns=["target"],
|
||||
prediction_length=prediction_length,
|
||||
freq="h",
|
||||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
assert len(inputs) == 2
|
||||
assert len(prediction_timestamps) == 2
|
||||
for series_id in ["A", "B"]:
|
||||
assert len(prediction_timestamps[series_id]) == prediction_length
|
||||
|
||||
|
||||
def test_convert_df_with_future_df_uses_future_df_timestamps():
|
||||
"""Test that timestamps from future_df are used when future_df is provided."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
|
||||
|
||||
# Create future_df with 2h freq (different from df's 1h freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq="2h")
|
||||
future_df = create_future_df(
|
||||
forecast_start_times=forecast_start_times,
|
||||
series_ids=["A", "B"],
|
||||
n_points=[5, 5],
|
||||
covariates=["cov1"],
|
||||
freq="2h"
|
||||
)
|
||||
|
||||
inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=future_df,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
validate_inputs=False,
|
||||
)
|
||||
|
||||
# Verify timestamps come from future_df (2h spacing)
|
||||
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
|
||||
for series_id in ["A", "B"]:
|
||||
expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"])
|
||||
pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)
|
||||
|
|
|
|||
Loading…
Reference in a new issue