From b438bed63fc6ffd2e30d400dc70776404ac434d7 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Thu, 27 Nov 2025 16:52:37 +0100 Subject: [PATCH] Chronos-2: Add option to skip dataframe validation in `predict_df` (#400) *Issue #, if available:* *Description of changes:* This PR adds a `validate_inputs ` argument to `predict_df` (defaults to `True`), which allows the user to disable dataframe validation when they know that their dataframe is in the right format. This reduces runtime by removing the input validation component, e.g., when calling this method from [AutoGluon](https://github.com/autogluon/autogluon/pull/5427), and also handles series with shorter than 3 timesteps. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/pipeline.py | 4 ++++ src/chronos/df_utils.py | 39 +++++++++++++++++++++++++------- test/test_chronos2.py | 22 +++++++++++++----- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index c2a6cce..30785ef 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -785,6 +785,7 @@ class Chronos2Pipeline(BaseChronosPipeline): prediction_length: int | None = None, quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], batch_size: int = 256, + validate_inputs: bool = True, **predict_kwargs, ) -> "pd.DataFrame": """ @@ -814,6 +815,8 @@ class Chronos2Pipeline(BaseChronosPipeline): The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates, which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch will be lower than this value, by default 256 + validate_inputs + When True, the dataframe(s) will be validated before prediction **predict_kwargs Additional arguments passed to predict_quantiles @@ -844,6 +847,7 @@ class Chronos2Pipeline(BaseChronosPipeline): timestamp_column=timestamp_column, target_columns=target, prediction_length=prediction_length, + validate_inputs=validate_inputs, ) # Generate forecasts diff --git a/src/chronos/df_utils.py b/src/chronos/df_utils.py index e63ec4f..41f68e7 100644 --- a/src/chronos/df_utils.py +++ b/src/chronos/df_utils.py @@ -215,6 +215,7 @@ def convert_df_input_to_list_of_dicts_input( prediction_length: int, id_column: str = "item_id", timestamp_column: str = "timestamp", + validate_inputs: bool = True, ) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]: """ Convert from dataframe input format to a list of dictionaries input format. @@ -240,6 +241,8 @@ def convert_df_input_to_list_of_dicts_input( Name of column containing time series identifiers timestamp_column Name of column containing timestamps + validate_inputs + When True, the dataframe(s) will be validated be conversion Returns ------- @@ -251,14 +254,34 @@ def convert_df_input_to_list_of_dicts_input( import pandas as pd - df, future_df, freq, series_lengths, original_order = validate_df_inputs( - df, - future_df=future_df, - id_column=id_column, - timestamp_column=timestamp_column, - target_columns=target_columns, - prediction_length=prediction_length, - ) + if validate_inputs: + df, future_df, freq, series_lengths, original_order = validate_df_inputs( + df, + future_df=future_df, + id_column=id_column, + timestamp_column=timestamp_column, + target_columns=target_columns, + prediction_length=prediction_length, + ) + else: + # Get the original order of time series IDs + original_order = df[id_column].unique() + + # Get series lengths + series_lengths = df[id_column].value_counts(sort=False).to_list() + + # If validation is skipped, the first freq in the dataframe is used + timestamp_index = pd.DatetimeIndex(df[timestamp_column]) + start_idx = 0 + for length in series_lengths: + if length < 3: + start_idx += length + continue + timestamps = timestamp_index[start_idx : start_idx + length] + freq = pd.infer_freq(timestamps) + break + + assert freq is not None, "validate is False, but could not infer frequency from the dataframe" # Convert to list of dicts format inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = [] diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 793c64c..abad44b 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -450,7 +450,10 @@ def test_pipeline_can_evaluate_on_dummy_fev_task(pipeline, task_kwargs): ], ) @pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"]) -def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup, expected_rows, freq): +@pytest.mark.parametrize("validate_inputs", [True, False]) +def test_predict_df_works_for_valid_inputs( + pipeline, context_setup, future_setup, expected_rows, freq, validate_inputs +): prediction_length = 3 df = create_df(**context_setup, freq=freq) forecast_start_times = get_forecast_start_times(df, freq) @@ -460,7 +463,13 @@ def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup target_columns = context_setup.get("target_cols", ["target"]) n_series = len(series_ids) n_targets = len(target_columns) - result = pipeline.predict_df(df, future_df=future_df, target=target_columns, prediction_length=prediction_length) + result = pipeline.predict_df( + df, + future_df=future_df, + target=target_columns, + prediction_length=prediction_length, + validate_inputs=validate_inputs, + ) assert len(result) == expected_rows assert "item_id" in result.columns and np.all( @@ -516,9 +525,10 @@ def test_predict_df_future_df_validation_errors(pipeline, future_data, error_mat pipeline.predict_df(df, future_df=future_df) -def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline): +@pytest.mark.parametrize("validate_inputs", [True, False]) +def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline, validate_inputs): df = create_df() - # Make timestamps non-uniform for series A + # Make timestamps non-uniform for series A (first series) df.loc[df["item_id"] == "A", "timestamp"] = [ "2023-01-01", "2023-01-02", @@ -532,8 +542,8 @@ def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline): "2023-01-11", ] - with pytest.raises(ValueError, match="not infer frequency"): - pipeline.predict_df(df) + with pytest.raises((ValueError, AssertionError), match="not infer frequency"): + pipeline.predict_df(df, validate_inputs=validate_inputs) def test_predict_df_with_inconsistent_frequencies_raises_error(pipeline):