diff --git a/src/chronos/base.py b/src/chronos/base.py index 7592c46..e738717 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -142,6 +142,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): prediction_length: int | None = None, quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], validate_inputs: bool = True, + freq: str | None = None, **predict_kwargs, ) -> "pd.DataFrame": """ @@ -166,6 +167,10 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): validate_inputs When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a regular frequency, and item IDs match between past and future data. Setting to False disables these checks. + freq + Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used + instead of inferring it from the data. This is useful when you already know the frequency and want to + skip the inference overhead. **predict_kwargs Additional arguments passed to predict_quantiles @@ -200,6 +205,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): timestamp_column=timestamp_column, target_columns=[target], prediction_length=prediction_length, + freq=freq, validate_inputs=validate_inputs, ) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index b3ffb05..5537282 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -825,6 +825,7 @@ class Chronos2Pipeline(BaseChronosPipeline): context_length: int | None = None, cross_learning: bool = False, validate_inputs: bool = True, + freq: str | None = None, **predict_kwargs, ) -> "pd.DataFrame": """ @@ -866,6 +867,11 @@ class Chronos2Pipeline(BaseChronosPipeline): validate_inputs When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a regular frequency, and item IDs match between past and future data. Setting to False disables these checks. + freq + Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used + instead of inferring it from the data. This is useful when you already know the frequency and want to + skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted + from future_df when it's available. **predict_kwargs Additional arguments passed to predict_quantiles @@ -896,6 +902,7 @@ class Chronos2Pipeline(BaseChronosPipeline): timestamp_column=timestamp_column, target_columns=target, prediction_length=prediction_length, + freq=freq, validate_inputs=validate_inputs, ) diff --git a/src/chronos/df_utils.py b/src/chronos/df_utils.py index af422ef..08169b0 100644 --- a/src/chronos/df_utils.py +++ b/src/chronos/df_utils.py @@ -203,6 +203,7 @@ def convert_df_input_to_list_of_dicts_input( prediction_length: int, id_column: str = "item_id", timestamp_column: str = "timestamp", + freq: str | None = None, validate_inputs: bool = True, ) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]: """ @@ -229,8 +230,12 @@ def convert_df_input_to_list_of_dicts_input( Name of column containing time series identifiers timestamp_column Name of column containing timestamps + freq + Frequency string for timestamp generation. If provided, this frequency is used + instead of inferring it from the data. Only used when future_df is not provided, + since timestamps are extracted from future_df when it's available. validate_inputs - When True, the dataframe(s) will be validated be conversion + When True, the dataframe(s) will be validated before conversion Returns ------- @@ -243,7 +248,7 @@ def convert_df_input_to_list_of_dicts_input( import pandas as pd if validate_inputs: - df, future_df, freq, series_lengths, original_order = validate_df_inputs( + df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs( df, future_df=future_df, id_column=id_column, @@ -251,6 +256,9 @@ def convert_df_input_to_list_of_dicts_input( target_columns=target_columns, prediction_length=prediction_length, ) + # Use provided freq if available, otherwise use inferred freq + if freq is None: + freq = inferred_freq else: # Get the original order of time series IDs original_order = df[id_column].unique() @@ -258,19 +266,19 @@ def convert_df_input_to_list_of_dicts_input( # Get series lengths series_lengths = df[id_column].value_counts(sort=False).to_list() - # If validation is skipped, the first freq in the dataframe is used - timestamp_index = pd.DatetimeIndex(df[timestamp_column]) - start_idx = 0 - freq = None - for length in series_lengths: - if length < 3: - start_idx += length - continue - timestamps = timestamp_index[start_idx : start_idx + length] - freq = pd.infer_freq(timestamps) - break + # If freq is not provided, infer from the first series with >= 3 points + if freq is None: + timestamp_index = pd.DatetimeIndex(df[timestamp_column]) + start_idx = 0 + for length in series_lengths: + if length < 3: + start_idx += length + continue + timestamps = timestamp_index[start_idx : start_idx + length] + freq = pd.infer_freq(timestamps) + break - assert freq is not None, "validate is False, but could not infer frequency from the dataframe" + assert freq is not None, "validate is False, but could not infer frequency from the dataframe" # Convert to list of dicts format inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = [] @@ -278,29 +286,29 @@ def convert_df_input_to_list_of_dicts_input( indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64") target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df)) - last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,) - offset = pd.tseries.frequencies.to_offset(freq) - with warnings.catch_warnings(): - # Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822 - warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) - # Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length) - prediction_timestamps_array = pd.DatetimeIndex( - np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel() - ) past_covariates_dict = { col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns } future_covariates_dict = {} + if future_df is not None: + # Use timestamps from future_df + prediction_timestamps_array = pd.DatetimeIndex(future_df[timestamp_column]) for col in future_df.columns.drop([id_column, timestamp_column]): future_covariates_dict[col] = future_df[col].to_numpy() - if validate_inputs: - if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any(): - raise ValueError( - "future_df timestamps do not match the expected prediction timestamps. " - "You can disable this check by setting `validate_inputs=False`" - ) + else: + # Generate timestamps from freq + assert freq is not None, "freq must be provided or inferred when future_df is not provided" + last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,) + offset = pd.tseries.frequencies.to_offset(freq) + with warnings.catch_warnings(): + # Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822 + warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) + # Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length) + prediction_timestamps_array = pd.DatetimeIndex( + np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel() + ) for i in range(len(series_lengths)): start_idx, end_idx = indptr[i], indptr[i + 1] diff --git a/test/test_df_utils.py b/test/test_df_utils.py index 2ffe1e8..47518a0 100644 --- a/test/test_df_utils.py +++ b/test/test_df_utils.py @@ -359,3 +359,68 @@ def test_convert_df_preserves_all_values_with_random_inputs(): assert len(inputs) == n_series assert list(original_order) == series_ids assert len(prediction_timestamps) == n_series + + +# Tests for freq parameter + + +@pytest.mark.parametrize("validate_inputs", [True, False]) +@pytest.mark.parametrize("use_future_df", [True, False]) +def test_convert_df_with_provided_freq(validate_inputs, use_future_df): + """Test that provided freq works with different combinations of validate_inputs and future_df.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h") + prediction_length = 5 + + future_df = None + if use_future_df: + forecast_start_times = get_forecast_start_times(df, freq="h") + future_df = create_future_df( + forecast_start_times=forecast_start_times, + series_ids=["A", "B"], + n_points=[prediction_length, prediction_length], + covariates=["cov1"], + freq="h" + ) + + inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input( + df=df, + future_df=future_df, + target_columns=["target"], + prediction_length=prediction_length, + freq="h", + validate_inputs=validate_inputs, + ) + + assert len(inputs) == 2 + assert len(prediction_timestamps) == 2 + for series_id in ["A", "B"]: + assert len(prediction_timestamps[series_id]) == prediction_length + + +def test_convert_df_with_future_df_uses_future_df_timestamps(): + """Test that timestamps from future_df are used when future_df is provided.""" + df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h") + + # Create future_df with 2h freq (different from df's 1h freq) + forecast_start_times = get_forecast_start_times(df, freq="2h") + future_df = create_future_df( + forecast_start_times=forecast_start_times, + series_ids=["A", "B"], + n_points=[5, 5], + covariates=["cov1"], + freq="2h" + ) + + inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input( + df=df, + future_df=future_df, + target_columns=["target"], + prediction_length=5, + validate_inputs=False, + ) + + # Verify timestamps come from future_df (2h spacing) + future_df_sorted = future_df.sort_values(["item_id", "timestamp"]) + for series_id in ["A", "B"]: + expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"]) + pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)