Chronos-2: Add after_batch callback (#436)

*Issue #, if available:*

*Description of changes:* Adds support for custom callbacks after each
batch is processed during prediction. This allows for keeping track of
the time limit in AutoGluon.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-12-17 11:45:16 +01:00 committed by GitHub
parent 2896499580
commit efb86e02c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 30 additions and 5 deletions

View file

@ -9,7 +9,7 @@ import time
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
import numpy as np
import torch
@ -577,6 +577,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
# effective batch size increases by a factor of `len(unrolled_quantiles)` when making long-horizon predictions,
# by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
unrolled_quantiles = kwargs.pop("unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
# A callback which is called after each batch has been processed
after_batch_callback: Callable = kwargs.pop("after_batch", lambda: None)
if len(kwargs) > 0:
raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}.")
@ -641,6 +643,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
target_idx_ranges=batch_target_idx_ranges,
)
all_predictions.extend(batch_prediction)
after_batch_callback()
return all_predictions

View file

@ -16,7 +16,7 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline
from chronos.chronos2.config import Chronos2CoreConfig
from chronos.chronos2.layers import MHA
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor, timeout_callback
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
@ -1008,6 +1008,17 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
def test_when_predict_df_called_with_timeout_callback_then_timeout_error_is_raised(pipeline):
num_series = 1000
large_df = create_df(series_ids=[j for j in range(num_series)], n_points=[2048] * num_series)
with pytest.raises(TimeoutError, match="time limit exceeded"):
pipeline.predict_df(
large_df,
prediction_length=48,
after_batch=timeout_callback(0.1),
)
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
"""Test that the pipeline works with different attention implementations."""

View file

@ -1,4 +1,5 @@
from typing import Optional, Tuple
import time
from typing import Callable, Optional, Tuple
import numpy as np
import pandas as pd
@ -13,7 +14,6 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[tor
assert a.dtype == dtype
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
"""Helper to create test context DataFrames."""
series_dfs = []
@ -44,4 +44,15 @@ def get_forecast_start_times(df, freq="h"):
context_end_times = df.groupby("item_id")["timestamp"].max()
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
return forecast_start_times
return forecast_start_times
def timeout_callback(seconds: float | None) -> Callable:
"""Return a callback object that raises an exception if time limit is exceeded."""
start_time = time.monotonic()
def callback() -> None:
if seconds is not None and time.monotonic() - start_time > seconds:
raise TimeoutError("time limit exceeded")
return callback