Update logic

This commit is contained in:
Oleksandr Shchur 2026-01-19 08:41:01 +00:00
parent 9a46c39eb9
commit 76463ed744
4 changed files with 49 additions and 63 deletions

View file

@ -168,9 +168,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead.
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
validate_inputs=False. When provided, skips frequency inference from the data.
**predict_kwargs
Additional arguments passed to predict_quantiles

View file

@ -868,10 +868,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted
from future_df when it's available.
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
validate_inputs=False. When provided, skips frequency inference from the data.
**predict_kwargs
Additional arguments passed to predict_quantiles

View file

@ -231,9 +231,8 @@ def convert_df_input_to_list_of_dicts_input(
timestamp_column
Name of column containing timestamps
freq
Frequency string for timestamp generation. If provided, this frequency is used
instead of inferring it from the data. Only used when future_df is not provided,
since timestamps are extracted from future_df when it's available.
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used
when validate_inputs=False. When provided, skips frequency inference from the data.
validate_inputs
When True, the dataframe(s) will be validated before conversion
@ -247,8 +246,16 @@ def convert_df_input_to_list_of_dicts_input(
import pandas as pd
if freq is not None and validate_inputs:
raise ValueError(
"freq can only be provided when validate_inputs=False. "
"When using freq with validate_inputs=False, you must ensure: "
"df (and future_df if provided) are sorted by (id_column, timestamp_column); "
"future_df (if provided) contains exactly prediction_length rows per item."
)
if validate_inputs:
df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df,
future_df=future_df,
id_column=id_column,
@ -256,9 +263,6 @@ def convert_df_input_to_list_of_dicts_input(
target_columns=target_columns,
prediction_length=prediction_length,
)
# Use provided freq if available, otherwise use inferred freq
if freq is None:
freq = inferred_freq
else:
# Get the original order of time series IDs
original_order = df[id_column].unique()
@ -286,36 +290,36 @@ def convert_df_input_to_list_of_dicts_input(
indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64")
target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df))
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_array = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)
past_covariates_dict = {
col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
}
future_covariates_dict = {}
if future_df is not None:
# Use timestamps from future_df
prediction_timestamps_flat = pd.DatetimeIndex(future_df[timestamp_column])
for col in future_df.columns.drop([id_column, timestamp_column]):
future_covariates_dict[col] = future_df[col].to_numpy()
else:
# Generate timestamps from freq
assert freq is not None, "freq must be provided or inferred when future_df is not provided"
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_flat = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)
if validate_inputs:
if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
raise ValueError(
"future_df timestamps do not match the expected prediction timestamps. "
"You can disable this check by setting `validate_inputs=False`"
)
for i in range(len(series_lengths)):
start_idx, end_idx = indptr[i], indptr[i + 1]
future_start_idx, future_end_idx = i * prediction_length, (i + 1) * prediction_length
series_id = df[id_column].iloc[start_idx]
prediction_timestamps[series_id] = prediction_timestamps_flat[future_start_idx:future_end_idx]
prediction_timestamps[series_id] = prediction_timestamps_array[future_start_idx:future_end_idx]
task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_array[:, start_idx:end_idx]}
if len(past_covariates_dict) > 0:

View file

@ -364,10 +364,24 @@ def test_convert_df_preserves_all_values_with_random_inputs():
# Tests for freq parameter
@pytest.mark.parametrize("validate_inputs", [True, False])
def test_convert_df_with_freq_and_validate_inputs_raises_error():
"""Test that providing freq with validate_inputs=True raises ValueError."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
with pytest.raises(ValueError, match="freq can only be provided when validate_inputs=False"):
convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
freq="h",
validate_inputs=True,
)
@pytest.mark.parametrize("use_future_df", [True, False])
def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
"""Test that provided freq works with different combinations of validate_inputs and future_df."""
def test_convert_df_with_freq_and_validate_inputs_false(use_future_df):
"""Test that freq works with validate_inputs=False."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
prediction_length = 5
@ -388,39 +402,10 @@ def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
target_columns=["target"],
prediction_length=prediction_length,
freq="h",
validate_inputs=validate_inputs,
validate_inputs=False,
)
assert len(inputs) == 2
assert len(prediction_timestamps) == 2
for series_id in ["A", "B"]:
assert len(prediction_timestamps[series_id]) == prediction_length
def test_convert_df_with_future_df_uses_future_df_timestamps():
"""Test that timestamps from future_df are used when future_df is provided."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
# Create future_df with 2h freq (different from df's 1h freq)
forecast_start_times = get_forecast_start_times(df, freq="2h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[5, 5],
covariates=["cov1"],
freq="2h"
)
inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=5,
validate_inputs=False,
)
# Verify timestamps come from future_df (2h spacing)
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
for series_id in ["A", "B"]:
expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"])
pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)