chronos-forecasting/test/util.py
Abdul Fatir efb86e02c2
Chronos-2: Add after_batch callback (#436)
*Issue #, if available:*

*Description of changes:* Adds support for custom callbacks after each
batch is processed during prediction. This allows for keeping track of
the time limit in AutoGluon.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
2025-12-17 11:45:16 +01:00

58 lines
2.2 KiB
Python

import time
from typing import Callable, Optional, Tuple
import numpy as np
import pandas as pd
import torch
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None) -> None:
assert isinstance(a, torch.Tensor)
assert a.shape == shape
if dtype is not None:
assert a.dtype == dtype
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
"""Helper to create test context DataFrames."""
series_dfs = []
for series_id, length in zip(series_ids, n_points):
series_data = {"item_id": series_id, "timestamp": pd.date_range(end="2001-10-01", periods=length, freq=freq)}
for target_col in target_cols:
series_data[target_col] = np.random.randn(length)
if covariates:
for cov in covariates:
series_data[cov] = np.random.randn(length)
series_dfs.append(pd.DataFrame(series_data))
return pd.concat(series_dfs, ignore_index=True)
def create_future_df(forecast_start_times: list, series_ids=["A", "B"], n_points=[5, 5], covariates=None, freq="h"):
"""Helper to create test future DataFrames."""
series_dfs = []
for series_id, length, start in zip(series_ids, n_points, forecast_start_times):
series_data = {"item_id": series_id, "timestamp": pd.date_range(start=start, periods=length, freq=freq)}
if covariates:
for cov in covariates:
series_data[cov] = np.random.randn(length)
series_dfs.append(pd.DataFrame(series_data))
return pd.concat(series_dfs, ignore_index=True)
def get_forecast_start_times(df, freq="h"):
context_end_times = df.groupby("item_id")["timestamp"].max()
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
return forecast_start_times
def timeout_callback(seconds: float | None) -> Callable:
"""Return a callback object that raises an exception if time limit is exceeded."""
start_time = time.monotonic()
def callback() -> None:
if seconds is not None and time.monotonic() - start_time > seconds:
raise TimeoutError("time limit exceeded")
return callback