mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
*Issue #, if available:* *Description of changes:* Adds support for custom callbacks after each batch is processed during prediction. This allows for keeping track of the time limit in AutoGluon. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
import time
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
|
|
|
|
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None) -> None:
|
|
assert isinstance(a, torch.Tensor)
|
|
assert a.shape == shape
|
|
|
|
if dtype is not None:
|
|
assert a.dtype == dtype
|
|
|
|
|
|
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
|
|
"""Helper to create test context DataFrames."""
|
|
series_dfs = []
|
|
for series_id, length in zip(series_ids, n_points):
|
|
series_data = {"item_id": series_id, "timestamp": pd.date_range(end="2001-10-01", periods=length, freq=freq)}
|
|
for target_col in target_cols:
|
|
series_data[target_col] = np.random.randn(length)
|
|
if covariates:
|
|
for cov in covariates:
|
|
series_data[cov] = np.random.randn(length)
|
|
series_dfs.append(pd.DataFrame(series_data))
|
|
return pd.concat(series_dfs, ignore_index=True)
|
|
|
|
|
|
def create_future_df(forecast_start_times: list, series_ids=["A", "B"], n_points=[5, 5], covariates=None, freq="h"):
|
|
"""Helper to create test future DataFrames."""
|
|
series_dfs = []
|
|
for series_id, length, start in zip(series_ids, n_points, forecast_start_times):
|
|
series_data = {"item_id": series_id, "timestamp": pd.date_range(start=start, periods=length, freq=freq)}
|
|
if covariates:
|
|
for cov in covariates:
|
|
series_data[cov] = np.random.randn(length)
|
|
series_dfs.append(pd.DataFrame(series_data))
|
|
return pd.concat(series_dfs, ignore_index=True)
|
|
|
|
|
|
def get_forecast_start_times(df, freq="h"):
|
|
context_end_times = df.groupby("item_id")["timestamp"].max()
|
|
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
|
|
|
|
return forecast_start_times
|
|
|
|
|
|
def timeout_callback(seconds: float | None) -> Callable:
|
|
"""Return a callback object that raises an exception if time limit is exceeded."""
|
|
start_time = time.monotonic()
|
|
|
|
def callback() -> None:
|
|
if seconds is not None and time.monotonic() - start_time > seconds:
|
|
raise TimeoutError("time limit exceeded")
|
|
|
|
return callback
|