mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Chronos-2: Add option to skip dataframe validation in predict_df (#400)
*Issue #, if available:* *Description of changes:* This PR adds a `validate_inputs ` argument to `predict_df` (defaults to `True`), which allows the user to disable dataframe validation when they know that their dataframe is in the right format. This reduces runtime by removing the input validation component, e.g., when calling this method from [AutoGluon](https://github.com/autogluon/autogluon/pull/5427), and also handles series with shorter than 3 timesteps. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
c5907ef52e
commit
b438bed63f
3 changed files with 51 additions and 14 deletions
|
|
@ -785,6 +785,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
prediction_length: int | None = None,
|
||||
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
batch_size: int = 256,
|
||||
validate_inputs: bool = True,
|
||||
**predict_kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
|
|
@ -814,6 +815,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates,
|
||||
which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch
|
||||
will be lower than this value, by default 256
|
||||
validate_inputs
|
||||
When True, the dataframe(s) will be validated before prediction
|
||||
**predict_kwargs
|
||||
Additional arguments passed to predict_quantiles
|
||||
|
||||
|
|
@ -844,6 +847,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
timestamp_column=timestamp_column,
|
||||
target_columns=target,
|
||||
prediction_length=prediction_length,
|
||||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
# Generate forecasts
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
prediction_length: int,
|
||||
id_column: str = "item_id",
|
||||
timestamp_column: str = "timestamp",
|
||||
validate_inputs: bool = True,
|
||||
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
|
||||
"""
|
||||
Convert from dataframe input format to a list of dictionaries input format.
|
||||
|
|
@ -240,6 +241,8 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
Name of column containing time series identifiers
|
||||
timestamp_column
|
||||
Name of column containing timestamps
|
||||
validate_inputs
|
||||
When True, the dataframe(s) will be validated be conversion
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -251,14 +254,34 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
|
||||
import pandas as pd
|
||||
|
||||
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
timestamp_column=timestamp_column,
|
||||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
if validate_inputs:
|
||||
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
|
||||
df,
|
||||
future_df=future_df,
|
||||
id_column=id_column,
|
||||
timestamp_column=timestamp_column,
|
||||
target_columns=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
else:
|
||||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
||||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
|
||||
# If validation is skipped, the first freq in the dataframe is used
|
||||
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
|
||||
start_idx = 0
|
||||
for length in series_lengths:
|
||||
if length < 3:
|
||||
start_idx += length
|
||||
continue
|
||||
timestamps = timestamp_index[start_idx : start_idx + length]
|
||||
freq = pd.infer_freq(timestamps)
|
||||
break
|
||||
|
||||
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
|
||||
|
||||
# Convert to list of dicts format
|
||||
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
|
||||
|
|
|
|||
|
|
@ -450,7 +450,10 @@ def test_pipeline_can_evaluate_on_dummy_fev_task(pipeline, task_kwargs):
|
|||
],
|
||||
)
|
||||
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
|
||||
def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup, expected_rows, freq):
|
||||
@pytest.mark.parametrize("validate_inputs", [True, False])
|
||||
def test_predict_df_works_for_valid_inputs(
|
||||
pipeline, context_setup, future_setup, expected_rows, freq, validate_inputs
|
||||
):
|
||||
prediction_length = 3
|
||||
df = create_df(**context_setup, freq=freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq)
|
||||
|
|
@ -460,7 +463,13 @@ def test_predict_df_works_for_valid_inputs(pipeline, context_setup, future_setup
|
|||
target_columns = context_setup.get("target_cols", ["target"])
|
||||
n_series = len(series_ids)
|
||||
n_targets = len(target_columns)
|
||||
result = pipeline.predict_df(df, future_df=future_df, target=target_columns, prediction_length=prediction_length)
|
||||
result = pipeline.predict_df(
|
||||
df,
|
||||
future_df=future_df,
|
||||
target=target_columns,
|
||||
prediction_length=prediction_length,
|
||||
validate_inputs=validate_inputs,
|
||||
)
|
||||
|
||||
assert len(result) == expected_rows
|
||||
assert "item_id" in result.columns and np.all(
|
||||
|
|
@ -516,9 +525,10 @@ def test_predict_df_future_df_validation_errors(pipeline, future_data, error_mat
|
|||
pipeline.predict_df(df, future_df=future_df)
|
||||
|
||||
|
||||
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
|
||||
@pytest.mark.parametrize("validate_inputs", [True, False])
|
||||
def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline, validate_inputs):
|
||||
df = create_df()
|
||||
# Make timestamps non-uniform for series A
|
||||
# Make timestamps non-uniform for series A (first series)
|
||||
df.loc[df["item_id"] == "A", "timestamp"] = [
|
||||
"2023-01-01",
|
||||
"2023-01-02",
|
||||
|
|
@ -532,8 +542,8 @@ def test_predict_df_with_non_uniform_timestamps_raises_error(pipeline):
|
|||
"2023-01-11",
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="not infer frequency"):
|
||||
pipeline.predict_df(df)
|
||||
with pytest.raises((ValueError, AssertionError), match="not infer frequency"):
|
||||
pipeline.predict_df(df, validate_inputs=validate_inputs)
|
||||
|
||||
|
||||
def test_predict_df_with_inconsistent_frequencies_raises_error(pipeline):
|
||||
|
|
|
|||
Loading…
Reference in a new issue