Merge branch 'main' into transformers-v5

This commit is contained in:
Kashif Rasul 2026-01-22 12:08:39 +01:00 committed by GitHub
commit 056b30ecc9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 213 additions and 78 deletions

View file

@ -142,6 +142,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
prediction_length: int | None = None,
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
@ -164,8 +165,14 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
quantile_levels
Quantile levels to compute
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
timestamps are regularly spaced (e.g., with hourly frequency).
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
validate_inputs=False. When provided, skips frequency inference from the data.
**predict_kwargs
Additional arguments passed to predict_quantiles
@ -200,6 +207,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
timestamp_column=timestamp_column,
target_columns=[target],
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

View file

@ -825,6 +825,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
context_length: int | None = None,
cross_learning: bool = False,
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
@ -864,8 +865,14 @@ class Chronos2Pipeline(BaseChronosPipeline):
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
timestamps are regularly spaced (e.g., with hourly frequency).
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
validate_inputs=False. When provided, skips frequency inference from the data.
**predict_kwargs
Additional arguments passed to predict_quantiles
@ -896,6 +903,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
timestamp_column=timestamp_column,
target_columns=target,
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

View file

@ -204,6 +204,7 @@ def convert_df_input_to_list_of_dicts_input(
id_column: str = "item_id",
timestamp_column: str = "timestamp",
validate_inputs: bool = True,
freq: str | None = None,
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
"""
Convert from dataframe input format to a list of dictionaries input format.
@ -230,7 +231,14 @@ def convert_df_input_to_list_of_dicts_input(
timestamp_column
Name of column containing timestamps
validate_inputs
When True, the dataframe(s) will be validated be conversion
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
timestamps are regularly spaced (e.g., with hourly frequency).
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used
when validate_inputs=False. When provided, skips frequency inference from the data.
Returns
-------
@ -242,6 +250,16 @@ def convert_df_input_to_list_of_dicts_input(
import pandas as pd
if freq is not None and validate_inputs:
raise ValueError(
"freq can only be provided when validate_inputs=False. "
"When using freq with validate_inputs=False, you must ensure: "
"(1) all dataframes are sorted by (id_column, timestamp_column); "
"(2) future_df (if provided) has the same item IDs as df with exactly "
"prediction_length rows of future timestamps per item; "
"(3) all timestamps are regularly spaced."
)
if validate_inputs:
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df,
@ -258,19 +276,19 @@ def convert_df_input_to_list_of_dicts_input(
# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# If validation is skipped, the first freq in the dataframe is used
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
freq = None
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
# If freq is not provided, infer from the first series with >= 3 points
if freq is None:
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
assert freq is not None, "validate_inputs is False, but could not infer frequency from the dataframe"
# Convert to list of dicts format
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []

View file

@ -13,7 +13,6 @@ from chronos.df_utils import (
)
from test.util import create_df, create_future_df, get_forecast_start_times
# Tests for validate_df_inputs function
@ -22,7 +21,7 @@ def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
"""Test that function returns validated dataframes, frequency, series lengths, and original order."""
# Create test data with 2 series
df = create_df(series_ids=["A", "B"], n_points=[10, 15], target_cols=["target"], freq=freq)
# Call validate_df_inputs
validated_df, validated_future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
df=df,
@ -32,7 +31,7 @@ def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
id_column="item_id",
timestamp_column="timestamp",
)
# Verify key return values
assert validated_future_df is None
assert inferred_freq is not None
@ -46,15 +45,17 @@ def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
def test_validate_df_inputs_casts_mixed_dtypes_correctly():
"""Test that numeric columns are cast to float32 and categorical/string/object columns are cast to category."""
# Create dataframe with mixed column types
df = pd.DataFrame({
"item_id": ["A"] * 10,
"timestamp": pd.date_range(end="2001-10-01", periods=10, freq="h"),
"target": np.random.randn(10), # numeric
"numeric_cov": np.random.randint(0, 10, 10), # integer numeric
"string_cov": ["cat1"] * 5 + ["cat2"] * 5, # string
"bool_cov": [True, False] * 5, # boolean
})
df = pd.DataFrame(
{
"item_id": ["A"] * 10,
"timestamp": pd.date_range(end="2001-10-01", periods=10, freq="h"),
"target": np.random.randn(10), # numeric
"numeric_cov": np.random.randint(0, 10, 10), # integer numeric
"string_cov": ["cat1"] * 5 + ["cat2"] * 5, # string
"bool_cov": [True, False] * 5, # boolean
}
)
# Call validate_df_inputs
validated_df, _, _, _, _ = validate_df_inputs(
df=df,
@ -62,7 +63,7 @@ def test_validate_df_inputs_casts_mixed_dtypes_correctly():
target_columns=["target"],
prediction_length=5,
)
# Verify dtypes after validation
assert validated_df["target"].dtype == np.float32
assert validated_df["numeric_cov"].dtype == np.float32
@ -74,7 +75,7 @@ def test_validate_df_inputs_raises_error_when_series_has_insufficient_data():
"""Test that ValueError is raised for series with < 3 data points."""
# Create dataframe with one series having only 2 points
df = create_df(series_ids=["A", "B"], n_points=[10, 2], target_cols=["target"], freq="h")
# Verify error is raised with series ID in message
with pytest.raises(ValueError, match=r"Every time series must have at least 3 data points.*series B"):
validate_df_inputs(
@ -89,17 +90,13 @@ def test_validate_df_inputs_raises_error_when_future_df_has_mismatched_series_id
"""Test that ValueError is raised when future_df has different series IDs than df."""
# Create df with series A and B
df = create_df(series_ids=["A", "B"], n_points=[10, 15], target_cols=["target"], freq="h")
# Create future_df with only series A
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
forecast_start_times=[forecast_start_times[0]],
series_ids=["A"],
n_points=[5],
covariates=None,
freq="h"
forecast_start_times=[forecast_start_times[0]], series_ids=["A"], n_points=[5], covariates=None, freq="h"
)
# Verify appropriate error is raised
with pytest.raises(ValueError, match=r"future_df must contain the same time series IDs as df"):
validate_df_inputs(
@ -114,7 +111,7 @@ def test_validate_df_inputs_raises_error_when_future_df_has_incorrect_lengths():
"""Test that ValueError is raised when future_df lengths don't match prediction_length."""
# Create df with series A and B with a covariate
df = create_df(series_ids=["A", "B"], n_points=[10, 13], target_cols=["target"], covariates=["cov1"], freq="h")
# Create future_df with varying lengths per series (3 and 7 instead of 5)
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
@ -122,11 +119,13 @@ def test_validate_df_inputs_raises_error_when_future_df_has_incorrect_lengths():
series_ids=["A", "B"],
n_points=[3, 7], # incorrect lengths
covariates=["cov1"],
freq="h"
freq="h",
)
# Verify error message indicates which series have incorrect lengths
with pytest.raises(ValueError, match=r"future_df must contain prediction_length=5 values for each series.*different lengths"):
with pytest.raises(
ValueError, match=r"future_df must contain prediction_length=5 values for each series.*different lengths"
):
validate_df_inputs(
df=df,
future_df=future_df,
@ -141,42 +140,46 @@ def test_validate_df_inputs_raises_error_when_future_df_has_incorrect_lengths():
def test_convert_df_with_single_target_preserves_values():
"""Test conversion with single target column."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
# Verify output list has correct length (one per series)
assert len(inputs) == 2
# Verify target arrays have correct shape and values match input
assert inputs[0]["target"].shape == (1, 10) # (n_targets=1, n_timesteps=10)
assert inputs[1]["target"].shape == (1, 12) # (n_targets=1, n_timesteps=12)
# Verify values are preserved
df_sorted = df.sort_values(["item_id", "timestamp"])
np.testing.assert_array_almost_equal(inputs[0]["target"][0], df_sorted[df_sorted["item_id"] == "A"]["target"].values)
np.testing.assert_array_almost_equal(inputs[1]["target"][0], df_sorted[df_sorted["item_id"] == "B"]["target"].values)
np.testing.assert_array_almost_equal(
inputs[0]["target"][0], df_sorted[df_sorted["item_id"] == "A"]["target"].values
)
np.testing.assert_array_almost_equal(
inputs[1]["target"][0], df_sorted[df_sorted["item_id"] == "B"]["target"].values
)
def test_convert_df_with_multiple_targets_preserves_values_and_shape():
"""Test conversion with multiple target columns."""
df = create_df(series_ids=["A", "B"], n_points=[10, 14], target_cols=["target1", "target2"], freq="h")
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target1", "target2"],
prediction_length=5,
)
# Verify target arrays have shape (n_targets, n_timesteps)
assert inputs[0]["target"].shape == (2, 10)
assert inputs[1]["target"].shape == (2, 14)
# Verify all target values are preserved for both series
df_sorted = df.sort_values(["item_id", "timestamp"])
for i, series_id in enumerate(["A", "B"]):
@ -187,26 +190,28 @@ def test_convert_df_with_multiple_targets_preserves_values_and_shape():
def test_convert_df_with_past_covariates_includes_them_in_output():
"""Test conversion with past covariates only."""
df = create_df(series_ids=["A", "B"], n_points=[10, 16], target_cols=["target"], covariates=["cov1", "cov2"], freq="h")
df = create_df(
series_ids=["A", "B"], n_points=[10, 16], target_cols=["target"], covariates=["cov1", "cov2"], freq="h"
)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
# Verify output includes past_covariates dictionary
assert "past_covariates" in inputs[0]
assert "cov1" in inputs[0]["past_covariates"]
assert "cov2" in inputs[0]["past_covariates"]
# Verify covariate values match input for both series
assert inputs[0]["past_covariates"]["cov1"].shape == (10,)
assert inputs[0]["past_covariates"]["cov2"].shape == (10,)
assert inputs[1]["past_covariates"]["cov1"].shape == (16,)
assert inputs[1]["past_covariates"]["cov2"].shape == (16,)
# Verify no future_covariates key in output
assert "future_covariates" not in inputs[0]
@ -214,29 +219,29 @@ def test_convert_df_with_past_covariates_includes_them_in_output():
def test_convert_df_with_past_and_future_covariates_includes_both():
"""Test conversion with both past and future covariates."""
df = create_df(series_ids=["A", "B"], n_points=[10, 18], target_cols=["target"], covariates=["cov1"], freq="h")
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[5, 5],
covariates=["cov1"],
freq="h"
freq="h",
)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=5,
)
# Verify output includes both past_covariates and future_covariates dictionaries for both series
assert "past_covariates" in inputs[0]
assert "future_covariates" in inputs[0]
assert "past_covariates" in inputs[1]
assert "future_covariates" in inputs[1]
# Verify all covariate values are preserved with correct shapes
assert inputs[0]["past_covariates"]["cov1"].shape == (10,)
assert inputs[0]["future_covariates"]["cov1"].shape == (5,)
@ -249,21 +254,21 @@ def test_convert_df_generates_prediction_timestamps_with_correct_frequency(freq)
"""Test that prediction timestamps follow the inferred frequency."""
# Use multiple series with irregular lengths
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 15, 12], target_cols=["target"], freq=freq)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
# Verify timestamps for all series
for series_id in ["A", "B", "C"]:
# Verify timestamps start after last context timestamp
last_context_time = df[df["item_id"] == series_id]["timestamp"].max()
first_pred_time = prediction_timestamps[series_id][0]
assert first_pred_time > last_context_time
# Verify timestamps are evenly spaced according to frequency
pred_times = prediction_timestamps[series_id]
assert len(pred_times) == 5
@ -274,7 +279,7 @@ def test_convert_df_generates_prediction_timestamps_with_correct_frequency(freq)
def test_convert_df_skips_validation_when_disabled():
"""Test that validate_inputs=False skips validation."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
# Mock validate_df_inputs to verify it's not called when validation is disabled
with patch("chronos.df_utils.validate_df_inputs") as mock_validate:
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
@ -284,10 +289,10 @@ def test_convert_df_skips_validation_when_disabled():
prediction_length=5,
validate_inputs=False,
)
# Verify validate_df_inputs was not called
mock_validate.assert_not_called()
# Verify conversion still works
assert len(inputs) == 2
@ -300,17 +305,19 @@ def test_convert_df_preserves_all_values_with_random_inputs():
n_past_only_covariates = np.random.randint(1, 3)
n_future_covariates = np.random.randint(1, 3)
prediction_length = 5
series_ids = [f"series_{i}" for i in range(n_series)]
n_points = [np.random.randint(10, 20) for _ in range(n_series)]
target_cols = [f"target_{i}" for i in range(n_targets)]
past_only_covariates = [f"past_cov_{i}" for i in range(n_past_only_covariates)]
future_covariates = [f"future_cov_{i}" for i in range(n_future_covariates)]
all_covariates = past_only_covariates + future_covariates
# Create dataframe with all covariates
df = create_df(series_ids=series_ids, n_points=n_points, target_cols=target_cols, covariates=all_covariates, freq="h")
df = create_df(
series_ids=series_ids, n_points=n_points, target_cols=target_cols, covariates=all_covariates, freq="h"
)
# Create future_df with only future covariates (not past-only ones)
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
@ -318,9 +325,9 @@ def test_convert_df_preserves_all_values_with_random_inputs():
series_ids=series_ids,
n_points=[prediction_length] * n_series,
covariates=future_covariates,
freq="h"
freq="h",
)
# Convert to list-of-dicts format
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
@ -328,23 +335,23 @@ def test_convert_df_preserves_all_values_with_random_inputs():
target_columns=target_cols,
prediction_length=prediction_length,
)
# Verify all target values are preserved exactly
df_sorted = df.sort_values(["item_id", "timestamp"])
for i, series_id in enumerate(series_ids):
series_data = df_sorted[df_sorted["item_id"] == series_id]
assert inputs[i]["target"].shape == (n_targets, n_points[i])
for j, target_col in enumerate(target_cols):
np.testing.assert_array_almost_equal(inputs[i]["target"][j], series_data[target_col].values)
# Verify all past covariate values are preserved (both past-only and future covariates)
for i, series_id in enumerate(series_ids):
series_data = df_sorted[df_sorted["item_id"] == series_id]
assert "past_covariates" in inputs[i]
for cov in all_covariates:
np.testing.assert_array_almost_equal(inputs[i]["past_covariates"][cov], series_data[cov].values)
# Verify only future covariates are in future_covariates (not past-only ones)
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
for i, series_id in enumerate(series_ids):
@ -354,8 +361,102 @@ def test_convert_df_preserves_all_values_with_random_inputs():
assert set(inputs[i]["future_covariates"].keys()) == set(future_covariates)
for cov in future_covariates:
np.testing.assert_array_almost_equal(inputs[i]["future_covariates"][cov], series_future_data[cov].values)
# Verify output structure is correct
assert len(inputs) == n_series
assert list(original_order) == series_ids
assert len(prediction_timestamps) == n_series
def test_convert_df_with_freq_and_validate_inputs_raises_error():
"""Test that providing freq with validate_inputs=True raises ValueError."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
with pytest.raises(ValueError, match="freq can only be provided when validate_inputs=False"):
convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
freq="h",
validate_inputs=True,
)
@pytest.mark.parametrize("use_future_df", [True, False])
def test_convert_df_with_freq_and_validate_inputs_false(use_future_df):
"""Test that freq works with validate_inputs=False."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
prediction_length = 5
future_df = None
if use_future_df:
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[prediction_length, prediction_length],
covariates=["cov1"],
freq="h",
)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=prediction_length,
freq="h",
validate_inputs=False,
)
assert len(inputs) == 2
assert len(prediction_timestamps) == 2
for series_id in ["A", "B"]:
assert len(prediction_timestamps[series_id]) == prediction_length
@pytest.mark.parametrize("use_future_df", [True, False])
def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df):
"""Test that user-provided freq overrides data frequency when validate_inputs=False."""
# Create data with hourly frequency
data_freq = "h"
df = create_df(
series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq=data_freq
)
prediction_length = 5
# User provides daily frequency (different from data)
user_freq = "D"
future_df = None
if use_future_df:
# Create future_df with hourly frequency (matching data, not user freq)
forecast_start_times = get_forecast_start_times(df, freq=data_freq)
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[prediction_length, prediction_length],
covariates=["cov1"],
freq=data_freq,
)
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=prediction_length,
freq=user_freq,
validate_inputs=False,
)
# Prediction should work
assert len(inputs) == 2
assert len(prediction_timestamps) == 2
# Forecast timestamps should use user-provided freq (daily), not data freq (hourly)
for series_id in ["A", "B"]:
pred_ts = prediction_timestamps[series_id]
assert len(pred_ts) == prediction_length
# Verify the frequency matches user-provided freq
inferred_freq = pd.infer_freq(pred_ts)
assert inferred_freq == user_freq