mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Test that user freq is used
This commit is contained in:
parent
0232330e50
commit
32b4b2e002
1 changed files with 45 additions and 0 deletions
|
|
@ -409,3 +409,48 @@ def test_convert_df_with_freq_and_validate_inputs_false(use_future_df):
|
|||
assert len(prediction_timestamps) == 2
|
||||
for series_id in ["A", "B"]:
|
||||
assert len(prediction_timestamps[series_id]) == prediction_length
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_future_df", [True, False])
|
||||
def test_convert_df_with_mismatched_freq_uses_user_provided_freq(use_future_df):
|
||||
"""Test that user-provided freq overrides data frequency when validate_inputs=False."""
|
||||
# Create data with hourly frequency
|
||||
data_freq = "h"
|
||||
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq=data_freq)
|
||||
prediction_length = 5
|
||||
|
||||
# User provides daily frequency (different from data)
|
||||
user_freq = "D"
|
||||
|
||||
future_df = None
|
||||
if use_future_df:
|
||||
# Create future_df with hourly frequency (matching data, not user freq)
|
||||
forecast_start_times = get_forecast_start_times(df, freq=data_freq)
|
||||
future_df = create_future_df(
|
||||
forecast_start_times=forecast_start_times,
|
||||
series_ids=["A", "B"],
|
||||
n_points=[prediction_length, prediction_length],
|
||||
covariates=["cov1"],
|
||||
freq=data_freq
|
||||
)
|
||||
|
||||
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
|
||||
df=df,
|
||||
future_df=future_df,
|
||||
target_columns=["target"],
|
||||
prediction_length=prediction_length,
|
||||
freq=user_freq,
|
||||
validate_inputs=False,
|
||||
)
|
||||
|
||||
# Prediction should work
|
||||
assert len(inputs) == 2
|
||||
assert len(prediction_timestamps) == 2
|
||||
|
||||
# Forecast timestamps should use user-provided freq (daily), not data freq (hourly)
|
||||
for series_id in ["A", "B"]:
|
||||
pred_ts = prediction_timestamps[series_id]
|
||||
assert len(pred_ts) == prediction_length
|
||||
# Verify the frequency matches user-provided freq
|
||||
inferred_freq = pd.infer_freq(pred_ts)
|
||||
assert inferred_freq == user_freq
|
||||
|
|
|
|||
Loading…
Reference in a new issue