This commit is contained in:
Dario Fumarola 2026-04-21 17:54:25 +02:00 committed by GitHub
commit d3d9c66c71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 4 deletions

View file

@ -144,8 +144,10 @@ def validate_df_inputs(
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
df = df.sort_values([id_column, timestamp_column])
# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# Get series lengths in the exact order that appears in the sorted dataframe.
# This avoids dtype-specific ordering differences (e.g., string[python]) that can
# break the alignment with contiguous timestamp slices below.
series_lengths = df.groupby(id_column, sort=False).size().to_list()
def validate_freq(timestamps: pd.DatetimeIndex, series_id: str):
freq = pd.infer_freq(timestamps)
@ -273,8 +275,8 @@ def convert_df_input_to_list_of_dicts_input(
# Get the original order of time series IDs
original_order = df[id_column].unique()
# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
# Keep lengths aligned with dataframe row order regardless of ID dtype.
series_lengths = df.groupby(id_column, sort=False).size().to_list()
# If freq is not provided, infer from the first series with >= 3 points
if freq is None:

View file

@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import hashlib
from unittest.mock import patch
import numpy as np
@ -16,6 +17,22 @@ from test.util import create_df, create_future_df, get_forecast_start_times
# Tests for validate_df_inputs function
def _create_unequal_length_weekly_df(
num_series: int = 200, num_periods: int = 60, period_variation: int = 30, seed: int = 42
) -> pd.DataFrame:
rng = np.random.default_rng(seed=seed)
end_date = pd.date_range(start="2023-01-02", periods=num_periods + period_variation, freq="W-MON")[-1]
series_data = []
for i in range(num_series):
series_id = hashlib.sha256(f"series_{i}".encode()).hexdigest()
periods = int(rng.integers(num_periods - period_variation, num_periods + period_variation + 1))
timestamps = pd.date_range(end=end_date, periods=periods, freq="W-MON")
series_data.append(
pd.DataFrame({"item_id": series_id, "timestamp": timestamps, "target": rng.normal(size=periods)})
)
return pd.concat(series_data, ignore_index=True)
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
"""Test that function returns validated dataframes, frequency, series lengths, and original order."""
@ -71,6 +88,47 @@ def test_validate_df_inputs_casts_mixed_dtypes_correctly():
assert validated_df["bool_cov"].dtype == np.float32 # booleans are cast to float32
def test_validate_df_inputs_accepts_string_python_ids_with_unequal_lengths():
"""Regression test for issue #440 with string[python] IDs and unequal series lengths."""
df = _create_unequal_length_weekly_df()
df["item_id"] = df["item_id"].astype("string[python]")
_, _, inferred_freq, series_lengths, _ = validate_df_inputs(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
assert inferred_freq == "W-MON"
assert len(series_lengths) == df["item_id"].nunique()
assert sum(series_lengths) == len(df)
def test_validate_df_inputs_has_consistent_metadata_for_object_and_string_python_ids():
"""Validation metadata should not depend on whether ID dtype is object or string[python]."""
object_df = _create_unequal_length_weekly_df(seed=7)
string_df = object_df.copy()
string_df["item_id"] = string_df["item_id"].astype("string[python]")
_, _, object_freq, object_lengths, object_order = validate_df_inputs(
df=object_df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
_, _, string_freq, string_lengths, string_order = validate_df_inputs(
df=string_df,
future_df=None,
target_columns=["target"],
prediction_length=5,
)
assert string_freq == object_freq
assert string_lengths == object_lengths
assert [str(x) for x in string_order] == [str(x) for x in object_order]
def test_validate_df_inputs_raises_error_when_series_has_insufficient_data():
"""Test that ValueError is raised for series with < 3 data points."""
# Create dataframe with one series having only 2 points
@ -460,3 +518,25 @@ def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df):
# Verify the frequency matches user-provided freq
inferred_freq = pd.infer_freq(pred_ts)
assert inferred_freq == user_freq
def test_convert_df_with_validate_inputs_false_handles_string_python_ids():
"""validate_inputs=False should work with string[python] IDs and preserve per-series lengths."""
df = _create_unequal_length_weekly_df(seed=11)
df["item_id"] = df["item_id"].astype("string[python]")
df = df.sort_values(["item_id", "timestamp"])
inputs, original_order, _ = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=None,
target_columns=["target"],
prediction_length=5,
validate_inputs=False,
freq="W-MON",
)
expected_lengths = df.groupby("item_id", sort=False).size().to_list()
observed_lengths = [task["target"].shape[1] for task in inputs]
assert observed_lengths == expected_lengths
assert len(original_order) == len(expected_lengths)