Chronos-2: Add option to skip dataframe validation in predict_df (#400)

*Issue #, if available:*

*Description of changes:* This PR adds a `validate_inputs ` argument to
`predict_df` (defaults to `True`), which allows the user to disable
dataframe validation when they know that their dataframe is in the right
format. This reduces runtime by removing the input validation component,
e.g., when calling this method from
[AutoGluon](https://github.com/autogluon/autogluon/pull/5427), and also
handles series with shorter than 3 timesteps.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-11-27 16:52:37 +01:00 committed by GitHub
parent c5907ef52e
commit b438bed63f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 14 deletions

View file

@ -785,6 +785,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
prediction_length: int | None = None,
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
batch_size: int = 256,
validate_inputs: bool = True,
**predict_kwargs,
) -> "pd.DataFrame":
"""
@ -814,6 +815,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates,
which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch
will be lower than this value, by default 256
validate_inputs
When True, the dataframe(s) will be validated before prediction
**predict_kwargs
Additional arguments passed to predict_quantiles
@ -844,6 +847,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
timestamp_column=timestamp_column,
target_columns=target,
prediction_length=prediction_length,
validate_inputs=validate_inputs,
)
# Generate forecasts

View file

@ -215,6 +215,7 @@ def convert_df_input_to_list_of_dicts_input(
prediction_length: int,
id_column: str = "item_id",
timestamp_column: str = "timestamp",
validate_inputs: bool = True,
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
"""
Convert from dataframe input format to a list of dictionaries input format.
@ -240,6 +241,8 @@ def convert_df_input_to_list_of_dicts_input(
Name of column containing time series identifiers
timestamp_column
Name of column containing timestamps
validate_inputs
When True, the dataframe(s) will be validated be conversion
Returns
-------
@ -251,14 +254,34 @@ def convert_df_input_to_list_of_dicts_input(
import pandas as pd
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df,
future_df=future_df,
id_column=id_column,
timestamp_column=timestamp_column,
target_columns=target_columns,
prediction_length=prediction_length,
)
if validate_inputs:
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df,
future_df=future_df,
id_column=id_column,
timestamp_column=timestamp_column,
target_columns=target_columns,
prediction_length=prediction_length,
)
else:
# Get the original order of time series IDs
original_order = df[id_column].unique()
# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# If validation is skipped, the first freq in the dataframe is used
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
# Convert to list of dicts format
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []

View file

@ -450,7 +450,10 @@ def test_pipeline_can_evaluate_on_dummy_fev_task(pipeline, task_kwargs):
],
)
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup, expected_rows, freq):
@pytest.mark.parametrize("validate_inputs", [True, False])
def test_predict_df_works_for_valid_inputs(
pipeline, context_setup, future_setup, expected_rows, freq, validate_inputs
):
prediction_length = 3
df = create_df(**context_setup, freq=freq)
forecast_start_times = get_forecast_start_times(df, freq)
@ -460,7 +463,13 @@ def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup
target_columns = context_setup.get("target_cols", ["target"])
n_series = len(series_ids)
n_targets = len(target_columns)
result = pipeline.predict_df(df, future_df=future_df, target=target_columns, prediction_length=prediction_length)
result = pipeline.predict_df(
df,
future_df=future_df,
target=target_columns,
prediction_length=prediction_length,
validate_inputs=validate_inputs,
)
assert len(result) == expected_rows
assert "item_id" in result.columns and np.all(
@ -516,9 +525,10 @@ def test_predict_df_future_df_validation_errors(pipeline, future_data, error_mat
pipeline.predict_df(df, future_df=future_df)
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
@pytest.mark.parametrize("validate_inputs", [True, False])
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline, validate_inputs):
df = create_df()
# Make timestamps non-uniform for series A
# Make timestamps non-uniform for series A (first series)
df.loc[df["item_id"] == "A", "timestamp"] = [
"2023-01-01",
"2023-01-02",
@ -532,8 +542,8 @@ def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
"2023-01-11",
]
with pytest.raises(ValueError, match="not infer frequency"):
pipeline.predict_df(df)
with pytest.raises((ValueError, AssertionError), match="not infer frequency"):
pipeline.predict_df(df, validate_inputs=validate_inputs)
def test_predict_df_with_inconsistent_frequencies_raises_error(pipeline):