From efb86e02c23bf11b700260f04f67c1b56cb51482 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Wed, 17 Dec 2025 11:45:16 +0100 Subject: [PATCH] Chronos-2: Add after_batch callback (#436) *Issue #, if available:* *Description of changes:* Adds support for custom callbacks after each batch is processed during prediction. This allows for keeping track of the time limit in AutoGluon. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- src/chronos/chronos2/pipeline.py | 5 ++++- test/test_chronos2.py | 13 ++++++++++++- test/util.py | 17 ++++++++++++++--- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index e99d8d9..91ed3da 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -9,7 +9,7 @@ import time import warnings from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable import numpy as np import torch @@ -577,6 +577,8 @@ class Chronos2Pipeline(BaseChronosPipeline): # effective batch size increases by a factor of `len(unrolled_quantiles)` when making long-horizon predictions, # by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] unrolled_quantiles = kwargs.pop("unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + # A callback which is called after each batch has been processed + after_batch_callback: Callable = kwargs.pop("after_batch", lambda: None) if len(kwargs) > 0: raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}.") @@ -641,6 +643,7 @@ class Chronos2Pipeline(BaseChronosPipeline): target_idx_ranges=batch_target_idx_ranges, ) all_predictions.extend(batch_prediction) + after_batch_callback() return all_predictions diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3fac726..3d8884e 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -16,7 +16,7 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline from chronos.chronos2.config import Chronos2CoreConfig from chronos.chronos2.layers import MHA from chronos.df_utils import convert_df_input_to_list_of_dicts_input -from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor +from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor, timeout_callback DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model" @@ -1008,6 +1008,17 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy()) +def test_when_predict_df_called_with_timeout_callback_then_timeout_error_is_raised(pipeline): + num_series = 1000 + large_df = create_df(series_ids=[j for j in range(num_series)], n_points=[2048] * num_series) + with pytest.raises(TimeoutError, match="time limit exceeded"): + pipeline.predict_df( + large_df, + prediction_length=48, + after_batch=timeout_callback(0.1), + ) + + @pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) def test_pipeline_works_with_different_attention_implementations(attn_implementation): """Test that the pipeline works with different attention implementations.""" diff --git a/test/util.py b/test/util.py index bc7b878..943e072 100644 --- a/test/util.py +++ b/test/util.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple +import time +from typing import Callable, Optional, Tuple import numpy as np import pandas as pd @@ -13,7 +14,6 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[tor assert a.dtype == dtype - def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"): """Helper to create test context DataFrames.""" series_dfs = [] @@ -44,4 +44,15 @@ def get_forecast_start_times(df, freq="h"): context_end_times = df.groupby("item_id")["timestamp"].max() forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times] - return forecast_start_times \ No newline at end of file + return forecast_start_times + + +def timeout_callback(seconds: float | None) -> Callable: + """Return a callback object that raises an exception if time limit is exceeded.""" + start_time = time.monotonic() + + def callback() -> None: + if seconds is not None and time.monotonic() - start_time > seconds: + raise TimeoutError("time limit exceeded") + + return callback