mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Add vLLM plugin for Chronos-2 inference
Signed-off-by: Li Zhang <lzhanga@amazon.com>
This commit is contained in:
parent
f951d9aefa
commit
ba47d25a04
20 changed files with 2707 additions and 0 deletions
|
|
@ -52,6 +52,10 @@ test = [
|
|||
"fev>=0.6.1",
|
||||
"pandas[pyarrow]>=2.0,<2.4",
|
||||
]
|
||||
vllm = [
|
||||
"vllm>=0.13.0",
|
||||
"pydantic>=2.0",
|
||||
]
|
||||
typecheck = ["mypy~=1.9"]
|
||||
dev = [
|
||||
"gluonts[pro]~=0.16",
|
||||
|
|
@ -65,6 +69,12 @@ dev = [
|
|||
"pandas>=2.0,<2.4",
|
||||
]
|
||||
|
||||
[project.entry-points."vllm.general_plugins"]
|
||||
chronos2 = "chronos.chronos2.vllm:register_chronos2_model"
|
||||
|
||||
[project.entry-points."vllm.io_processor_plugins"]
|
||||
chronos2 = "chronos.chronos2.vllm:get_chronos2_io_processor"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/amazon-science/chronos-forecasting"
|
||||
Issues = "https://github.com/amazon-science/chronos-forecasting/issues"
|
||||
|
|
|
|||
179
src/chronos/chronos2/vllm/README.md
Normal file
179
src/chronos/chronos2/vllm/README.md
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
# Chronos-2 vLLM Plugin
|
||||
|
||||
A [vLLM](https://github.com/vllm-project/vllm) plugin that adds support for [Chronos-2](https://github.com/amazon-science/chronos-forecasting) time series forecasting via the `/pooling` API endpoint.
|
||||
|
||||
## Overview
|
||||
|
||||
Chronos-2 is an encoder-only time series foundation model for zero-shot forecasting. This plugin integrates it with vLLM using the **IOProcessor** plugin interface, so forecast requests are served through vLLM's standard pooling endpoint.
|
||||
|
||||
### Features
|
||||
|
||||
- **Zero-shot forecasting** — no fine-tuning required
|
||||
- **Quantile predictions** — probabilistic forecasts with customizable quantile levels
|
||||
- **Cross-series learning** — information sharing across time series in a batch
|
||||
- **Covariates support** — past and future covariates (numeric and categorical)
|
||||
- **Batch forecasting** — process multiple time series in a single request
|
||||
|
||||
## Installation
|
||||
|
||||
Requires Python 3.10+ and vLLM 0.13.0+.
|
||||
|
||||
```bash
|
||||
pip install chronos-forecasting[vllm]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start the Server
|
||||
|
||||
```bash
|
||||
vllm serve amazon/chronos-2 \
|
||||
--io-processor-plugin chronos2 \
|
||||
--runner pooling \
|
||||
--enforce-eager \
|
||||
--no-enable-prefix-caching \
|
||||
--skip-tokenizer-init \
|
||||
--enable-mm-embeds \
|
||||
--dtype float32 \
|
||||
--max-model-len 8192
|
||||
```
|
||||
|
||||
### 2. Send a Forecast Request
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/pooling \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "amazon/chronos-2",
|
||||
"task": "plugin",
|
||||
"data": {
|
||||
"inputs": [
|
||||
{
|
||||
"target": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
|
||||
"item_id": "series_1"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"prediction_length": 5,
|
||||
"quantile_levels": [0.1, 0.5, 0.9]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. Parse the Response
|
||||
|
||||
```json
|
||||
{
|
||||
"request_id": null,
|
||||
"created_at": 1739397600,
|
||||
"data": {
|
||||
"predictions": [
|
||||
{
|
||||
"mean": [11.0, 12.1, 13.0, 14.2, 15.1],
|
||||
"0.1": [9.5, 10.3, 11.0, 11.8, 12.5],
|
||||
"0.5": [11.0, 12.0, 13.0, 14.0, 15.0],
|
||||
"0.9": [12.5, 13.8, 15.0, 16.3, 17.5],
|
||||
"item_id": "series_1"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Request Format
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `model` | `str` | ✅ | Model name (e.g., `"amazon/chronos-2"`) |
|
||||
| `task` | `str` | ✅ | Must be `"plugin"` |
|
||||
| `data.inputs` | `list` | ✅ | List of time series inputs (1–1024) |
|
||||
| `data.parameters` | `dict` | | Forecast parameters |
|
||||
|
||||
#### Time Series Input (`data.inputs[*]`)
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `target` | `list[float]` or `list[list[float]]` | ✅ | Historical values (min 5 observations). 1-D for univariate, 2-D for multivariate. |
|
||||
| `item_id` | `str` | | Identifier echoed in response |
|
||||
| `start` | `str` | | ISO 8601 timestamp |
|
||||
| `past_covariates` | `dict[str, list]` | | Past covariate arrays (must match target length) |
|
||||
| `future_covariates` | `dict[str, list]` | | Future covariate arrays (must match `prediction_length`) |
|
||||
|
||||
#### Parameters (`data.parameters`)
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `prediction_length` | `int` | `1` | Forecast horizon (1–1024) |
|
||||
| `quantile_levels` | `list[float]` | `[0.1, 0.5, 0.9]` | Quantile levels in (0, 1) |
|
||||
| `freq` | `str` | `null` | Pandas frequency string (e.g., `"D"`, `"H"`) |
|
||||
| `batch_size` | `int` | `256` | Inference batch size |
|
||||
| `cross_learning` | `bool` | `false` | Enable cross-series learning |
|
||||
|
||||
### Response Format
|
||||
|
||||
Each prediction in `data.predictions` contains:
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `mean` | `list[float]` | Point forecast (mean/median) |
|
||||
| `"0.1"`, `"0.5"`, etc. | `list[float]` | Named quantile columns matching `quantile_levels` |
|
||||
| `item_id` | `str` | Echoed from input (if provided) |
|
||||
|
||||
## Architecture
|
||||
|
||||
The vLLM model wrapper (`Chronos2ForForecasting`) is a thin adapter that delegates all computation to the existing `chronos.chronos2.model.Chronos2Model`. No model architecture is duplicated.
|
||||
|
||||
### Module Structure
|
||||
|
||||
```
|
||||
src/chronos/chronos2/vllm/
|
||||
├── __init__.py # Plugin entry point & registration
|
||||
├── model.py # Chronos2ForForecasting (thin vLLM wrapper)
|
||||
├── multimodal.py # MM pipeline for "timeseries" modality
|
||||
├── io_processor.py # Chronos2IOProcessor (request/response handling)
|
||||
├── protocol/
|
||||
│ ├── __init__.py
|
||||
│ ├── forecast.py # Pydantic models (TimeSeriesInput, ForecastParameters, etc.)
|
||||
│ ├── validation.py # Input validation logic
|
||||
│ └── data_prep.py # Tensor preparation from validated inputs
|
||||
└── utils/
|
||||
├── __init__.py
|
||||
├── helpers.py # Utility functions
|
||||
└── quantiles.py # Quantile selection & interpolation
|
||||
```
|
||||
|
||||
### Key Classes
|
||||
|
||||
| Class | File | Purpose |
|
||||
|---|---|---|
|
||||
| `Chronos2ForForecasting` | `model.py` | Thin vLLM wrapper — delegates to `chronos.chronos2.model.Chronos2Model` |
|
||||
| `Chronos2IOProcessor` | `io_processor.py` | Request parsing, validation, pre/post processing |
|
||||
| `ForecastParameters` | `protocol/forecast.py` | Pydantic validation for forecast parameters |
|
||||
| `TimeSeriesInput` | `protocol/forecast.py` | Pydantic validation for time series inputs |
|
||||
| `ForecastPrediction` | `protocol/forecast.py` | Pydantic model for forecast output |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Server Flags
|
||||
|
||||
The following flags are required for Chronos-2:
|
||||
|
||||
```bash
|
||||
--io-processor-plugin chronos2 # Enable the forecast IOProcessor
|
||||
--enforce-eager # Chronos-2 doesn't support CUDA graphs
|
||||
--no-enable-prefix-caching # Not applicable for time series
|
||||
--skip-tokenizer-init # Chronos-2 doesn't use a text tokenizer
|
||||
```
|
||||
|
||||
### Plugin Not Loading
|
||||
|
||||
1. Verify installation: `pip list | grep chronos-forecasting`
|
||||
2. Check entry points:
|
||||
```python
|
||||
from importlib.metadata import entry_points
|
||||
print(list(entry_points(group='vllm.general_plugins')))
|
||||
```
|
||||
3. Enable debug logging: `VLLM_LOGGING_LEVEL=DEBUG vllm serve ...`
|
||||
48
src/chronos/chronos2/vllm/__init__.py
Normal file
48
src/chronos/chronos2/vllm/__init__.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""vLLM Chronos-2 Time Series Forecasting Model Plugin.
|
||||
|
||||
This plugin registers the Chronos-2 model with vLLM's ModelRegistry,
|
||||
allowing it to be used with vLLM's inference engine for time series forecasting.
|
||||
"""
|
||||
|
||||
|
||||
def register_chronos2_model() -> None:
|
||||
"""Register Chronos-2 models with vLLM's ModelRegistry.
|
||||
|
||||
This function is called automatically when the plugin is loaded
|
||||
through vLLM's plugin discovery mechanism.
|
||||
"""
|
||||
try:
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Chronos2Model",
|
||||
"chronos.chronos2.vllm.model:Chronos2ForForecasting",
|
||||
)
|
||||
|
||||
logger.info("Successfully registered Chronos-2 model with vLLM")
|
||||
|
||||
except Exception as e:
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
logger.error(f"Failed to register Chronos-2 model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_chronos2_io_processor():
|
||||
"""
|
||||
Factory function for IOProcessor plugin registration.
|
||||
|
||||
This function is called by vLLM when --io-processor-plugin is set to 'chronos2'.
|
||||
Returns the fully qualified name (module.Class format) as a string.
|
||||
"""
|
||||
return "chronos.chronos2.vllm.io_processor.Chronos2IOProcessor"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_chronos2_model",
|
||||
"get_chronos2_io_processor",
|
||||
]
|
||||
247
src/chronos/chronos2/vllm/io_processor.py
Normal file
247
src/chronos/chronos2/vllm/io_processor.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""IOProcessor for Chronos-2 time series forecasting.
|
||||
|
||||
Thin orchestrator that delegates to:
|
||||
- protocol.data_prep: tensor preparation from validated inputs (via Chronos2Dataset)
|
||||
- protocol.validation: cross-series validation
|
||||
- multimodal.MODALITY: MM prompt construction
|
||||
- utils: quantile selection, helpers
|
||||
|
||||
Flow:
|
||||
1. parse_request → validate inputs via Pydantic models
|
||||
2. pre_process → prepare_request() → timeseries MM prompts (one per batch)
|
||||
3. post_process → select_quantiles() → per-series predictions
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
from chronos.chronos2.vllm.multimodal import MODALITY
|
||||
from chronos.chronos2.vllm.protocol.data_prep import PreparedRequest, prepare_request
|
||||
from chronos.chronos2.vllm.protocol.forecast import (
|
||||
ForecastParameters,
|
||||
ForecastPrediction,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
from chronos.chronos2.vllm.protocol.validation import validate_cross_series
|
||||
from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list
|
||||
from chronos.chronos2.vllm.utils.quantiles import select_quantiles
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PostProcessInfo:
|
||||
"""Lightweight cache of data needed for post_process — avoids storing full tensors."""
|
||||
|
||||
item_ids: list[str | None]
|
||||
parameters: ForecastParameters
|
||||
target_idx_ranges: list[list[tuple[int, int]]] # per-batch list of (start, end) ranges
|
||||
|
||||
|
||||
class Chronos2IOProcessor(IOProcessor[dict, dict]):
|
||||
"""IOProcessor for Chronos-2 time series forecasting."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
config = vllm_config.model_config.hf_config
|
||||
config.is_encoder_decoder = False
|
||||
# Set projection_dim=0 so vLLM uses IdentityPooler (no projection layer).
|
||||
# This avoids requiring --hf-overrides '{"projection_dim": 0}' on the CLI.
|
||||
if not hasattr(config, "projection_dim") or config.projection_dim != 0:
|
||||
config.projection_dim = 0
|
||||
|
||||
cc = getattr(config, "chronos_config", {})
|
||||
self._chronos_config = (
|
||||
cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {})
|
||||
)
|
||||
self._context_length = self._chronos_config.get("context_length", 8192)
|
||||
self._output_patch_size = self._chronos_config.get("output_patch_size", 16)
|
||||
self._model_quantiles = self._chronos_config.get(
|
||||
"quantiles",
|
||||
[
|
||||
0.01,
|
||||
0.05,
|
||||
0.1,
|
||||
0.15,
|
||||
0.2,
|
||||
0.25,
|
||||
0.3,
|
||||
0.35,
|
||||
0.4,
|
||||
0.45,
|
||||
0.5,
|
||||
0.55,
|
||||
0.6,
|
||||
0.65,
|
||||
0.7,
|
||||
0.75,
|
||||
0.8,
|
||||
0.85,
|
||||
0.9,
|
||||
0.95,
|
||||
0.99,
|
||||
],
|
||||
)
|
||||
self._request_info: dict[str, _PostProcessInfo] = {}
|
||||
logger.info("Initialized Chronos2IOProcessor")
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Request parsing
|
||||
# -----------------------------------------------------------
|
||||
|
||||
def parse_request(self, request: Any) -> dict:
|
||||
data = request.data if hasattr(request, "data") else request
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict, got {type(data)}")
|
||||
if "inputs" not in data:
|
||||
raise ValueError("Request must contain 'inputs' field")
|
||||
raw_inputs = data["inputs"]
|
||||
if not isinstance(raw_inputs, list) or len(raw_inputs) == 0:
|
||||
raise ValueError("'inputs' must be a non-empty list")
|
||||
|
||||
validated_inputs = []
|
||||
for i, ts in enumerate(raw_inputs):
|
||||
try:
|
||||
validated_inputs.append(TimeSeriesInput(**ts))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid input at index {i}: {e}") from e
|
||||
|
||||
try:
|
||||
validated_params = ForecastParameters(**data.get("parameters", {}))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid parameters: {e}") from e
|
||||
|
||||
validate_cross_series(validated_inputs, validated_params)
|
||||
return {"inputs": validated_inputs, "parameters": validated_params}
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Pre-process: inputs → MM prompts
|
||||
# -----------------------------------------------------------
|
||||
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: dict,
|
||||
request_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
prepared = prepare_request(
|
||||
inputs=prompt["inputs"],
|
||||
parameters=prompt["parameters"],
|
||||
context_length=self._context_length,
|
||||
output_patch_size=self._output_patch_size,
|
||||
)
|
||||
|
||||
# Cache only lightweight post-processing metadata
|
||||
if request_id is not None:
|
||||
self._request_info[request_id] = _PostProcessInfo(
|
||||
item_ids=prepared.item_ids,
|
||||
parameters=prepared.parameters,
|
||||
target_idx_ranges=[batch.target_idx_ranges for batch in prepared.batches],
|
||||
)
|
||||
|
||||
# Build one MM prompt per batch (Chronos2Dataset handles batch splitting)
|
||||
prompts: list[PromptType] = []
|
||||
for batch in prepared.batches:
|
||||
prompts.append(
|
||||
{
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": {
|
||||
MODALITY: {
|
||||
"context": batch.context,
|
||||
"future_covariates": batch.future_covariates,
|
||||
"group_ids": batch.group_ids,
|
||||
"num_output_patches": batch.num_output_patches,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return prompts if len(prompts) > 1 else prompts[0]
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Post-process: model output → predictions
|
||||
# -----------------------------------------------------------
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
info = self._request_info.pop(request_id, None) if request_id else None
|
||||
if info is None:
|
||||
logger.error("No pending request for request_id=%s", request_id)
|
||||
return {"predictions": []}
|
||||
|
||||
parameters = info.parameters
|
||||
item_ids = info.item_ids
|
||||
|
||||
try:
|
||||
# Collect predictions across all batches, trimmed to exact prediction_length
|
||||
pred_len = parameters.prediction_length
|
||||
all_predictions: list[torch.Tensor] = []
|
||||
for batch_idx, output in enumerate(model_output):
|
||||
tensor = output.outputs.data
|
||||
while tensor.ndim > 3:
|
||||
tensor = tensor.squeeze(0)
|
||||
|
||||
batch_ranges = info.target_idx_ranges[batch_idx]
|
||||
for start, end in batch_ranges:
|
||||
all_predictions.append(tensor[start:end, :, :pred_len])
|
||||
|
||||
quantiles_out, mean_out = select_quantiles(
|
||||
all_predictions, self._model_quantiles, parameters.quantile_levels
|
||||
)
|
||||
|
||||
result: dict[str, Any] = {"predictions": [], "request_id": request_id}
|
||||
for i, (q_tensor, m_tensor) in enumerate(zip(quantiles_out, mean_out)):
|
||||
pred: dict[str, Any] = {"mean": tensor_to_list(m_tensor)}
|
||||
for q, q_vals in zip(parameters.quantile_levels, q_tensor.unbind(dim=-1)):
|
||||
pred[str(q)] = tensor_to_list(q_vals)
|
||||
if i < len(item_ids) and item_ids[i] is not None:
|
||||
pred["item_id"] = item_ids[i]
|
||||
result["predictions"].append(pred)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to decode predictions: %s", e, exc_info=True)
|
||||
result = {
|
||||
"predictions": [
|
||||
empty_prediction(parameters.prediction_length, parameters.quantile_levels)
|
||||
for _ in item_ids
|
||||
],
|
||||
"request_id": request_id,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Response formatting
|
||||
# -----------------------------------------------------------
|
||||
|
||||
def validate_or_generate_params(self, params: Any = None) -> PoolingParams:
|
||||
return PoolingParams(task="plugin")
|
||||
|
||||
def output_to_response(self, plugin_output: dict) -> IOProcessorResponse:
|
||||
validated = []
|
||||
for pred in plugin_output.get("predictions", []):
|
||||
try:
|
||||
validated.append(ForecastPrediction(**pred).model_dump(exclude_none=True))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to validate prediction: %s", e)
|
||||
validated.append(pred)
|
||||
return IOProcessorResponse(
|
||||
request_id=plugin_output.get("request_id"),
|
||||
created_at=int(time.time()),
|
||||
data={"predictions": validated},
|
||||
)
|
||||
196
src/chronos/chronos2/vllm/model.py
Normal file
196
src/chronos/chronos2/vllm/model.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""Chronos-2 vLLM model wrapper.
|
||||
|
||||
Thin wrapper around the existing ``chronos.chronos2.model.Chronos2Model``
|
||||
that plugs into vLLM's multimodal (MM) interface. No model architecture
|
||||
is duplicated — all computation is delegated to the upstream implementation.
|
||||
|
||||
Architecture:
|
||||
pre_process → IOProcessor prepares context/covariates/group_ids tensors
|
||||
→ Chronos2Dataset handles batching and group_id construction
|
||||
→ returns MM prompt(s) with timeseries data
|
||||
forward() → receives pre-batched context/future_covariates/group_ids as kwargs
|
||||
→ delegates to chronos.chronos2.model.Chronos2Model
|
||||
pooler → IdentityPooler passes output through unchanged
|
||||
post_process → IOProcessor selects/interpolates quantiles
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import IdentityPooler
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
IsAttentionFree,
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces_base import attn_type
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from chronos.chronos2.model import Chronos2Model
|
||||
|
||||
from .multimodal import (
|
||||
MODALITY,
|
||||
ChronosInputBuilder,
|
||||
ChronosMultiModalProcessor,
|
||||
ChronosProcessingInfo,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@attn_type("attention_free")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
ChronosMultiModalProcessor,
|
||||
info=ChronosProcessingInfo,
|
||||
dummy_inputs=ChronosInputBuilder,
|
||||
)
|
||||
class Chronos2ForForecasting(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
"""Chronos-2 forecasting model for vLLM.
|
||||
|
||||
Delegates all computation to ``chronos.chronos2.model.Chronos2Model``.
|
||||
Receives pre-batched tensors (context, future_covariates, group_ids)
|
||||
directly in forward() via the timeseries MM pipeline. Batching is
|
||||
handled upstream by ``Chronos2Dataset`` in the IOProcessor.
|
||||
"""
|
||||
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
# Required by VllmModel interface
|
||||
packed_modules_mapping: dict[str, Any] = {}
|
||||
supported_lora_modules: list[str] = []
|
||||
embedding_modules: dict[str, str] = {}
|
||||
embedding_padding_modules: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith(MODALITY):
|
||||
return None
|
||||
raise ValueError(f"Only {MODALITY} modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config.is_encoder_decoder = False
|
||||
if not hasattr(config, "projection_dim") or config.projection_dim != 0:
|
||||
config.projection_dim = 0
|
||||
vllm_config.model_config.hf_text_config = None
|
||||
|
||||
self.config = config
|
||||
self.model_name = vllm_config.model_config.model
|
||||
self.d_model = getattr(config, "d_model", 768)
|
||||
|
||||
cc = getattr(config, "chronos_config", {})
|
||||
self.chronos_config = (
|
||||
cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {})
|
||||
)
|
||||
self.output_patch_size = self.chronos_config.get("output_patch_size", 16)
|
||||
self.max_output_patches = self.chronos_config.get("max_output_patches", 64)
|
||||
|
||||
# Instantiate the upstream Chronos2Model — reuses all existing layers
|
||||
self.model = Chronos2Model(config)
|
||||
|
||||
self.pooler = IdentityPooler()
|
||||
|
||||
logger.info(
|
||||
"Initialized Chronos2ForForecasting (d_model=%d, delegating to chronos.chronos2.model)",
|
||||
self.d_model,
|
||||
)
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Return empty embeddings — Chronos-2 has no token vocabulary."""
|
||||
return torch.empty((input_ids.shape[0], 0))
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
"""Run Chronos-2 inference by delegating to the upstream model.
|
||||
|
||||
Receives pre-batched context, future_covariates, group_ids from
|
||||
the timeseries MM pipeline. Batching is handled upstream by
|
||||
Chronos2Dataset in the IOProcessor.
|
||||
"""
|
||||
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
|
||||
|
||||
context: torch.Tensor | None = kwargs.get("context") # type: ignore[assignment]
|
||||
future_covariates: torch.Tensor | None = kwargs.get("future_covariates") # type: ignore[assignment]
|
||||
group_ids: torch.Tensor | None = kwargs.get("group_ids") # type: ignore[assignment]
|
||||
num_output_patches: int | None = kwargs.get("num_output_patches") # type: ignore[assignment]
|
||||
|
||||
if context is None:
|
||||
# Warmup/profiling pass — return zeros
|
||||
return torch.zeros(input_len, 0, device=positions.device, dtype=torch.float32)
|
||||
|
||||
# Determine num_output_patches from preprocessor or fall back to computing from future_covariates
|
||||
if num_output_patches is None:
|
||||
prediction_length = future_covariates.shape[1] if future_covariates is not None else 0
|
||||
if prediction_length == 0:
|
||||
prediction_length = self.output_patch_size * self.max_output_patches
|
||||
num_output_patches = int(math.ceil(prediction_length / self.output_patch_size))
|
||||
num_output_patches = min(num_output_patches, self.max_output_patches)
|
||||
|
||||
prediction_length = num_output_patches * self.output_patch_size
|
||||
|
||||
# Pad or trim future_covariates to match output_size
|
||||
if future_covariates is not None:
|
||||
if prediction_length > future_covariates.shape[1]:
|
||||
pad_size = prediction_length - future_covariates.shape[1]
|
||||
pad_tensor = torch.full(
|
||||
(future_covariates.shape[0], pad_size),
|
||||
fill_value=float("nan"),
|
||||
device=future_covariates.device,
|
||||
)
|
||||
future_covariates = torch.cat([future_covariates, pad_tensor], dim=1)
|
||||
else:
|
||||
future_covariates = future_covariates[:, :prediction_length]
|
||||
|
||||
# Delegate to the upstream Chronos2Model — single forward pass
|
||||
model_kwargs: dict[str, Any] = {
|
||||
"context": context,
|
||||
"num_output_patches": num_output_patches,
|
||||
}
|
||||
if group_ids is not None:
|
||||
model_kwargs["group_ids"] = group_ids
|
||||
if future_covariates is not None:
|
||||
model_kwargs["future_covariates"] = future_covariates
|
||||
|
||||
output = self.model(**model_kwargs)
|
||||
batch_prediction = output.quantile_preds[..., :prediction_length]
|
||||
|
||||
# Expand to match input_len for vLLM pipeline compatibility.
|
||||
# IdentityPooler passes through unchanged; post_process squeezes.
|
||||
hidden_states = batch_prediction[None].expand(
|
||||
input_len, *(-1 for _ in range(batch_prediction.ndim))
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights into the upstream Chronos2Model.
|
||||
|
||||
Prefix checkpoint names with 'model.' to match our wrapping.
|
||||
Uses AutoWeightsLoader for robust weight loading.
|
||||
"""
|
||||
prefixed = [(f"model.{name}", tensor) for name, tensor in weights]
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(prefixed)
|
||||
237
src/chronos/chronos2/vllm/multimodal.py
Normal file
237
src/chronos/chronos2/vllm/multimodal.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Multimodal boilerplate for the "timeseries" modality in vLLM.
|
||||
|
||||
Provides the MM processing pipeline classes that route timeseries
|
||||
dict data (context, future_covariates, group_ids) through vLLM's
|
||||
multimodal infrastructure.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalFieldElem,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptUpdate,
|
||||
)
|
||||
|
||||
# The modality name used throughout the MM pipeline.
|
||||
MODALITY = "timeseries"
|
||||
|
||||
# Field names expected in the timeseries MM data dict.
|
||||
REQUIRED_FIELDS = frozenset({"context", "future_covariates", "group_ids", "num_output_patches"})
|
||||
|
||||
|
||||
def _field_config() -> dict[str, MultiModalFieldConfig]:
|
||||
"""Shared field config for all timeseries fields."""
|
||||
return {
|
||||
"context": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
|
||||
"future_covariates": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
|
||||
"group_ids": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
|
||||
"num_output_patches": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Processing info: tells vLLM what modalities we support
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChronosProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {MODALITY: None}
|
||||
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
return ChronosMultiModalDataParser()
|
||||
|
||||
def build_data_parser(self) -> MultiModalDataParser:
|
||||
return ChronosMultiModalDataParser()
|
||||
|
||||
@property # type: ignore[override]
|
||||
def data_parser(self) -> MultiModalDataParser:
|
||||
if not hasattr(self, "_data_parser"):
|
||||
self._data_parser = ChronosMultiModalDataParser()
|
||||
return self._data_parser
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Dummy input builder: provides profiling data for GPU warmup
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChronosInputBuilder(BaseDummyInputsBuilder[ChronosProcessingInfo]):
|
||||
"""Provides dummy data for vLLM's GPU profiling/warmup pass."""
|
||||
|
||||
def __init__(self, info: ChronosProcessingInfo):
|
||||
super().__init__(info)
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MultiModalDataDict:
|
||||
return {
|
||||
MODALITY: {
|
||||
"context": torch.ones(100, 2048, dtype=torch.float32),
|
||||
"future_covariates": torch.ones(100, 1024, dtype=torch.float32),
|
||||
"group_ids": torch.zeros(100, dtype=torch.long),
|
||||
"num_output_patches": 64,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Data parser: routes timeseries dict data through the MM pipeline
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChronosMultiModalDataParser(MultiModalDataParser):
|
||||
"""Parses timeseries dict data for vLLM's MM pipeline."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _parse_timeseries_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality=MODALITY,
|
||||
required_fields=REQUIRED_FIELDS,
|
||||
fields_factory=lambda _: _field_config(),
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, Any]:
|
||||
return {MODALITY: self._parse_timeseries_data}
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict, **kwargs: Any) -> MultiModalDataItems:
|
||||
if MODALITY not in mm_data:
|
||||
mm_data = {MODALITY: mm_data}
|
||||
|
||||
ts_data = mm_data[MODALITY]
|
||||
items = self._parse_timeseries_data(ts_data)
|
||||
if items is None:
|
||||
raise ValueError("Failed to parse timeseries data")
|
||||
|
||||
return MultiModalDataItems({MODALITY: items})
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Processor: converts parsed data into MultiModalInputs
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChronosMultiModalProcessor(BaseMultiModalProcessor):
|
||||
"""Processes timeseries MM data into MultiModalInputs for vLLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
info: ChronosProcessingInfo,
|
||||
dummy_inputs: BaseDummyInputsBuilder[ChronosProcessingInfo],
|
||||
*,
|
||||
cache: MultiModalProcessorOnlyCache | None = None,
|
||||
) -> None:
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _field_config()
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
return []
|
||||
|
||||
def apply(self, *args: Any, **kwargs: Any) -> MultiModalInputs:
|
||||
# Handle both old and new vLLM calling conventions:
|
||||
# Old: apply(prompt, mm_items, hf_processor_mm_kwargs, ...)
|
||||
# New: apply(processor_inputs, timing_ctx)
|
||||
# Internal (from get_dummy_mm_inputs): apply(prompt=..., mm_items=..., ...)
|
||||
|
||||
ts_data: dict[str, Any] = {}
|
||||
|
||||
if args and hasattr(args[0], "prompt"):
|
||||
# New vLLM: first arg is ProcessorInputs dataclass
|
||||
processor_inputs = args[0]
|
||||
mm_items = getattr(processor_inputs, "mm_items", None) or getattr(
|
||||
processor_inputs, "mm_data_items", None
|
||||
)
|
||||
if mm_items is not None and isinstance(mm_items, MultiModalDataItems):
|
||||
if MODALITY in mm_items:
|
||||
ts_items = mm_items[MODALITY]
|
||||
ts_data = ts_items.data if hasattr(ts_items, "data") else {}
|
||||
else:
|
||||
# Old vLLM / direct call: extract from positional/keyword args
|
||||
mm_items = args[1] if len(args) > 1 else kwargs.get("mm_items")
|
||||
mm_data = kwargs.get("mm_data")
|
||||
|
||||
if mm_items is not None and isinstance(mm_items, MultiModalDataItems):
|
||||
if MODALITY in mm_items:
|
||||
ts_items = mm_items[MODALITY]
|
||||
ts_data = ts_items.data if hasattr(ts_items, "data") else {}
|
||||
elif mm_data is not None and isinstance(mm_data, dict):
|
||||
ts_data = mm_data.get(MODALITY, mm_data)
|
||||
|
||||
mm_placeholders = {MODALITY: [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
# Build MultiModalKwargsItems directly to ensure proper modality keying.
|
||||
# from_hf_inputs + BatchFeature can produce empty modalities when
|
||||
# the data doesn't match the expected shared field batch structure.
|
||||
field_config = _field_config()
|
||||
mm_item_dict: dict[str, MultiModalFieldElem] = {}
|
||||
for key, config in field_config.items():
|
||||
tensor = ts_data.get(key) if isinstance(ts_data, dict) else None
|
||||
if tensor is not None:
|
||||
mm_item_dict[key] = MultiModalFieldElem(
|
||||
data=tensor,
|
||||
field=config.field,
|
||||
)
|
||||
|
||||
mm_kwargs_items = MultiModalKwargsItems({MODALITY: [MultiModalKwargsItem(mm_item_dict)]})
|
||||
|
||||
# Unique hash per request (required by vLLM v0.16+)
|
||||
ts_hash = hashlib.sha256(
|
||||
str(id(ts_data)).encode() + str(time.monotonic()).encode()
|
||||
).hexdigest()[:16]
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=mm_kwargs_items,
|
||||
mm_hashes={MODALITY: [ts_hash]},
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
17
src/chronos/chronos2/vllm/protocol/__init__.py
Normal file
17
src/chronos/chronos2/vllm/protocol/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Chronos-2 forecasting protocol definitions and data preparation."""
|
||||
|
||||
from chronos.chronos2.vllm.protocol.forecast import (
|
||||
ForecastParameters,
|
||||
ForecastPrediction,
|
||||
ForecastRequest,
|
||||
ForecastResponse,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TimeSeriesInput",
|
||||
"ForecastParameters",
|
||||
"ForecastRequest",
|
||||
"ForecastPrediction",
|
||||
"ForecastResponse",
|
||||
]
|
||||
136
src/chronos/chronos2/vllm/protocol/data_prep.py
Normal file
136
src/chronos/chronos2/vllm/protocol/data_prep.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Data preparation for Chronos-2 model input tensors.
|
||||
|
||||
Converts validated TimeSeriesInput objects into batched tensors
|
||||
ready for the model by delegating to ``chronos.chronos2.dataset``.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode
|
||||
from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreparedBatch:
|
||||
"""A single batch of prepared model input tensors."""
|
||||
|
||||
context: torch.Tensor # (batch_rows, padded_ctx_len)
|
||||
future_covariates: torch.Tensor # (batch_rows, prediction_length)
|
||||
group_ids: torch.Tensor # (batch_rows,)
|
||||
num_output_patches: int
|
||||
target_idx_ranges: list[tuple[int, int]] # per-input-series (start, end) in this batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreparedRequest:
|
||||
"""All batches and metadata for a single forecast request."""
|
||||
|
||||
batches: list[PreparedBatch]
|
||||
item_ids: list[str | None]
|
||||
parameters: ForecastParameters
|
||||
|
||||
|
||||
def _timeseries_input_to_dict(ts: TimeSeriesInput) -> dict[str, Any]:
|
||||
"""Convert a Pydantic TimeSeriesInput to the dict format expected by Chronos2Dataset.
|
||||
|
||||
The dataset expects:
|
||||
- ``target``: np.ndarray of shape (history_length,) or (n_variates, history_length)
|
||||
- ``past_covariates``: dict[str, np.ndarray]
|
||||
- ``future_covariates``: dict[str, np.ndarray]
|
||||
"""
|
||||
target = np.array(ts.target, dtype=np.float32)
|
||||
|
||||
entry: dict[str, Any] = {"target": target}
|
||||
|
||||
if ts.past_covariates:
|
||||
past_covs: dict[str, np.ndarray] = {}
|
||||
for key, vals in ts.past_covariates.items():
|
||||
arr = np.array([v if v is not None else np.nan for v in vals])
|
||||
# Keep string arrays as object dtype for categorical encoding
|
||||
if any(isinstance(v, str) for v in vals if v is not None):
|
||||
arr = np.array([str(v) if v is not None else "nan" for v in vals])
|
||||
else:
|
||||
arr = arr.astype(np.float32)
|
||||
past_covs[key] = arr
|
||||
entry["past_covariates"] = past_covs
|
||||
|
||||
if ts.future_covariates:
|
||||
future_covs: dict[str, np.ndarray] = {}
|
||||
for key, vals in ts.future_covariates.items():
|
||||
arr = np.array([v if v is not None else np.nan for v in vals])
|
||||
if any(isinstance(v, str) for v in vals if v is not None):
|
||||
arr = np.array([str(v) if v is not None else "nan" for v in vals])
|
||||
else:
|
||||
arr = arr.astype(np.float32)
|
||||
future_covs[key] = arr
|
||||
entry["future_covariates"] = future_covs
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def prepare_request(
|
||||
inputs: list[TimeSeriesInput],
|
||||
parameters: ForecastParameters,
|
||||
context_length: int = 8192,
|
||||
output_patch_size: int = 16,
|
||||
) -> PreparedRequest:
|
||||
"""Convert validated inputs into batched model-ready tensors.
|
||||
|
||||
Delegates to ``Chronos2Dataset`` in TEST mode for data preparation,
|
||||
covariate encoding, batching, and group ID construction.
|
||||
|
||||
Args:
|
||||
inputs: validated time series inputs
|
||||
parameters: forecast parameters
|
||||
context_length: maximum context length (from model config)
|
||||
output_patch_size: model's output patch size (from model config)
|
||||
|
||||
Returns:
|
||||
PreparedRequest with batched tensors and metadata
|
||||
"""
|
||||
# Convert Pydantic models to dicts for Chronos2Dataset
|
||||
raw_inputs = [_timeseries_input_to_dict(ts) for ts in inputs]
|
||||
item_ids = [ts.item_id for ts in inputs]
|
||||
|
||||
pred_len = parameters.prediction_length
|
||||
|
||||
# Use Chronos2Dataset in TEST mode — handles validation, covariate encoding,
|
||||
# left-padding, and group_id construction. We use a very large batch_size
|
||||
# to always produce exactly one batch; row-level chunking (if needed for
|
||||
# memory) can be handled by the caller or the model's forward pass.
|
||||
dataset = Chronos2Dataset(
|
||||
inputs=raw_inputs,
|
||||
context_length=context_length,
|
||||
prediction_length=pred_len,
|
||||
batch_size=2**31 - 1, # large enough to fit all series in one batch
|
||||
output_patch_size=output_patch_size,
|
||||
mode=DatasetMode.TEST,
|
||||
convert_inputs=True,
|
||||
)
|
||||
|
||||
batches: list[PreparedBatch] = []
|
||||
for batch_dict in dataset:
|
||||
group_ids = batch_dict["group_ids"]
|
||||
if parameters.cross_learning:
|
||||
group_ids = torch.zeros_like(group_ids)
|
||||
|
||||
batches.append(
|
||||
PreparedBatch(
|
||||
context=batch_dict["context"],
|
||||
future_covariates=batch_dict["future_covariates"],
|
||||
group_ids=group_ids,
|
||||
num_output_patches=batch_dict["num_output_patches"],
|
||||
target_idx_ranges=batch_dict["target_idx_ranges"],
|
||||
)
|
||||
)
|
||||
|
||||
return PreparedRequest(
|
||||
batches=batches,
|
||||
item_ids=item_ids,
|
||||
parameters=parameters,
|
||||
)
|
||||
218
src/chronos/chronos2/vllm/protocol/forecast.py
Normal file
218
src/chronos/chronos2/vllm/protocol/forecast.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
import math
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
try:
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
except ImportError:
|
||||
OpenAIBaseModel = BaseModel # type: ignore[misc,assignment]
|
||||
|
||||
from chronos.chronos2.vllm.protocol.validation import (
|
||||
MAX_NUM_TIME_SERIES,
|
||||
validate_quantile_levels,
|
||||
validate_single_series_covariates,
|
||||
validate_start_timestamp,
|
||||
validate_target,
|
||||
)
|
||||
|
||||
|
||||
class TimeSeriesInput(BaseModel):
|
||||
"""Input time series data for forecasting."""
|
||||
|
||||
target: list[float] | list[list[float]] = Field(
|
||||
...,
|
||||
description="Historical time series values. "
|
||||
"1-D array for univariate, 2-D array for multivariate",
|
||||
)
|
||||
|
||||
item_id: str | None = Field(default=None, description="Unique identifier for the time series")
|
||||
|
||||
start: str | None = Field(
|
||||
default=None,
|
||||
description="Start timestamp in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)",
|
||||
)
|
||||
|
||||
# Values can be NaN (numeric) or None/NaN mixed in string arrays.
|
||||
past_covariates: dict[str, list[Union[float, str, None]]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Dictionary of past covariate arrays (numeric or categorical). "
|
||||
"Each array must match the length of the target"
|
||||
),
|
||||
)
|
||||
|
||||
# Values can be NaN (numeric) or None/NaN mixed in string arrays.
|
||||
future_covariates: dict[str, list[Union[float, str, None]]] | None = Field(
|
||||
default=None,
|
||||
description="Dictionary of known future covariate arrays. "
|
||||
"Keys must be a subset of past_covariates. "
|
||||
"Each array must match prediction_length",
|
||||
)
|
||||
|
||||
@field_validator("past_covariates", "future_covariates", mode="before")
|
||||
@classmethod
|
||||
def _sanitize_nan_in_covariates(
|
||||
cls, v: dict[str, list[Any]] | None
|
||||
) -> dict[str, list[Any]] | None:
|
||||
"""Convert NaN floats to None in covariate arrays.
|
||||
|
||||
Datasets often encode missing categorical values as float NaN.
|
||||
Pydantic rejects NaN in string-typed lists, so we normalize
|
||||
NaN → None before validation.
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
sanitized: dict[str, list[Any]] = {}
|
||||
for key, arr in v.items():
|
||||
sanitized[key] = [None if isinstance(x, float) and math.isnan(x) else x for x in arr]
|
||||
return sanitized
|
||||
|
||||
@field_validator("target")
|
||||
@classmethod
|
||||
def _validate_target(
|
||||
cls, v: list[float] | list[list[float]]
|
||||
) -> list[float] | list[list[float]]:
|
||||
return validate_target(v)
|
||||
|
||||
@field_validator("start")
|
||||
@classmethod
|
||||
def _validate_start(cls, v: str | None) -> str | None:
|
||||
return validate_start_timestamp(v)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_covariates(self) -> "TimeSeriesInput":
|
||||
validate_single_series_covariates(self.target, self.past_covariates, self.future_covariates)
|
||||
return self
|
||||
|
||||
|
||||
class ForecastParameters(BaseModel):
|
||||
"""Parameters for time series forecasting."""
|
||||
|
||||
prediction_length: int = Field(
|
||||
default=1, ge=1, le=1024, description="Number of future steps to forecast"
|
||||
)
|
||||
|
||||
quantile_levels: list[float] = Field(
|
||||
default=[0.1, 0.5, 0.9],
|
||||
description="Quantile levels for uncertainty quantification. "
|
||||
"Each value must be between 0 and 1 (exclusive)",
|
||||
)
|
||||
|
||||
freq: str | None = Field(
|
||||
default=None,
|
||||
description="Pandas frequency string (e.g., 'D' for daily, 'H' for hourly). "
|
||||
"Required if 'start' is provided in inputs",
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
default=256,
|
||||
ge=1,
|
||||
description="Internal row batch size for model inference. "
|
||||
"Controls how many rows (series + covariates) are processed "
|
||||
"in a single model forward pass. Rows are chunked internally "
|
||||
"respecting series boundaries via group_ids.",
|
||||
)
|
||||
|
||||
cross_learning: bool = Field(
|
||||
default=False,
|
||||
description="Enable information sharing across time series in batch",
|
||||
)
|
||||
|
||||
@field_validator("quantile_levels")
|
||||
@classmethod
|
||||
def _validate_quantiles(cls, v: list[float]) -> list[float]:
|
||||
return validate_quantile_levels(v)
|
||||
|
||||
|
||||
class ForecastRequest(OpenAIBaseModel):
|
||||
"""Request format for time series forecasting via pooling API."""
|
||||
|
||||
model: str = Field(..., description="Model name to use for forecasting")
|
||||
|
||||
task: Literal["forecast"] = Field(
|
||||
default="forecast", description="Task type, must be 'forecast'"
|
||||
)
|
||||
|
||||
data: dict[str, Any] = Field(
|
||||
...,
|
||||
description=("Forecast request data containing 'inputs' and optional 'parameters'"),
|
||||
)
|
||||
|
||||
@field_validator("data")
|
||||
@classmethod
|
||||
def validate_data_structure(cls, v: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the data field contains required structure."""
|
||||
if "inputs" not in v:
|
||||
raise ValueError("data must contain 'inputs' field")
|
||||
|
||||
if not isinstance(v["inputs"], list):
|
||||
raise ValueError("data.inputs must be a list")
|
||||
|
||||
if len(v["inputs"]) == 0:
|
||||
raise ValueError("data.inputs cannot be empty")
|
||||
|
||||
if len(v["inputs"]) > MAX_NUM_TIME_SERIES:
|
||||
raise ValueError(
|
||||
f"data.inputs may contain at most {MAX_NUM_TIME_SERIES} time series "
|
||||
f"(received {len(v['inputs'])})"
|
||||
)
|
||||
|
||||
# Validate each input as TimeSeriesInput
|
||||
validated_inputs = []
|
||||
for i, ts_input in enumerate(v["inputs"]):
|
||||
try:
|
||||
validated_input = TimeSeriesInput(**ts_input)
|
||||
validated_inputs.append(validated_input)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid time series input at index {i}: {e}") from e
|
||||
|
||||
# Validate parameters if present
|
||||
validated_params = None
|
||||
if "parameters" in v and v["parameters"] is not None:
|
||||
try:
|
||||
validated_params = ForecastParameters(**v["parameters"])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid parameters: {e}") from e
|
||||
|
||||
# Cross-validate future_covariates length with prediction_length
|
||||
if validated_params is not None:
|
||||
prediction_length = validated_params.prediction_length
|
||||
for i, ts_input in enumerate(validated_inputs):
|
||||
if ts_input.future_covariates is not None:
|
||||
for key, values in ts_input.future_covariates.items():
|
||||
if len(values) != prediction_length:
|
||||
raise ValueError(
|
||||
f"Input {i}: future_covariate '{key}' length "
|
||||
f"({len(values)}) must match prediction_length "
|
||||
f"({prediction_length})"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class ForecastPrediction(BaseModel):
|
||||
"""Single time series forecast result with quantile forecasts."""
|
||||
|
||||
mean: list[float] | list[list[float]] = Field(
|
||||
..., description="Point forecast (mean). Shape matches input target"
|
||||
)
|
||||
|
||||
item_id: str | None = Field(default=None, description="Echoed from input if provided")
|
||||
|
||||
start: str | None = Field(default=None, description="Start timestamp of forecast horizon")
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow dynamic quantile fields like "0.1", "0.5", "0.9"
|
||||
|
||||
|
||||
class ForecastResponse(OpenAIBaseModel):
|
||||
"""Response format for time series forecasting."""
|
||||
|
||||
request_id: str = Field(..., description="Request identifier")
|
||||
|
||||
created_at: int = Field(..., description="Unix timestamp when the response was created")
|
||||
|
||||
data: dict[str, list[ForecastPrediction]] = Field(
|
||||
..., description="Forecast results with 'predictions' key"
|
||||
)
|
||||
242
src/chronos/chronos2/vllm/protocol/validation.py
Normal file
242
src/chronos/chronos2/vllm/protocol/validation.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Centralized validation for Chronos-2 forecast requests.
|
||||
|
||||
All validation rules live here so that both Pydantic model validators
|
||||
(in ``protocol.forecast``) and the IOProcessor can delegate to a single
|
||||
source of truth. The design mirrors the SageMaker endpoint's
|
||||
``validate_payload`` / ``validate_covariates`` utilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chronos.chronos2.vllm.protocol.forecast import ( # noqa: F811
|
||||
ForecastParameters,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — match the SageMaker endpoint constraints
|
||||
# TODO: Get some of the values from model config
|
||||
# ---------------------------------------------------------------------------
|
||||
MIN_TARGET_LENGTH: int = 5
|
||||
MAX_PREDICTION_LENGTH: int = 1024
|
||||
MAX_NUM_TIME_SERIES: int = 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-input validators (called from Pydantic field/model validators)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_target(
|
||||
target: list[float] | list[list[float]],
|
||||
) -> list[float] | list[list[float]]:
|
||||
"""Validate minimum target length and multivariate row consistency.
|
||||
|
||||
Raises ``ValueError`` if:
|
||||
- The target is empty.
|
||||
- The target has fewer than ``MIN_TARGET_LENGTH`` observations.
|
||||
- For multivariate targets, rows have inconsistent lengths.
|
||||
"""
|
||||
if not target:
|
||||
raise ValueError("Target must not be empty")
|
||||
if isinstance(target[0], list):
|
||||
first_len = len(target[0])
|
||||
if first_len < MIN_TARGET_LENGTH:
|
||||
raise ValueError(
|
||||
f"Target must contain at least {MIN_TARGET_LENGTH} "
|
||||
f"observations (received {first_len})"
|
||||
)
|
||||
for i, dim in enumerate(target[1:], start=1):
|
||||
if len(dim) != first_len: # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
f"All target dimensions must have same length. "
|
||||
f"Dimension 0 has {first_len} observations, "
|
||||
f"dimension {i} has {len(dim)}" # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
if len(target) < MIN_TARGET_LENGTH:
|
||||
raise ValueError(
|
||||
f"Target must contain at least {MIN_TARGET_LENGTH} "
|
||||
f"observations (received {len(target)})"
|
||||
)
|
||||
return target
|
||||
|
||||
|
||||
def validate_start_timestamp(start: str | None) -> str | None:
|
||||
"""Validate that *start* is a valid ISO-8601 string (or ``None``)."""
|
||||
if start is not None:
|
||||
try:
|
||||
datetime.fromisoformat(start.replace("Z", "+00:00"))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid start timestamp format: {start}. "
|
||||
f"Expected ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)"
|
||||
) from e
|
||||
return start
|
||||
|
||||
|
||||
def validate_quantile_levels(quantile_levels: list[float]) -> list[float]:
|
||||
"""Validate that every quantile level is in the open interval (0, 1)."""
|
||||
for q in quantile_levels:
|
||||
if not (0 < q < 1):
|
||||
raise ValueError(f"Quantile levels must be between 0 and 1 (exclusive), got {q}")
|
||||
return quantile_levels
|
||||
|
||||
|
||||
def validate_single_series_covariates(
|
||||
target: list[float] | list[list[float]],
|
||||
past_covariates: dict[str, list] | None,
|
||||
future_covariates: dict[str, list] | None,
|
||||
) -> None:
|
||||
"""Validate covariate constraints for a **single** time series.
|
||||
|
||||
Raises ``ValueError`` if any of the following are violated:
|
||||
|
||||
- ``'target'`` is used as a covariate name.
|
||||
- ``future_covariates`` is provided without ``past_covariates``.
|
||||
- Future covariate keys are not a subset of past covariate keys.
|
||||
- Past covariate array lengths do not match the target length.
|
||||
"""
|
||||
if not target:
|
||||
raise ValueError("Target must not be empty")
|
||||
target_len = len(target[0]) if isinstance(target[0], list) else len(target)
|
||||
|
||||
# 'target' must not be used as a covariate name
|
||||
for label, covariates in [
|
||||
("past_covariates", past_covariates),
|
||||
("future_covariates", future_covariates),
|
||||
]:
|
||||
if covariates is not None and "target" in covariates:
|
||||
raise ValueError("Covariate with name 'target' is not allowed")
|
||||
|
||||
# future_covariates requires past_covariates
|
||||
if future_covariates is not None and past_covariates is None:
|
||||
raise ValueError(
|
||||
"Both 'past_covariates' and 'future_covariates' must be provided "
|
||||
"together. Got 'future_covariates' without 'past_covariates'"
|
||||
)
|
||||
|
||||
# future keys ⊆ past keys
|
||||
if past_covariates is not None and future_covariates is not None:
|
||||
past_keys = set(past_covariates.keys())
|
||||
future_keys = set(future_covariates.keys())
|
||||
if not future_keys.issubset(past_keys):
|
||||
extra = future_keys - past_keys
|
||||
raise ValueError(
|
||||
f"All future covariate keys must be present in past covariates. "
|
||||
f"Keys {extra} are in 'future_covariates' but not in "
|
||||
f"'past_covariates'"
|
||||
)
|
||||
|
||||
# past covariate lengths must match target length
|
||||
if past_covariates is not None:
|
||||
for key, values in past_covariates.items():
|
||||
if len(values) != target_len:
|
||||
raise ValueError(
|
||||
f"Past covariate '{key}' length ({len(values)}) "
|
||||
f"must match target length ({target_len})"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cross-series validators (called from IOProcessor.parse_request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_cross_series(
|
||||
inputs: list[TimeSeriesInput],
|
||||
parameters: ForecastParameters,
|
||||
) -> None:
|
||||
"""Validate constraints that span across multiple time series.
|
||||
|
||||
Mirrors the SageMaker endpoint's ``validate_payload`` and
|
||||
``validate_covariates`` utilities.
|
||||
|
||||
Raises ``ValueError`` if any of the following are violated:
|
||||
|
||||
- ``item_id`` provided for some but not all inputs.
|
||||
- ``item_id`` values are not unique.
|
||||
- ``start`` provided for some but not all inputs.
|
||||
- ``start`` is provided without ``freq`` (or vice-versa).
|
||||
- Covariate keys are not identical across all series.
|
||||
- ``future_covariates`` array lengths don't match ``prediction_length``.
|
||||
"""
|
||||
_validate_item_ids(inputs)
|
||||
_validate_start_freq(inputs, parameters)
|
||||
_validate_covariate_consistency(inputs)
|
||||
_validate_future_covariate_lengths(inputs, parameters)
|
||||
|
||||
|
||||
# -- helpers (private) -------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_item_ids(inputs: list[TimeSeriesInput]) -> None:
|
||||
item_ids = [ts.item_id for ts in inputs]
|
||||
has_none = any(x is None for x in item_ids)
|
||||
has_value = any(x is not None for x in item_ids)
|
||||
if has_none and has_value:
|
||||
raise ValueError(
|
||||
"If 'item_id' is provided for at least one time series in "
|
||||
"'inputs', it should be provided for all time series"
|
||||
)
|
||||
if has_value and len(item_ids) != len(set(item_ids)):
|
||||
raise ValueError("'item_id' must be unique for all time series in 'inputs'")
|
||||
|
||||
|
||||
def _validate_start_freq(
|
||||
inputs: list[TimeSeriesInput],
|
||||
parameters: ForecastParameters,
|
||||
) -> None:
|
||||
starts = [ts.start for ts in inputs]
|
||||
has_none = any(x is None for x in starts)
|
||||
has_value = any(x is not None for x in starts)
|
||||
if has_none and has_value:
|
||||
raise ValueError(
|
||||
"If 'start' is provided for at least one time series in "
|
||||
"'inputs', it should be provided for all time series"
|
||||
)
|
||||
if has_value and parameters.freq is None:
|
||||
raise ValueError(
|
||||
"If 'start' is provided, then 'freq' must also be provided " "in 'parameters'"
|
||||
)
|
||||
if parameters.freq is not None and not has_value:
|
||||
raise ValueError(
|
||||
"If 'freq' is provided in 'parameters', then 'start' must "
|
||||
"also be provided for all time series in 'inputs'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_covariate_consistency(inputs: list[TimeSeriesInput]) -> None:
|
||||
key_sets: list[frozenset[str] | None] = []
|
||||
for ts in inputs:
|
||||
if ts.past_covariates is not None:
|
||||
key_sets.append(frozenset(ts.past_covariates.keys()))
|
||||
else:
|
||||
key_sets.append(None)
|
||||
|
||||
if len(set(key_sets)) > 1:
|
||||
raise ValueError(
|
||||
"If 'past_covariates' and 'future_covariates' are provided "
|
||||
"for at least one time series in 'inputs', the same "
|
||||
"covariates should be provided for all time series"
|
||||
)
|
||||
|
||||
|
||||
def _validate_future_covariate_lengths(
|
||||
inputs: list[TimeSeriesInput],
|
||||
parameters: ForecastParameters,
|
||||
) -> None:
|
||||
prediction_length = parameters.prediction_length
|
||||
for i, ts in enumerate(inputs):
|
||||
if ts.future_covariates is not None:
|
||||
for key, values in ts.future_covariates.items():
|
||||
if len(values) != prediction_length:
|
||||
raise ValueError(
|
||||
f"Input {i}: length of future covariate '{key}' "
|
||||
f"({len(values)}) must equal prediction_length "
|
||||
f"({prediction_length})"
|
||||
)
|
||||
1
src/chronos/chronos2/vllm/utils/__init__.py
Normal file
1
src/chronos/chronos2/vllm/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Utility modules for Chronos-2 vLLM plugin."""
|
||||
35
src/chronos/chronos2/vllm/utils/helpers.py
Normal file
35
src/chronos/chronos2/vllm/utils/helpers.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Common helper functions for Chronos-2 vLLM plugin."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def tensor_to_list(tensor: torch.Tensor) -> list[float] | list[list[float]]:
|
||||
"""Convert 2-D tensor to list (univariate) or list of lists (multivariate).
|
||||
|
||||
Args:
|
||||
tensor: shape (n_variates, horizon)
|
||||
|
||||
Returns:
|
||||
list[float] if univariate (n_variates == 1), else list[list[float]]
|
||||
"""
|
||||
assert tensor.ndim == 2
|
||||
return tensor[0].tolist() if tensor.shape[0] == 1 else [row.tolist() for row in tensor]
|
||||
|
||||
|
||||
def empty_prediction(
|
||||
prediction_length: int,
|
||||
quantile_levels: list[float],
|
||||
) -> dict[str, Any]:
|
||||
"""Return a zero-filled prediction dict for error cases.
|
||||
|
||||
Args:
|
||||
prediction_length: number of forecast steps
|
||||
quantile_levels: list of quantile levels to include
|
||||
"""
|
||||
zeros = [0.0] * prediction_length
|
||||
pred: dict[str, Any] = {"mean": zeros}
|
||||
for q in quantile_levels:
|
||||
pred[str(q)] = zeros
|
||||
return pred
|
||||
40
src/chronos/chronos2/vllm/utils/quantiles.py
Normal file
40
src/chronos/chronos2/vllm/utils/quantiles.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Quantile selection and interpolation utilities."""
|
||||
|
||||
import torch
|
||||
|
||||
from chronos.utils import interpolate_quantiles
|
||||
|
||||
|
||||
def select_quantiles(
|
||||
predictions: list[torch.Tensor],
|
||||
model_quantiles: list[float],
|
||||
requested_levels: list[float],
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""Select or interpolate quantiles from model output.
|
||||
|
||||
Args:
|
||||
predictions: list of tensors, each shape (..., num_quantiles, horizon)
|
||||
model_quantiles: the quantile levels the model was trained on (sorted ascending)
|
||||
requested_levels: the quantile levels the caller wants
|
||||
|
||||
Returns:
|
||||
(quantiles, mean) where:
|
||||
- quantiles: list of tensors, each shape (..., horizon, num_requested_quantiles)
|
||||
- mean: list of tensors, each shape (..., horizon) — the median
|
||||
"""
|
||||
# Swap quantile and time axes: [... q h] -> [... h q]
|
||||
swapped = [pred.permute(*range(pred.ndim - 2), -1, -2) for pred in predictions]
|
||||
|
||||
if set(requested_levels).issubset(model_quantiles):
|
||||
indices = [model_quantiles.index(q) for q in requested_levels]
|
||||
quantiles = [pred[..., indices] for pred in swapped]
|
||||
else:
|
||||
quantiles = [
|
||||
interpolate_quantiles(requested_levels, model_quantiles, pred) for pred in swapped
|
||||
]
|
||||
|
||||
# Median as mean (Chronos-2 convention)
|
||||
median_idx = model_quantiles.index(0.5) if 0.5 in model_quantiles else len(model_quantiles) // 2
|
||||
mean = [pred[..., median_idx] for pred in swapped]
|
||||
|
||||
return quantiles, mean
|
||||
1
test/chronos2/vllm/__init__.py
Normal file
1
test/chronos2/vllm/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Chronos-2 plugin tests
|
||||
360
test/chronos2/vllm/test_chronos2_plugin.py
Normal file
360
test/chronos2/vllm/test_chronos2_plugin.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""Tests for Chronos-2 plugin validation logic."""
|
||||
|
||||
import pytest
|
||||
|
||||
from chronos.chronos2.vllm.protocol.forecast import (
|
||||
ForecastParameters,
|
||||
ForecastRequest,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# TimeSeriesInput — per-input validation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTimeSeriesInputTargetValidation:
|
||||
"""Tests for target field validation."""
|
||||
|
||||
def test_valid_univariate_target(self):
|
||||
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
assert len(ts.target) == 5
|
||||
|
||||
def test_valid_multivariate_target(self):
|
||||
ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
|
||||
assert len(ts.target) == 2
|
||||
|
||||
def test_univariate_target_too_short(self):
|
||||
with pytest.raises(ValueError, match="at least 5 observations"):
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0])
|
||||
|
||||
def test_multivariate_target_too_short(self):
|
||||
with pytest.raises(ValueError, match="at least 5 observations"):
|
||||
TimeSeriesInput(target=[[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
def test_multivariate_inconsistent_lengths(self):
|
||||
with pytest.raises(ValueError, match="All target dimensions must have same length"):
|
||||
TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
|
||||
|
||||
|
||||
class TestTimeSeriesInputStartTimestamp:
|
||||
"""Tests for start timestamp validation."""
|
||||
|
||||
def test_valid_date_format(self):
|
||||
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01")
|
||||
assert ts.start == "2024-01-01"
|
||||
|
||||
def test_valid_datetime_format(self):
|
||||
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01T12:00:00")
|
||||
assert ts.start == "2024-01-01T12:00:00"
|
||||
|
||||
def test_invalid_timestamp_format(self):
|
||||
with pytest.raises(ValueError, match="Invalid start timestamp format"):
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="not-a-date")
|
||||
|
||||
|
||||
class TestTimeSeriesInputCovariateValidation:
|
||||
"""Tests for per-input covariate validation."""
|
||||
|
||||
def test_valid_past_covariates(self):
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
)
|
||||
assert ts.past_covariates is not None
|
||||
|
||||
def test_past_covariate_length_mismatch(self):
|
||||
with pytest.raises(ValueError, match="must match target length"):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"feature_a": [1.0, 2.0, 3.0]},
|
||||
)
|
||||
|
||||
def test_target_name_in_past_covariates_rejected(self):
|
||||
with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"target": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
)
|
||||
|
||||
def test_target_name_in_future_covariates_rejected(self):
|
||||
with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
future_covariates={"target": [1.0]},
|
||||
)
|
||||
|
||||
def test_future_covariates_without_past_rejected(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Both 'past_covariates' and 'future_covariates' must be provided"
|
||||
):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
future_covariates={"feature_a": [1.0]},
|
||||
)
|
||||
|
||||
def test_future_covariate_keys_must_be_subset_of_past(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="All future covariate keys must be present in past covariates"
|
||||
):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
future_covariates={"feature_b": [1.0]},
|
||||
)
|
||||
|
||||
def test_future_covariate_keys_subset_is_valid(self):
|
||||
"""Past covariates may have more keys than future — that's fine."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={
|
||||
"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
"feature_b": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
},
|
||||
future_covariates={"feature_a": [1.0]},
|
||||
)
|
||||
assert ts.future_covariates is not None
|
||||
|
||||
def test_past_only_covariates_valid(self):
|
||||
"""Having past_covariates without future_covariates is acceptable."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
)
|
||||
assert ts.past_covariates is not None
|
||||
assert ts.future_covariates is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ForecastParameters validation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestForecastParameters:
|
||||
"""Tests for ForecastParameters validation."""
|
||||
|
||||
def test_default_values(self):
|
||||
params = ForecastParameters()
|
||||
assert params.prediction_length == 1
|
||||
assert params.quantile_levels == [0.1, 0.5, 0.9]
|
||||
assert params.freq is None
|
||||
assert params.batch_size == 256
|
||||
assert params.cross_learning is False
|
||||
|
||||
def test_prediction_length_too_large(self):
|
||||
with pytest.raises(ValueError):
|
||||
ForecastParameters(prediction_length=2000)
|
||||
|
||||
def test_prediction_length_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
ForecastParameters(prediction_length=0)
|
||||
|
||||
def test_quantile_out_of_range(self):
|
||||
with pytest.raises(ValueError, match="between 0 and 1"):
|
||||
ForecastParameters(quantile_levels=[0.0, 0.5])
|
||||
|
||||
def test_quantile_one(self):
|
||||
with pytest.raises(ValueError, match="between 0 and 1"):
|
||||
ForecastParameters(quantile_levels=[0.5, 1.0])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cross-series validation (via validation module)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCrossSeriesValidation:
|
||||
"""Tests for cross-series validation logic in validation module."""
|
||||
|
||||
@staticmethod
|
||||
def _validate(inputs, parameters=None):
|
||||
"""Helper that calls validate_cross_series from the validation module."""
|
||||
from chronos.chronos2.vllm.protocol.validation import validate_cross_series
|
||||
|
||||
if parameters is None:
|
||||
parameters = ForecastParameters()
|
||||
validate_cross_series(inputs, parameters)
|
||||
|
||||
def test_item_id_all_provided(self):
|
||||
"""All item_ids present — valid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
|
||||
TimeSeriesInput(target=[2.0] * 5, item_id="B"),
|
||||
]
|
||||
self._validate(inputs) # should not raise
|
||||
|
||||
def test_item_id_none_provided(self):
|
||||
"""No item_ids at all — valid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5),
|
||||
TimeSeriesInput(target=[2.0] * 5),
|
||||
]
|
||||
self._validate(inputs) # should not raise
|
||||
|
||||
def test_item_id_partial(self):
|
||||
"""Some have item_id, some don't — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
|
||||
TimeSeriesInput(target=[2.0] * 5),
|
||||
]
|
||||
with pytest.raises(ValueError, match="item_id.*provided for all time series"):
|
||||
self._validate(inputs)
|
||||
|
||||
def test_item_id_not_unique(self):
|
||||
"""Duplicate item_ids — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
|
||||
TimeSeriesInput(target=[2.0] * 5, item_id="A"),
|
||||
]
|
||||
with pytest.raises(ValueError, match="item_id.*must be unique"):
|
||||
self._validate(inputs)
|
||||
|
||||
def test_start_all_provided_with_freq(self):
|
||||
"""All starts present with freq — valid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
|
||||
TimeSeriesInput(target=[2.0] * 5, start="2024-02-01"),
|
||||
]
|
||||
params = ForecastParameters(freq="D")
|
||||
self._validate(inputs, params) # should not raise
|
||||
|
||||
def test_start_partial(self):
|
||||
"""Some have start, some don't — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
|
||||
TimeSeriesInput(target=[2.0] * 5),
|
||||
]
|
||||
params = ForecastParameters(freq="D")
|
||||
with pytest.raises(ValueError, match="start.*provided for all time series"):
|
||||
self._validate(inputs, params)
|
||||
|
||||
def test_start_without_freq(self):
|
||||
"""start provided but freq missing — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
|
||||
]
|
||||
params = ForecastParameters() # freq is None
|
||||
with pytest.raises(ValueError, match="freq.*must also be provided"):
|
||||
self._validate(inputs, params)
|
||||
|
||||
def test_freq_without_start(self):
|
||||
"""freq provided but no start on inputs — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0] * 5),
|
||||
]
|
||||
params = ForecastParameters(freq="D")
|
||||
with pytest.raises(ValueError, match="start.*must also be provided"):
|
||||
self._validate(inputs, params)
|
||||
|
||||
def test_covariate_keys_consistent(self):
|
||||
"""Same covariate keys on all series — valid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0] * 5,
|
||||
past_covariates={"feat": [1.0] * 5},
|
||||
),
|
||||
TimeSeriesInput(
|
||||
target=[2.0] * 5,
|
||||
past_covariates={"feat": [2.0] * 5},
|
||||
),
|
||||
]
|
||||
self._validate(inputs) # should not raise
|
||||
|
||||
def test_covariate_keys_inconsistent(self):
|
||||
"""Different covariate keys across series — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0] * 5,
|
||||
past_covariates={"feat_a": [1.0] * 5},
|
||||
),
|
||||
TimeSeriesInput(
|
||||
target=[2.0] * 5,
|
||||
past_covariates={"feat_b": [2.0] * 5},
|
||||
),
|
||||
]
|
||||
with pytest.raises(ValueError, match="same covariates should be provided"):
|
||||
self._validate(inputs)
|
||||
|
||||
def test_covariate_some_have_none(self):
|
||||
"""One series has covariates, other doesn't — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0] * 5,
|
||||
past_covariates={"feat": [1.0] * 5},
|
||||
),
|
||||
TimeSeriesInput(target=[2.0] * 5),
|
||||
]
|
||||
with pytest.raises(ValueError, match="same covariates should be provided"):
|
||||
self._validate(inputs)
|
||||
|
||||
def test_future_covariate_length_matches_prediction_length(self):
|
||||
"""future_covariates length equals prediction_length — valid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0] * 5,
|
||||
past_covariates={"feat": [1.0] * 5},
|
||||
future_covariates={"feat": [1.0, 2.0, 3.0]},
|
||||
),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
self._validate(inputs, params) # should not raise
|
||||
|
||||
def test_future_covariate_length_mismatch(self):
|
||||
"""future_covariates length != prediction_length — invalid."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0] * 5,
|
||||
past_covariates={"feat": [1.0] * 5},
|
||||
future_covariates={"feat": [1.0, 2.0]},
|
||||
),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
with pytest.raises(ValueError, match="must equal prediction_length"):
|
||||
self._validate(inputs, params)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ForecastRequest — end-to-end request validation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestForecastRequest:
|
||||
"""Tests for ForecastRequest end-to-end validation."""
|
||||
|
||||
def test_valid_minimal_request(self):
|
||||
req = ForecastRequest(
|
||||
model="chronos-v2",
|
||||
data={"inputs": [{"target": [1.0, 2.0, 3.0, 4.0, 5.0]}]},
|
||||
)
|
||||
assert req.data["inputs"] is not None
|
||||
|
||||
def test_empty_inputs_rejected(self):
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
ForecastRequest(model="chronos-v2", data={"inputs": []})
|
||||
|
||||
def test_too_many_inputs_rejected(self):
|
||||
many_inputs = [{"target": [1.0] * 5} for _ in range(1025)]
|
||||
with pytest.raises(ValueError, match="at most 1024"):
|
||||
ForecastRequest(model="chronos-v2", data={"inputs": many_inputs})
|
||||
|
||||
def test_missing_inputs_rejected(self):
|
||||
with pytest.raises(ValueError, match="inputs"):
|
||||
ForecastRequest(model="chronos-v2", data={"parameters": {}})
|
||||
|
||||
def test_future_covariate_length_cross_validated(self):
|
||||
"""ForecastRequest validates future_covariates against prediction_length."""
|
||||
with pytest.raises(ValueError, match="prediction_length"):
|
||||
ForecastRequest(
|
||||
model="chronos-v2",
|
||||
data={
|
||||
"inputs": [
|
||||
{
|
||||
"target": [1.0] * 5,
|
||||
"past_covariates": {"feat": [1.0] * 5},
|
||||
"future_covariates": {"feat": [1.0, 2.0]},
|
||||
}
|
||||
],
|
||||
"parameters": {"prediction_length": 5},
|
||||
},
|
||||
)
|
||||
248
test/chronos2/vllm/test_data_prep.py
Normal file
248
test/chronos2/vllm/test_data_prep.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""Unit tests for chronos.chronos2.vllm.protocol.data_prep."""
|
||||
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from chronos.chronos2.vllm.protocol.data_prep import ( # noqa: E402
|
||||
PreparedRequest,
|
||||
prepare_request,
|
||||
)
|
||||
from chronos.chronos2.vllm.protocol.forecast import ( # noqa: E402
|
||||
ForecastParameters,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
|
||||
|
||||
class TestPrepareRequestUnivariate:
|
||||
"""Tests for prepare_request with univariate inputs."""
|
||||
|
||||
def test_single_series(self):
|
||||
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
assert isinstance(result, PreparedRequest)
|
||||
assert len(result.batches) >= 1
|
||||
assert len(result.item_ids) == 1
|
||||
assert result.item_ids[0] is None
|
||||
|
||||
batch = result.batches[0]
|
||||
assert batch.context.shape[0] == 1
|
||||
assert batch.context.shape[1] == 5
|
||||
assert batch.group_ids.shape == (1,)
|
||||
assert batch.target_idx_ranges == [(0, 1)]
|
||||
|
||||
def test_multiple_series(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="a"),
|
||||
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0], item_id="b"),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
# Both series should fit in one batch with default batch_size
|
||||
assert len(result.batches) >= 1
|
||||
# Count total rows across all batches
|
||||
total_rows = sum(b.context.shape[0] for b in result.batches)
|
||||
assert total_rows == 2
|
||||
assert result.item_ids == ["a", "b"]
|
||||
|
||||
def test_different_lengths_padded(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=1)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# Padded to max length = 10
|
||||
assert batch.context.shape == (2, 10)
|
||||
# First series: right-aligned, left NaN-padded
|
||||
assert torch.isnan(batch.context[0, :5]).all()
|
||||
assert not torch.isnan(batch.context[0, 5:]).any()
|
||||
|
||||
def test_truncation_to_context_length(self):
|
||||
long_target = list(range(100))
|
||||
inputs = [TimeSeriesInput(target=long_target)]
|
||||
params = ForecastParameters(prediction_length=1)
|
||||
result = prepare_request(inputs, params, context_length=20)
|
||||
|
||||
batch = result.batches[0]
|
||||
assert batch.context.shape == (1, 20)
|
||||
|
||||
def test_future_covariates_nan_for_targets(self):
|
||||
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
# Target rows should have NaN future covariates
|
||||
assert torch.isnan(result.batches[0].future_covariates[0]).all()
|
||||
|
||||
|
||||
class TestPrepareRequestMultivariate:
|
||||
"""Tests for prepare_request with multivariate inputs."""
|
||||
|
||||
def test_bivariate(self):
|
||||
inputs = [TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# 2 variates = 2 rows
|
||||
assert batch.context.shape == (2, 5)
|
||||
assert batch.target_idx_ranges == [(0, 2)]
|
||||
assert batch.group_ids.tolist() == [0, 0]
|
||||
|
||||
|
||||
class TestPrepareRequestCovariates:
|
||||
"""Tests for prepare_request with covariates."""
|
||||
|
||||
def test_past_covariates(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# 1 target + 1 covariate = 2 rows
|
||||
assert batch.context.shape == (2, 5)
|
||||
# Only target row is in target_idx_ranges
|
||||
assert batch.target_idx_ranges == [(0, 1)]
|
||||
assert batch.group_ids.tolist() == [0, 0]
|
||||
|
||||
def test_future_covariates(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
|
||||
future_covariates={"temp": [60.0, 70.0]},
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# Target row: NaN future covariates
|
||||
assert torch.isnan(batch.future_covariates[0]).all()
|
||||
# Covariate row: actual future values
|
||||
assert batch.future_covariates[1, 0].item() == 60.0
|
||||
assert batch.future_covariates[1, 1].item() == 70.0
|
||||
|
||||
|
||||
class TestPrepareRequestCrossLearning:
|
||||
"""Tests for cross-learning mode."""
|
||||
|
||||
def test_cross_learning_zeros_group_ids(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
|
||||
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=1, cross_learning=True)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
assert result.batches[0].group_ids.tolist() == [0, 0]
|
||||
|
||||
def test_no_cross_learning_distinct_group_ids(self):
|
||||
inputs = [
|
||||
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
|
||||
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]),
|
||||
]
|
||||
params = ForecastParameters(prediction_length=1, cross_learning=False)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
assert result.batches[0].group_ids.tolist() == [0, 1]
|
||||
|
||||
|
||||
class TestPrepareRequestNanCovariates:
|
||||
"""Tests for prepare_request handling of NaN-sanitized covariates."""
|
||||
|
||||
def test_numeric_covariates_with_none(self):
|
||||
"""Numeric covariates with None values should become NaN in tensors."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, None, 30.0, None, 50.0]},
|
||||
future_covariates={"temp": [None, 70.0, None]},
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# 1 target + 1 covariate = 2 rows
|
||||
assert batch.context.shape == (2, 5)
|
||||
# Covariate row: None → NaN in context tensor
|
||||
assert torch.isnan(batch.context[1, 1])
|
||||
assert torch.isnan(batch.context[1, 3])
|
||||
assert batch.context[1, 0].item() == 10.0
|
||||
assert batch.context[1, 2].item() == 30.0
|
||||
# Future covariates: None → NaN
|
||||
assert torch.isnan(batch.future_covariates[1, 0])
|
||||
assert batch.future_covariates[1, 1].item() == 70.0
|
||||
|
||||
def test_string_covariates_encoded_in_tensors(self):
|
||||
"""String covariates (like holidays) should be encoded and included in tensors."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"holiday": ["A", None, "B", None, "C"]},
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# 1 target + 1 encoded categorical covariate = 2 rows
|
||||
assert batch.context.shape == (2, 5)
|
||||
# Encoded values should be finite floats (not NaN) for non-None entries
|
||||
assert not torch.isnan(batch.context[1, 0]) # "A" encoded
|
||||
assert not torch.isnan(batch.context[1, 2]) # "B" encoded
|
||||
assert not torch.isnan(batch.context[1, 4]) # "C" encoded
|
||||
|
||||
def test_categorical_with_future(self):
|
||||
"""Categorical covariates with future values should be encoded consistently."""
|
||||
inputs = [
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"holiday": ["A", "B", "A", "B", "A"]},
|
||||
future_covariates={"holiday": ["A", "B"]},
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
batch = result.batches[0]
|
||||
# 1 target + 1 encoded categorical = 2 rows
|
||||
assert batch.context.shape == (2, 5)
|
||||
# Future covariates should be filled for the categorical row
|
||||
assert not torch.isnan(batch.future_covariates[1, 0])
|
||||
assert not torch.isnan(batch.future_covariates[1, 1])
|
||||
|
||||
|
||||
class TestPrepareRequestBatching:
|
||||
"""Tests for batch output — always a single batch (chunking deferred to model forward)."""
|
||||
|
||||
def test_always_single_batch(self):
|
||||
"""prepare_request always returns one batch regardless of batch_size parameter."""
|
||||
inputs = [TimeSeriesInput(target=[float(i)] * 5) for i in range(5)]
|
||||
params = ForecastParameters(prediction_length=1, batch_size=2)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
# Always 1 batch — model forward() handles chunking if needed
|
||||
assert len(result.batches) == 1
|
||||
assert result.batches[0].context.shape[0] == 5
|
||||
assert len(result.batches[0].target_idx_ranges) == 5
|
||||
|
||||
def test_single_series_single_batch(self):
|
||||
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
|
||||
params = ForecastParameters(prediction_length=1, batch_size=256)
|
||||
result = prepare_request(inputs, params)
|
||||
|
||||
assert len(result.batches) == 1
|
||||
assert result.batches[0].context.shape[0] == 1
|
||||
53
test/chronos2/vllm/test_helpers.py
Normal file
53
test/chronos2/vllm/test_helpers.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Unit tests for chronos.chronos2.vllm.utils.helpers."""
|
||||
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list # noqa: E402
|
||||
|
||||
|
||||
class TestTensorToList:
|
||||
"""Tests for tensor_to_list."""
|
||||
|
||||
def test_univariate(self):
|
||||
t = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = tensor_to_list(t)
|
||||
assert result == [1.0, 2.0, 3.0]
|
||||
assert isinstance(result, list)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_multivariate(self):
|
||||
t = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
result = tensor_to_list(t)
|
||||
assert result == [[1.0, 2.0], [3.0, 4.0]]
|
||||
assert isinstance(result[0], list)
|
||||
|
||||
def test_single_value(self):
|
||||
t = torch.tensor([[42.0]])
|
||||
assert tensor_to_list(t) == [42.0]
|
||||
|
||||
def test_wrong_dims_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
tensor_to_list(torch.tensor([1.0, 2.0])) # 1-D
|
||||
|
||||
|
||||
class TestEmptyPrediction:
|
||||
"""Tests for empty_prediction."""
|
||||
|
||||
def test_basic(self):
|
||||
pred = empty_prediction(3, [0.1, 0.5, 0.9])
|
||||
assert pred["mean"] == [0.0, 0.0, 0.0]
|
||||
assert pred["0.1"] == [0.0, 0.0, 0.0]
|
||||
assert pred["0.5"] == [0.0, 0.0, 0.0]
|
||||
assert pred["0.9"] == [0.0, 0.0, 0.0]
|
||||
|
||||
def test_single_quantile(self):
|
||||
pred = empty_prediction(2, [0.5])
|
||||
assert "mean" in pred
|
||||
assert "0.5" in pred
|
||||
assert len(pred["mean"]) == 2
|
||||
|
||||
def test_empty_quantiles(self):
|
||||
pred = empty_prediction(1, [])
|
||||
assert pred == {"mean": [0.0]}
|
||||
178
test/chronos2/vllm/test_protocol.py
Normal file
178
test/chronos2/vllm/test_protocol.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""Unit tests for chronos.chronos2.vllm.protocol.forecast Pydantic models."""
|
||||
|
||||
import pytest
|
||||
|
||||
from chronos.chronos2.vllm.protocol.forecast import (
|
||||
ForecastParameters,
|
||||
ForecastPrediction,
|
||||
TimeSeriesInput,
|
||||
)
|
||||
|
||||
|
||||
class TestTimeSeriesInput:
|
||||
"""Tests for TimeSeriesInput Pydantic model."""
|
||||
|
||||
def test_minimal(self):
|
||||
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
assert ts.target == [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
assert ts.item_id is None
|
||||
assert ts.start is None
|
||||
assert ts.past_covariates is None
|
||||
assert ts.future_covariates is None
|
||||
|
||||
def test_with_item_id(self):
|
||||
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="series_1")
|
||||
assert ts.item_id == "series_1"
|
||||
|
||||
def test_multivariate_target(self):
|
||||
ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
|
||||
assert len(ts.target) == 2
|
||||
|
||||
def test_with_covariates(self):
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
|
||||
future_covariates={"temp": [60.0]},
|
||||
)
|
||||
assert "temp" in ts.past_covariates
|
||||
assert "temp" in ts.future_covariates
|
||||
|
||||
def test_too_short_target_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
TimeSeriesInput(target=[1.0, 2.0])
|
||||
|
||||
def test_covariate_length_mismatch_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, 20.0]}, # wrong length
|
||||
)
|
||||
|
||||
def test_future_without_past_rejected(self):
|
||||
with pytest.raises(Exception):
|
||||
TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
future_covariates={"temp": [60.0]},
|
||||
)
|
||||
|
||||
|
||||
class TestTimeSeriesInputNanCovariates:
|
||||
"""Tests for NaN handling in covariate arrays (favorita_stores_1D scenario)."""
|
||||
|
||||
def test_nan_in_string_future_covariates_sanitized(self):
|
||||
"""NaN floats in string covariate lists should be sanitized to None."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={
|
||||
"holiday": ["New Year", "MLK Day", float("nan"), float("nan"), "Easter"]
|
||||
},
|
||||
future_covariates={"holiday": [float("nan")]},
|
||||
)
|
||||
# NaN should be converted to None
|
||||
assert ts.past_covariates["holiday"][2] is None
|
||||
assert ts.past_covariates["holiday"][3] is None
|
||||
assert ts.future_covariates["holiday"][0] is None
|
||||
# Strings should remain unchanged
|
||||
assert ts.past_covariates["holiday"][0] == "New Year"
|
||||
assert ts.past_covariates["holiday"][4] == "Easter"
|
||||
|
||||
def test_nan_in_numeric_covariates_preserved(self):
|
||||
"""NaN in numeric covariates should be sanitized to None too (consistent)."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"temp": [10.0, float("nan"), 30.0, 40.0, 50.0]},
|
||||
future_covariates={"temp": [float("nan")]},
|
||||
)
|
||||
# NaN floats → None after sanitization
|
||||
assert ts.past_covariates["temp"][1] is None
|
||||
assert ts.future_covariates["temp"][0] is None
|
||||
|
||||
def test_all_nan_string_covariates(self):
|
||||
"""All-NaN covariate arrays should be accepted."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"holiday": [float("nan")] * 5},
|
||||
future_covariates={"holiday": [float("nan")]},
|
||||
)
|
||||
assert all(v is None for v in ts.past_covariates["holiday"])
|
||||
assert all(v is None for v in ts.future_covariates["holiday"])
|
||||
|
||||
def test_mixed_string_and_none_accepted(self):
|
||||
"""Mix of strings and None should be accepted."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"holiday": ["A", None, "B", None, "C"]},
|
||||
)
|
||||
assert ts.past_covariates["holiday"] == ["A", None, "B", None, "C"]
|
||||
|
||||
def test_clean_string_covariates_unchanged(self):
|
||||
"""String covariates without NaN should pass through unchanged."""
|
||||
ts = TimeSeriesInput(
|
||||
target=[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
past_covariates={"category": ["A", "B", "C", "D", "E"]},
|
||||
)
|
||||
assert ts.past_covariates["category"] == ["A", "B", "C", "D", "E"]
|
||||
|
||||
|
||||
class TestForecastParameters:
|
||||
"""Tests for ForecastParameters Pydantic model."""
|
||||
|
||||
def test_defaults(self):
|
||||
params = ForecastParameters()
|
||||
assert params.prediction_length == 1
|
||||
assert params.quantile_levels == [0.1, 0.5, 0.9]
|
||||
assert params.freq is None
|
||||
assert params.batch_size == 256
|
||||
assert params.cross_learning is False
|
||||
|
||||
def test_custom_values(self):
|
||||
params = ForecastParameters(
|
||||
prediction_length=24,
|
||||
quantile_levels=[0.25, 0.5, 0.75],
|
||||
batch_size=100,
|
||||
cross_learning=True,
|
||||
)
|
||||
assert params.prediction_length == 24
|
||||
assert params.quantile_levels == [0.25, 0.5, 0.75]
|
||||
assert params.batch_size == 100
|
||||
assert params.cross_learning is True
|
||||
|
||||
def test_invalid_prediction_length(self):
|
||||
with pytest.raises(Exception):
|
||||
ForecastParameters(prediction_length=0)
|
||||
|
||||
def test_invalid_quantile_level(self):
|
||||
with pytest.raises(Exception):
|
||||
ForecastParameters(quantile_levels=[0.0, 0.5])
|
||||
|
||||
def test_prediction_length_max(self):
|
||||
with pytest.raises(Exception):
|
||||
ForecastParameters(prediction_length=2000)
|
||||
|
||||
|
||||
class TestForecastPrediction:
|
||||
"""Tests for ForecastPrediction Pydantic model."""
|
||||
|
||||
def test_minimal(self):
|
||||
pred = ForecastPrediction(mean=[1.0, 2.0, 3.0])
|
||||
assert pred.mean == [1.0, 2.0, 3.0]
|
||||
assert pred.item_id is None
|
||||
|
||||
def test_with_quantiles(self):
|
||||
pred = ForecastPrediction(
|
||||
mean=[1.0, 2.0],
|
||||
item_id="s1",
|
||||
**{"0.1": [0.5, 1.5], "0.9": [1.5, 2.5]},
|
||||
)
|
||||
assert pred.item_id == "s1"
|
||||
|
||||
def test_multivariate_mean(self):
|
||||
pred = ForecastPrediction(mean=[[1.0, 2.0], [3.0, 4.0]])
|
||||
assert len(pred.mean) == 2
|
||||
|
||||
def test_extra_fields_allowed(self):
|
||||
"""ForecastPrediction allows dynamic quantile fields."""
|
||||
pred = ForecastPrediction(mean=[1.0], **{"0.5": [1.0], "0.1": [0.5]})
|
||||
d = pred.model_dump()
|
||||
assert "0.5" in d
|
||||
assert "0.1" in d
|
||||
98
test/chronos2/vllm/test_quantiles.py
Normal file
98
test/chronos2/vllm/test_quantiles.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Unit tests for chronos.chronos2.vllm.utils.quantiles and chronos.utils.interpolate_quantiles."""
|
||||
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from chronos.chronos2.vllm.utils.quantiles import select_quantiles # noqa: E402
|
||||
from chronos.utils import interpolate_quantiles # noqa: E402
|
||||
|
||||
MODEL_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
|
||||
class TestSelectQuantiles:
|
||||
"""Tests for select_quantiles."""
|
||||
|
||||
def test_subset_selection(self):
|
||||
"""Requesting a subset of model quantiles should do direct indexing."""
|
||||
# Shape: (1, 9_quantiles, 4_horizon)
|
||||
pred = torch.randn(1, len(MODEL_QUANTILES), 4)
|
||||
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9])
|
||||
|
||||
assert len(quantiles) == 1
|
||||
assert len(mean) == 1
|
||||
# After swap: (1, 4, 3)
|
||||
assert quantiles[0].shape == (1, 4, 3)
|
||||
# Mean is median (0.5 = index 4)
|
||||
assert mean[0].shape == (1, 4)
|
||||
|
||||
def test_direct_selection_values(self):
|
||||
"""Verify selected values match expected indices."""
|
||||
pred = torch.arange(36, dtype=torch.float32).reshape(1, 9, 4)
|
||||
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.9])
|
||||
|
||||
# q=0.1 is index 0, q=0.9 is index 8
|
||||
# After swap: pred[..., h, q] → values at (h, q_idx)
|
||||
q = quantiles[0] # (1, 4, 2)
|
||||
assert q[0, 0, 0].item() == 0.0 # q=0.1, h=0 → row 0, col 0
|
||||
assert q[0, 0, 1].item() == 32.0 # q=0.9, h=0 → row 8, col 0
|
||||
|
||||
def test_interpolation_triggered(self):
|
||||
"""Non-subset quantile levels should trigger interpolation."""
|
||||
pred = torch.randn(1, len(MODEL_QUANTILES), 4)
|
||||
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.15, 0.5])
|
||||
|
||||
assert quantiles[0].shape == (1, 4, 2)
|
||||
|
||||
def test_multiple_predictions(self):
|
||||
"""Works with multiple prediction tensors."""
|
||||
preds = [torch.randn(1, 9, 4) for _ in range(3)]
|
||||
quantiles, mean = select_quantiles(preds, MODEL_QUANTILES, [0.5])
|
||||
|
||||
assert len(quantiles) == 3
|
||||
assert len(mean) == 3
|
||||
|
||||
def test_multivariate(self):
|
||||
"""Works with multivariate predictions (2 variates)."""
|
||||
pred = torch.randn(2, len(MODEL_QUANTILES), 4)
|
||||
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9])
|
||||
|
||||
assert quantiles[0].shape == (2, 4, 3)
|
||||
assert mean[0].shape == (2, 4)
|
||||
|
||||
|
||||
class TestInterpolateQuantiles:
|
||||
"""Tests for chronos.utils.interpolate_quantiles."""
|
||||
|
||||
def test_exact_match(self):
|
||||
"""Exact match should return the original values."""
|
||||
values = torch.tensor([[1.0, 2.0, 3.0]]) # (1, 3) for levels [0.1, 0.5, 0.9]
|
||||
result = interpolate_quantiles([0.5], [0.1, 0.5, 0.9], values)
|
||||
assert result.shape == (1, 1)
|
||||
assert result[0, 0].item() == 2.0
|
||||
|
||||
def test_midpoint_interpolation(self):
|
||||
"""Midpoint between two levels should be the average."""
|
||||
values = torch.tensor([[0.0, 10.0]]) # (1, 2) for levels [0.2, 0.8]
|
||||
result = interpolate_quantiles([0.5], [0.2, 0.8], values)
|
||||
assert abs(result[0, 0].item() - 5.0) < 1e-5
|
||||
|
||||
def test_clamping(self):
|
||||
"""Query below/above model range should clamp to boundary."""
|
||||
values = torch.tensor([[1.0, 2.0, 3.0]]) # levels [0.1, 0.5, 0.9]
|
||||
low = interpolate_quantiles([0.01], [0.1, 0.5, 0.9], values)
|
||||
high = interpolate_quantiles([0.99], [0.1, 0.5, 0.9], values)
|
||||
assert low[0, 0].item() == pytest.approx(1.0, abs=1e-5) # clamped to 0.1
|
||||
assert high[0, 0].item() == pytest.approx(3.0, abs=1e-5) # clamped to 0.9
|
||||
|
||||
def test_multiple_queries(self):
|
||||
"""Multiple query levels should produce correct shape."""
|
||||
values = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = interpolate_quantiles([0.1, 0.3, 0.5, 0.9], [0.1, 0.5, 0.9], values)
|
||||
assert result.shape == (1, 4)
|
||||
|
||||
def test_batch_dimension(self):
|
||||
"""Works with batch dimension."""
|
||||
values = torch.randn(5, 9) # 5 time steps, 9 quantiles
|
||||
result = interpolate_quantiles([0.25], MODEL_QUANTILES, values)
|
||||
assert result.shape == (5, 1)
|
||||
163
test/chronos2/vllm/test_validation.py
Normal file
163
test/chronos2/vllm/test_validation.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Unit tests for chronos.chronos2.vllm.protocol.validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput
|
||||
from chronos.chronos2.vllm.protocol.validation import (
|
||||
validate_cross_series,
|
||||
validate_quantile_levels,
|
||||
validate_single_series_covariates,
|
||||
validate_start_timestamp,
|
||||
validate_target,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateTarget:
|
||||
"""Tests for validate_target."""
|
||||
|
||||
def test_valid_univariate(self):
|
||||
result = validate_target([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
assert len(result) == 5
|
||||
|
||||
def test_valid_multivariate(self):
|
||||
result = validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
|
||||
assert len(result) == 2
|
||||
|
||||
def test_too_short_univariate(self):
|
||||
with pytest.raises(ValueError, match="at least 5"):
|
||||
validate_target([1.0, 2.0])
|
||||
|
||||
def test_too_short_multivariate(self):
|
||||
with pytest.raises(ValueError, match="at least 5"):
|
||||
validate_target([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
def test_inconsistent_multivariate_lengths(self):
|
||||
with pytest.raises(ValueError, match="same length"):
|
||||
validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0]])
|
||||
|
||||
|
||||
class TestValidateStartTimestamp:
|
||||
"""Tests for validate_start_timestamp."""
|
||||
|
||||
def test_none(self):
|
||||
assert validate_start_timestamp(None) is None
|
||||
|
||||
def test_valid_date(self):
|
||||
assert validate_start_timestamp("2024-01-01") == "2024-01-01"
|
||||
|
||||
def test_valid_datetime(self):
|
||||
assert validate_start_timestamp("2024-01-01T12:00:00") == "2024-01-01T12:00:00"
|
||||
|
||||
def test_invalid_format(self):
|
||||
with pytest.raises(ValueError, match="Invalid start timestamp"):
|
||||
validate_start_timestamp("not-a-date")
|
||||
|
||||
|
||||
class TestValidateQuantileLevels:
|
||||
"""Tests for validate_quantile_levels."""
|
||||
|
||||
def test_valid(self):
|
||||
result = validate_quantile_levels([0.1, 0.5, 0.9])
|
||||
assert result == [0.1, 0.5, 0.9]
|
||||
|
||||
def test_zero_invalid(self):
|
||||
with pytest.raises(ValueError, match="between 0 and 1"):
|
||||
validate_quantile_levels([0.0, 0.5])
|
||||
|
||||
def test_one_invalid(self):
|
||||
with pytest.raises(ValueError, match="between 0 and 1"):
|
||||
validate_quantile_levels([0.5, 1.0])
|
||||
|
||||
def test_negative_invalid(self):
|
||||
with pytest.raises(ValueError, match="between 0 and 1"):
|
||||
validate_quantile_levels([-0.1])
|
||||
|
||||
|
||||
class TestValidateSingleSeriesCovariates:
|
||||
"""Tests for validate_single_series_covariates."""
|
||||
|
||||
def test_no_covariates(self):
|
||||
validate_single_series_covariates([1.0, 2.0, 3.0, 4.0, 5.0], None, None)
|
||||
|
||||
def test_valid_past_covariates(self):
|
||||
target = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}
|
||||
validate_single_series_covariates(target, past, None)
|
||||
|
||||
def test_future_without_past_raises(self):
|
||||
target = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
future = {"temp": [60.0]}
|
||||
with pytest.raises(ValueError, match="together"):
|
||||
validate_single_series_covariates(target, None, future)
|
||||
|
||||
def test_future_key_not_in_past_raises(self):
|
||||
target = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}
|
||||
future = {"wind": [1.0]}
|
||||
with pytest.raises(ValueError, match="not in"):
|
||||
validate_single_series_covariates(target, past, future)
|
||||
|
||||
def test_past_length_mismatch(self):
|
||||
target = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
past = {"temp": [10.0, 20.0]}
|
||||
with pytest.raises(ValueError, match="must match target length"):
|
||||
validate_single_series_covariates(target, past, None)
|
||||
|
||||
def test_target_name_forbidden(self):
|
||||
target = [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
past = {"target": [10.0, 20.0, 30.0, 40.0, 50.0]}
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_single_series_covariates(target, past, None)
|
||||
|
||||
|
||||
class TestValidateCrossSeries:
|
||||
"""Tests for validate_cross_series."""
|
||||
|
||||
def _make_input(self, target=None, item_id=None, start=None, past_cov=None, future_cov=None):
|
||||
return TimeSeriesInput(
|
||||
target=target or [1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
item_id=item_id,
|
||||
start=start,
|
||||
past_covariates=past_cov,
|
||||
future_covariates=future_cov,
|
||||
)
|
||||
|
||||
def test_valid_simple(self):
|
||||
inputs = [self._make_input(), self._make_input()]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
validate_cross_series(inputs, params)
|
||||
|
||||
def test_inconsistent_item_ids(self):
|
||||
inputs = [self._make_input(item_id="a"), self._make_input(item_id=None)]
|
||||
params = ForecastParameters()
|
||||
with pytest.raises(ValueError, match="item_id"):
|
||||
validate_cross_series(inputs, params)
|
||||
|
||||
def test_duplicate_item_ids(self):
|
||||
inputs = [self._make_input(item_id="a"), self._make_input(item_id="a")]
|
||||
params = ForecastParameters()
|
||||
with pytest.raises(ValueError, match="unique"):
|
||||
validate_cross_series(inputs, params)
|
||||
|
||||
def test_start_without_freq(self):
|
||||
inputs = [self._make_input(start="2024-01-01")]
|
||||
params = ForecastParameters(prediction_length=1)
|
||||
with pytest.raises(ValueError, match="freq"):
|
||||
validate_cross_series(inputs, params)
|
||||
|
||||
def test_freq_without_start(self):
|
||||
inputs = [self._make_input()]
|
||||
params = ForecastParameters(prediction_length=1, freq="D")
|
||||
with pytest.raises(ValueError, match="start"):
|
||||
validate_cross_series(inputs, params)
|
||||
|
||||
def test_future_cov_length_mismatch(self):
|
||||
inputs = [
|
||||
self._make_input(
|
||||
past_cov={"temp": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||||
future_cov={"temp": [10.0, 20.0]}, # length 2 != prediction_length 3
|
||||
)
|
||||
]
|
||||
params = ForecastParameters(prediction_length=3)
|
||||
with pytest.raises(ValueError, match="prediction_length"):
|
||||
validate_cross_series(inputs, params)
|
||||
Loading…
Reference in a new issue