From e0ccdf59361e5a0b8e393ef93517dfe88f41e8b0 Mon Sep 17 00:00:00 2001 From: dario-fumarola Date: Thu, 26 Feb 2026 09:54:55 -0500 Subject: [PATCH] Fix string[python] id ordering in dataframe frequency validation --- src/chronos/df_utils.py | 10 +++--- test/test_df_utils.py | 80 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/src/chronos/df_utils.py b/src/chronos/df_utils.py index 3adc1b0..232c780 100644 --- a/src/chronos/df_utils.py +++ b/src/chronos/df_utils.py @@ -144,8 +144,10 @@ def validate_df_inputs( df[timestamp_column] = pd.to_datetime(df[timestamp_column]) df = df.sort_values([id_column, timestamp_column]) - # Get series lengths - series_lengths = df[id_column].value_counts(sort=False).to_list() + # Get series lengths in the exact order that appears in the sorted dataframe. + # This avoids dtype-specific ordering differences (e.g., string[python]) that can + # break the alignment with contiguous timestamp slices below. + series_lengths = df.groupby(id_column, sort=False).size().to_list() def validate_freq(timestamps: pd.DatetimeIndex, series_id: str): freq = pd.infer_freq(timestamps) @@ -273,8 +275,8 @@ def convert_df_input_to_list_of_dicts_input( # Get the original order of time series IDs original_order = df[id_column].unique() - # Get series lengths - series_lengths = df[id_column].value_counts(sort=False).to_list() + # Keep lengths aligned with dataframe row order regardless of ID dtype. + series_lengths = df.groupby(id_column, sort=False).size().to_list() # If freq is not provided, infer from the first series with >= 3 points if freq is None: diff --git a/test/test_df_utils.py b/test/test_df_utils.py index 7d5d814..1569c3d 100644 --- a/test/test_df_utils.py +++ b/test/test_df_utils.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import hashlib from unittest.mock import patch import numpy as np @@ -16,6 +17,22 @@ from test.util import create_df, create_future_df, get_forecast_start_times # Tests for validate_df_inputs function +def _create_unequal_length_weekly_df( + num_series: int = 200, num_periods: int = 60, period_variation: int = 30, seed: int = 42 +) -> pd.DataFrame: + rng = np.random.default_rng(seed=seed) + end_date = pd.date_range(start="2023-01-02", periods=num_periods + period_variation, freq="W-MON")[-1] + series_data = [] + for i in range(num_series): + series_id = hashlib.sha256(f"series_{i}".encode()).hexdigest() + periods = int(rng.integers(num_periods - period_variation, num_periods + period_variation + 1)) + timestamps = pd.date_range(end=end_date, periods=periods, freq="W-MON") + series_data.append( + pd.DataFrame({"item_id": series_id, "timestamp": timestamps, "target": rng.normal(size=periods)}) + ) + return pd.concat(series_data, ignore_index=True) + + @pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"]) def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq): """Test that function returns validated dataframes, frequency, series lengths, and original order.""" @@ -71,6 +88,47 @@ def test_validate_df_inputs_casts_mixed_dtypes_correctly(): assert validated_df["bool_cov"].dtype == np.float32 # booleans are cast to float32 +def test_validate_df_inputs_accepts_string_python_ids_with_unequal_lengths(): + """Regression test for issue #440 with string[python] IDs and unequal series lengths.""" + df = _create_unequal_length_weekly_df() + df["item_id"] = df["item_id"].astype("string[python]") + + _, _, inferred_freq, series_lengths, _ = validate_df_inputs( + df=df, + future_df=None, + target_columns=["target"], + prediction_length=5, + ) + + assert inferred_freq == "W-MON" + assert len(series_lengths) == df["item_id"].nunique() + assert sum(series_lengths) == len(df) + + +def test_validate_df_inputs_has_consistent_metadata_for_object_and_string_python_ids(): + """Validation metadata should not depend on whether ID dtype is object or string[python].""" + object_df = _create_unequal_length_weekly_df(seed=7) + string_df = object_df.copy() + string_df["item_id"] = string_df["item_id"].astype("string[python]") + + _, _, object_freq, object_lengths, object_order = validate_df_inputs( + df=object_df, + future_df=None, + target_columns=["target"], + prediction_length=5, + ) + _, _, string_freq, string_lengths, string_order = validate_df_inputs( + df=string_df, + future_df=None, + target_columns=["target"], + prediction_length=5, + ) + + assert string_freq == object_freq + assert string_lengths == object_lengths + assert [str(x) for x in string_order] == [str(x) for x in object_order] + + def test_validate_df_inputs_raises_error_when_series_has_insufficient_data(): """Test that ValueError is raised for series with < 3 data points.""" # Create dataframe with one series having only 2 points @@ -460,3 +518,25 @@ def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df): # Verify the frequency matches user-provided freq inferred_freq = pd.infer_freq(pred_ts) assert inferred_freq == user_freq + + +def test_convert_df_with_validate_inputs_false_handles_string_python_ids(): + """validate_inputs=False should work with string[python] IDs and preserve per-series lengths.""" + df = _create_unequal_length_weekly_df(seed=11) + df["item_id"] = df["item_id"].astype("string[python]") + df = df.sort_values(["item_id", "timestamp"]) + + inputs, original_order, _ = convert_df_input_to_list_of_dicts_input( + df=df, + future_df=None, + target_columns=["target"], + prediction_length=5, + validate_inputs=False, + freq="W-MON", + ) + + expected_lengths = df.groupby("item_id", sort=False).size().to_list() + observed_lengths = [task["target"].shape[1] for task in inputs] + + assert observed_lengths == expected_lengths + assert len(original_order) == len(expected_lengths)