mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Update logic
This commit is contained in:
parent
9a46c39eb9
commit
76463ed744
4 changed files with 49 additions and 63 deletions
|
|
@ -168,9 +168,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
|
||||
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
|
||||
freq
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
|
||||
instead of inferring it from the data. This is useful when you already know the frequency and want to
|
||||
skip the inference overhead.
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
|
||||
validate_inputs=False. When provided, skips frequency inference from the data.
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
|
|
|
|||
|
|
@ -868,10 +868,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
|
||||
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
|
||||
freq
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
|
||||
instead of inferring it from the data. This is useful when you already know the frequency and want to
|
||||
skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted
|
||||
from future_df when it's available.
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
|
||||
validate_inputs=False. When provided, skips frequency inference from the data.
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
|
|
|
|||
|
|
@ -231,9 +231,8 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
freq
|
||||
Frequency string for timestamp generation. If provided, this frequency is used
|
||||
instead of inferring it from the data. Only used when future_df is not provided,
|
||||
since timestamps are extracted from future_df when it's available.
|
||||
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used
|
||||
when validate_inputs=False. When provided, skips frequency inference from the data.
|
||||
validate_inputs
|
||||
When True, the dataframe(s) will be validated before conversion
|
||||
|
||||
|
|
@ -247,8 +246,16 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
|
||||
import pandas as pd
|
||||
|
||||
if freq is not None and validate_inputs:
|
||||
raise ValueError(
|
||||
"freq can only be provided when validate_inputs=False. "
|
||||
"When using freq with validate_inputs=False, you must ensure: "
|
||||
"df (and future_df if provided) are sorted by (id_column, timestamp_column); "
|
||||
"future_df (if provided) contains exactly prediction_length rows per item."
|
||||
)
|
||||
|
||||
if validate_inputs:
|
||||
df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
|
||||
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
|
|
@ -256,9 +263,6 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
# Use provided freq if available, otherwise use inferred freq
|
||||
if freq is None:
|
||||
freq = inferred_freq
|
||||
else:
|
||||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
|
@ -286,36 +290,36 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
|
||||
indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64")
|
||||
target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df))
|
||||
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
|
||||
offset = pd.tseries.frequencies.to_offset(freq)
|
||||
with warnings.catch_warnings():
|
||||
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
|
||||
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
|
||||
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
|
||||
prediction_timestamps_array = pd.DatetimeIndex(
|
||||
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
|
||||
)
|
||||
|
||||
past_covariates_dict = {
|
||||
col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
|
||||
}
|
||||
future_covariates_dict = {}
|
||||
|
||||
if future_df is not None:
|
||||
# Use timestamps from future_df
|
||||
prediction_timestamps_flat = pd.DatetimeIndex(future_df[timestamp_column])
|
||||
for col in future_df.columns.drop([id_column, timestamp_column]):
|
||||
future_covariates_dict[col] = future_df[col].to_numpy()
|
||||
else:
|
||||
# Generate timestamps from freq
|
||||
assert freq is not None, "freq must be provided or inferred when future_df is not provided"
|
||||
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
|
||||
offset = pd.tseries.frequencies.to_offset(freq)
|
||||
with warnings.catch_warnings():
|
||||
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
|
||||
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
|
||||
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
|
||||
prediction_timestamps_flat = pd.DatetimeIndex(
|
||||
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
|
||||
)
|
||||
if validate_inputs:
|
||||
if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
|
||||
raise ValueError(
|
||||
"future_df timestamps do not match the expected prediction timestamps. "
|
||||
"You can disable this check by setting `validate_inputs=False`"
|
||||
)
|
||||
|
||||
for i in range(len(series_lengths)):
|
||||
start_idx, end_idx = indptr[i], indptr[i + 1]
|
||||
future_start_idx, future_end_idx = i * prediction_length, (i + 1) * prediction_length
|
||||
|
||||
series_id = df[id_column].iloc[start_idx]
|
||||
prediction_timestamps[series_id] = prediction_timestamps_flat[future_start_idx:future_end_idx]
|
||||
prediction_timestamps[series_id] = prediction_timestamps_array[future_start_idx:future_end_idx]
|
||||
task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_array[:, start_idx:end_idx]}
|
||||
|
||||
if len(past_covariates_dict) > 0:
|
||||
|
|
|
|||
|
|
@ -364,10 +364,24 @@ def test_convert_df_preserves_all_values_with_random_inputs():
|
|||
# Tests for freq parameter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("validate_inputs", [True, False])
|
||||
def test_convert_df_with_freq_and_validate_inputs_raises_error():
|
||||
"""Test that providing freq with validate_inputs=True raises ValueError."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
|
||||
|
||||
with pytest.raises(ValueError, match="freq can only be provided when validate_inputs=False"):
|
||||
convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=None,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
freq="h",
|
||||
validate_inputs=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_future_df", [True, False])
|
||||
def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
|
||||
"""Test that provided freq works with different combinations of validate_inputs and future_df."""
|
||||
def test_convert_df_with_freq_and_validate_inputs_false(use_future_df):
|
||||
"""Test that freq works with validate_inputs=False."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
|
||||
prediction_length = 5
|
||||
|
||||
|
|
@ -388,39 +402,10 @@ def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
|
|||
target_columns=["target"],
|
||||
prediction_length=prediction_length,
|
||||
freq="h",
|
||||
validate_inputs=validate_inputs,
|
||||
validate_inputs=False,
|
||||
)
|
||||
|
||||
assert len(inputs) == 2
|
||||
assert len(prediction_timestamps) == 2
|
||||
for series_id in ["A", "B"]:
|
||||
assert len(prediction_timestamps[series_id]) == prediction_length
|
||||
|
||||
|
||||
def test_convert_df_with_future_df_uses_future_df_timestamps():
|
||||
"""Test that timestamps from future_df are used when future_df is provided."""
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
|
||||
|
||||
# Create future_df with 2h freq (different from df's 1h freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq="2h")
|
||||
future_df = create_future_df(
|
||||
forecast_start_times=forecast_start_times,
|
||||
series_ids=["A", "B"],
|
||||
n_points=[5, 5],
|
||||
covariates=["cov1"],
|
||||
freq="2h"
|
||||
)
|
||||
|
||||
inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=future_df,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
validate_inputs=False,
|
||||
)
|
||||
|
||||
# Verify timestamps come from future_df (2h spacing)
|
||||
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
|
||||
for series_id in ["A", "B"]:
|
||||
expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"])
|
||||
pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)
|
||||
|
|
|
|||
Loading…
Reference in a new issue