mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Merge e0ccdf5936 into 32111085d8
This commit is contained in:
commit
d3d9c66c71
2 changed files with 86 additions and 4 deletions
|
|
@ -144,8 +144,10 @@ def validate_df_inputs(
|
|||
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
|
||||
df = df.sort_values([id_column, timestamp_column])
|
||||
|
||||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
# Get series lengths in the exact order that appears in the sorted dataframe.
|
||||
# This avoids dtype-specific ordering differences (e.g., string[python]) that can
|
||||
# break the alignment with contiguous timestamp slices below.
|
||||
series_lengths = df.groupby(id_column, sort=False).size().to_list()
|
||||
|
||||
def validate_freq(timestamps: pd.DatetimeIndex, series_id: str):
|
||||
freq = pd.infer_freq(timestamps)
|
||||
|
|
@ -273,8 +275,8 @@ def convert_df_input_to_list_of_dicts_input(
|
|||
# Get the original order of time series IDs
|
||||
original_order = df[id_column].unique()
|
||||
|
||||
# Get series lengths
|
||||
series_lengths = df[id_column].value_counts(sort=False).to_list()
|
||||
# Keep lengths aligned with dataframe row order regardless of ID dtype.
|
||||
series_lengths = df.groupby(id_column, sort=False).size().to_list()
|
||||
|
||||
# If freq is not provided, infer from the first series with >= 3 points
|
||||
if freq is None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -16,6 +17,22 @@ from test.util import create_df, create_future_df, get_forecast_start_times
|
|||
# Tests for validate_df_inputs function
|
||||
|
||||
|
||||
def _create_unequal_length_weekly_df(
|
||||
num_series: int = 200, num_periods: int = 60, period_variation: int = 30, seed: int = 42
|
||||
) -> pd.DataFrame:
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
end_date = pd.date_range(start="2023-01-02", periods=num_periods + period_variation, freq="W-MON")[-1]
|
||||
series_data = []
|
||||
for i in range(num_series):
|
||||
series_id = hashlib.sha256(f"series_{i}".encode()).hexdigest()
|
||||
periods = int(rng.integers(num_periods - period_variation, num_periods + period_variation + 1))
|
||||
timestamps = pd.date_range(end=end_date, periods=periods, freq="W-MON")
|
||||
series_data.append(
|
||||
pd.DataFrame({"item_id": series_id, "timestamp": timestamps, "target": rng.normal(size=periods)})
|
||||
)
|
||||
return pd.concat(series_data, ignore_index=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
|
||||
def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
|
||||
"""Test that function returns validated dataframes, frequency, series lengths, and original order."""
|
||||
|
|
@ -71,6 +88,47 @@ def test_validate_df_inputs_casts_mixed_dtypes_correctly():
|
|||
assert validated_df["bool_cov"].dtype == np.float32 # booleans are cast to float32
|
||||
|
||||
|
||||
def test_validate_df_inputs_accepts_string_python_ids_with_unequal_lengths():
|
||||
"""Regression test for issue #440 with string[python] IDs and unequal series lengths."""
|
||||
df = _create_unequal_length_weekly_df()
|
||||
df["item_id"] = df["item_id"].astype("string[python]")
|
||||
|
||||
_, _, inferred_freq, series_lengths, _ = validate_df_inputs(
|
||||
df=df,
|
||||
future_df=None,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
)
|
||||
|
||||
assert inferred_freq == "W-MON"
|
||||
assert len(series_lengths) == df["item_id"].nunique()
|
||||
assert sum(series_lengths) == len(df)
|
||||
|
||||
|
||||
def test_validate_df_inputs_has_consistent_metadata_for_object_and_string_python_ids():
|
||||
"""Validation metadata should not depend on whether ID dtype is object or string[python]."""
|
||||
object_df = _create_unequal_length_weekly_df(seed=7)
|
||||
string_df = object_df.copy()
|
||||
string_df["item_id"] = string_df["item_id"].astype("string[python]")
|
||||
|
||||
_, _, object_freq, object_lengths, object_order = validate_df_inputs(
|
||||
df=object_df,
|
||||
future_df=None,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
)
|
||||
_, _, string_freq, string_lengths, string_order = validate_df_inputs(
|
||||
df=string_df,
|
||||
future_df=None,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
)
|
||||
|
||||
assert string_freq == object_freq
|
||||
assert string_lengths == object_lengths
|
||||
assert [str(x) for x in string_order] == [str(x) for x in object_order]
|
||||
|
||||
|
||||
def test_validate_df_inputs_raises_error_when_series_has_insufficient_data():
|
||||
"""Test that ValueError is raised for series with < 3 data points."""
|
||||
# Create dataframe with one series having only 2 points
|
||||
|
|
@ -460,3 +518,25 @@ def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df):
|
|||
# Verify the frequency matches user-provided freq
|
||||
inferred_freq = pd.infer_freq(pred_ts)
|
||||
assert inferred_freq == user_freq
|
||||
|
||||
|
||||
def test_convert_df_with_validate_inputs_false_handles_string_python_ids():
|
||||
"""validate_inputs=False should work with string[python] IDs and preserve per-series lengths."""
|
||||
df = _create_unequal_length_weekly_df(seed=11)
|
||||
df["item_id"] = df["item_id"].astype("string[python]")
|
||||
df = df.sort_values(["item_id", "timestamp"])
|
||||
|
||||
inputs, original_order, _ = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=None,
|
||||
target_columns=["target"],
|
||||
prediction_length=5,
|
||||
validate_inputs=False,
|
||||
freq="W-MON",
|
||||
)
|
||||
|
||||
expected_lengths = df.groupby("item_id", sort=False).size().to_list()
|
||||
observed_lengths = [task["target"].shape[1] for task in inputs]
|
||||
|
||||
assert observed_lengths == expected_lengths
|
||||
assert len(original_order) == len(expected_lengths)
|
||||
|
|
|
|||
Loading…
Reference in a new issue