From 76463ed744703987e075c190dacd9eac9afe1827 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Mon, 19 Jan 2026 08:41:01 +0000 Subject: [PATCH] Update logic --- src/chronos/base.py | 5 ++-- src/chronos/chronos2/pipeline.py | 6 ++-- src/chronos/df_utils.py | 50 +++++++++++++++++-------------- test/test_df_utils.py | 51 +++++++++++--------------------- 4 files changed, 49 insertions(+), 63 deletions(-) diff --git a/src/chronos/base.py b/src/chronos/base.py index e738717..daecf4c 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -168,9 +168,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a regular frequency, and item IDs match between past and future data. Setting to False disables these checks. freq - Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used - instead of inferring it from the data. This is useful when you already know the frequency and want to - skip the inference overhead. + Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when + validate_inputs=False. When provided, skips frequency inference from the data. **predict_kwargs Additional arguments passed to predict_quantiles diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 5537282..f066bee 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -868,10 +868,8 @@ class Chronos2Pipeline(BaseChronosPipeline): When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a regular frequency, and item IDs match between past and future data. Setting to False disables these checks. freq - Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used - instead of inferring it from the data. This is useful when you already know the frequency and want to - skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted - from future_df when it's available. + Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when + validate_inputs=False. When provided, skips frequency inference from the data. **predict_kwargs Additional arguments passed to predict_quantiles diff --git a/src/chronos/df_utils.py b/src/chronos/df_utils.py index 223ad8a..00731b2 100644 --- a/src/chronos/df_utils.py +++ b/src/chronos/df_utils.py @@ -231,9 +231,8 @@ def convert_df_input_to_list_of_dicts_input( timestamp_column Name of column containing timestamps freq - Frequency string for timestamp generation. If provided, this frequency is used - instead of inferring it from the data. Only used when future_df is not provided, - since timestamps are extracted from future_df when it's available. + Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used + when validate_inputs=False. When provided, skips frequency inference from the data. validate_inputs When True, the dataframe(s) will be validated before conversion @@ -247,8 +246,16 @@ def convert_df_input_to_list_of_dicts_input( import pandas as pd + if freq is not None and validate_inputs: + raise ValueError( + "freq can only be provided when validate_inputs=False. " + "When using freq with validate_inputs=False, you must ensure: " + "df (and future_df if provided) are sorted by (id_column, timestamp_column); " + "future_df (if provided) contains exactly prediction_length rows per item." + ) + if validate_inputs: - df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs( + df, future_df, freq, series_lengths, original_order = validate_df_inputs( df, future_df=future_df, id_column=id_column, @@ -256,9 +263,6 @@ def convert_df_input_to_list_of_dicts_input( target_columns=target_columns, prediction_length=prediction_length, ) - # Use provided freq if available, otherwise use inferred freq - if freq is None: - freq = inferred_freq else: # Get the original order of time series IDs original_order = df[id_column].unique() @@ -286,36 +290,36 @@ def convert_df_input_to_list_of_dicts_input( indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64") target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df)) + last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,) + offset = pd.tseries.frequencies.to_offset(freq) + with warnings.catch_warnings(): + # Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822 + warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) + # Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length) + prediction_timestamps_array = pd.DatetimeIndex( + np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel() + ) past_covariates_dict = { col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns } future_covariates_dict = {} - if future_df is not None: - # Use timestamps from future_df - prediction_timestamps_flat = pd.DatetimeIndex(future_df[timestamp_column]) for col in future_df.columns.drop([id_column, timestamp_column]): future_covariates_dict[col] = future_df[col].to_numpy() - else: - # Generate timestamps from freq - assert freq is not None, "freq must be provided or inferred when future_df is not provided" - last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,) - offset = pd.tseries.frequencies.to_offset(freq) - with warnings.catch_warnings(): - # Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822 - warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) - # Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length) - prediction_timestamps_flat = pd.DatetimeIndex( - np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel() - ) + if validate_inputs: + if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any(): + raise ValueError( + "future_df timestamps do not match the expected prediction timestamps. " + "You can disable this check by setting `validate_inputs=False`" + ) for i in range(len(series_lengths)): start_idx, end_idx = indptr[i], indptr[i + 1] future_start_idx, future_end_idx = i * prediction_length, (i + 1) * prediction_length series_id = df[id_column].iloc[start_idx] - prediction_timestamps[series_id] = prediction_timestamps_flat[future_start_idx:future_end_idx] + prediction_timestamps[series_id] = prediction_timestamps_array[future_start_idx:future_end_idx] task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_array[:, start_idx:end_idx]} if len(past_covariates_dict) > 0: diff --git a/test/test_df_utils.py b/test/test_df_utils.py index 47518a0..6f0438e 100644 --- a/test/test_df_utils.py +++ b/test/test_df_utils.py @@ -364,10 +364,24 @@ def test_convert_df_preserves_all_values_with_random_inputs(): # Tests for freq parameter -@pytest.mark.parametrize("validate_inputs", [True, False]) +def test_convert_df_with_freq_and_validate_inputs_raises_error(): + """Test that providing freq with validate_inputs=True raises ValueError.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h") + + with pytest.raises(ValueError, match="freq can only be provided when validate_inputs=False"): + convert_df_input_to_list_of_dicts_input( + df=df, + future_df=None, + target_columns=["target"], + prediction_length=5, + freq="h", + validate_inputs=True, + ) + + @pytest.mark.parametrize("use_future_df", [True, False]) -def test_convert_df_with_provided_freq(validate_inputs, use_future_df): - """Test that provided freq works with different combinations of validate_inputs and future_df.""" +def test_convert_df_with_freq_and_validate_inputs_false(use_future_df): + """Test that freq works with validate_inputs=False.""" df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h") prediction_length = 5 @@ -388,39 +402,10 @@ def test_convert_df_with_provided_freq(validate_inputs, use_future_df): target_columns=["target"], prediction_length=prediction_length, freq="h", - validate_inputs=validate_inputs, + validate_inputs=False, ) assert len(inputs) == 2 assert len(prediction_timestamps) == 2 for series_id in ["A", "B"]: assert len(prediction_timestamps[series_id]) == prediction_length - - -def test_convert_df_with_future_df_uses_future_df_timestamps(): - """Test that timestamps from future_df are used when future_df is provided.""" - df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h") - - # Create future_df with 2h freq (different from df's 1h freq) - forecast_start_times = get_forecast_start_times(df, freq="2h") - future_df = create_future_df( - forecast_start_times=forecast_start_times, - series_ids=["A", "B"], - n_points=[5, 5], - covariates=["cov1"], - freq="2h" - ) - - inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input( - df=df, - future_df=future_df, - target_columns=["target"], - prediction_length=5, - validate_inputs=False, - ) - - # Verify timestamps come from future_df (2h spacing) - future_df_sorted = future_df.sort_values(["item_id", "timestamp"]) - for series_id in ["A", "B"]: - expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"]) - pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)