mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Chronos-2: Add after_batch callback (#436)
*Issue #, if available:* *Description of changes:* Adds support for custom callbacks after each batch is processed during prediction. This allows for keeping track of the time limit in AutoGluon. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
2896499580
commit
efb86e02c2
3 changed files with 30 additions and 5 deletions
|
|
@ -9,7 +9,7 @@ import time
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -577,6 +577,8 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
# effective batch size increases by a factor of `len(unrolled_quantiles)` when making long-horizon predictions,
|
||||
# by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
unrolled_quantiles = kwargs.pop("unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
||||
# A callback which is called after each batch has been processed
|
||||
after_batch_callback: Callable = kwargs.pop("after_batch", lambda: None)
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}.")
|
||||
|
|
@ -641,6 +643,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
target_idx_ranges=batch_target_idx_ranges,
|
||||
)
|
||||
all_predictions.extend(batch_prediction)
|
||||
after_batch_callback()
|
||||
|
||||
return all_predictions
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline
|
|||
from chronos.chronos2.config import Chronos2CoreConfig
|
||||
from chronos.chronos2.layers import MHA
|
||||
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
|
||||
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor
|
||||
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor, timeout_callback
|
||||
|
||||
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
|
||||
|
||||
|
|
@ -1008,6 +1008,17 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
|
|||
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
|
||||
|
||||
|
||||
def test_when_predict_df_called_with_timeout_callback_then_timeout_error_is_raised(pipeline):
|
||||
num_series = 1000
|
||||
large_df = create_df(series_ids=[j for j in range(num_series)], n_points=[2048] * num_series)
|
||||
with pytest.raises(TimeoutError, match="time limit exceeded"):
|
||||
pipeline.predict_df(
|
||||
large_df,
|
||||
prediction_length=48,
|
||||
after_batch=timeout_callback(0.1),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
|
||||
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
|
||||
"""Test that the pipeline works with different attention implementations."""
|
||||
|
|
|
|||
17
test/util.py
17
test/util.py
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, Tuple
|
||||
import time
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -13,7 +14,6 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[tor
|
|||
assert a.dtype == dtype
|
||||
|
||||
|
||||
|
||||
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
|
||||
"""Helper to create test context DataFrames."""
|
||||
series_dfs = []
|
||||
|
|
@ -44,4 +44,15 @@ def get_forecast_start_times(df, freq="h"):
|
|||
context_end_times = df.groupby("item_id")["timestamp"].max()
|
||||
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
|
||||
|
||||
return forecast_start_times
|
||||
return forecast_start_times
|
||||
|
||||
|
||||
def timeout_callback(seconds: float | None) -> Callable:
|
||||
"""Return a callback object that raises an exception if time limit is exceeded."""
|
||||
start_time = time.monotonic()
|
||||
|
||||
def callback() -> None:
|
||||
if seconds is not None and time.monotonic() - start_time > seconds:
|
||||
raise TimeoutError("time limit exceeded")
|
||||
|
||||
return callback
|
||||
|
|
|
|||
Loading…
Reference in a new issue