2025-12-17 10:45:16 +00:00
|
|
|
import time
|
|
|
|
|
from typing import Callable, Optional, Tuple
|
2024-11-29 15:54:21 +00:00
|
|
|
|
2025-11-11 17:37:19 +00:00
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
2024-11-29 15:54:21 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
2025-10-20 08:34:20 +00:00
|
|
|
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None) -> None:
|
2024-11-29 15:54:21 +00:00
|
|
|
assert isinstance(a, torch.Tensor)
|
|
|
|
|
assert a.shape == shape
|
|
|
|
|
|
|
|
|
|
if dtype is not None:
|
Fix padding for int contexts (#227)
*Issue #, if available:* On Linux, the final call to `.to` creates
trouble when input tensors are integer. For example:
```
>>> a = torch.tensor([1])
>>> b = torch.stack([torch.full((1,), torch.nan), a])
>>> b
tensor([[nan],
[1.]])
>>> b.to(a)
tensor([[-9223372036854775808],
[ 1]])
```
By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
2024-12-04 15:46:17 +00:00
|
|
|
assert a.dtype == dtype
|
2025-11-11 17:37:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
|
|
|
|
|
"""Helper to create test context DataFrames."""
|
|
|
|
|
series_dfs = []
|
|
|
|
|
for series_id, length in zip(series_ids, n_points):
|
|
|
|
|
series_data = {"item_id": series_id, "timestamp": pd.date_range(end="2001-10-01", periods=length, freq=freq)}
|
|
|
|
|
for target_col in target_cols:
|
|
|
|
|
series_data[target_col] = np.random.randn(length)
|
|
|
|
|
if covariates:
|
|
|
|
|
for cov in covariates:
|
|
|
|
|
series_data[cov] = np.random.randn(length)
|
|
|
|
|
series_dfs.append(pd.DataFrame(series_data))
|
|
|
|
|
return pd.concat(series_dfs, ignore_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_future_df(forecast_start_times: list, series_ids=["A", "B"], n_points=[5, 5], covariates=None, freq="h"):
|
|
|
|
|
"""Helper to create test future DataFrames."""
|
|
|
|
|
series_dfs = []
|
|
|
|
|
for series_id, length, start in zip(series_ids, n_points, forecast_start_times):
|
|
|
|
|
series_data = {"item_id": series_id, "timestamp": pd.date_range(start=start, periods=length, freq=freq)}
|
|
|
|
|
if covariates:
|
|
|
|
|
for cov in covariates:
|
|
|
|
|
series_data[cov] = np.random.randn(length)
|
|
|
|
|
series_dfs.append(pd.DataFrame(series_data))
|
|
|
|
|
return pd.concat(series_dfs, ignore_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_forecast_start_times(df, freq="h"):
|
|
|
|
|
context_end_times = df.groupby("item_id")["timestamp"].max()
|
|
|
|
|
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
|
|
|
|
|
|
2025-12-17 10:45:16 +00:00
|
|
|
return forecast_start_times
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def timeout_callback(seconds: float | None) -> Callable:
|
|
|
|
|
"""Return a callback object that raises an exception if time limit is exceeded."""
|
|
|
|
|
start_time = time.monotonic()
|
|
|
|
|
|
|
|
|
|
def callback() -> None:
|
|
|
|
|
if seconds is not None and time.monotonic() - start_time > seconds:
|
|
|
|
|
raise TimeoutError("time limit exceeded")
|
|
|
|
|
|
|
|
|
|
return callback
|