Add vLLM plugin for Chronos-2 inference

Signed-off-by: Li Zhang <lzhanga@amazon.com>
This commit is contained in:
Li Zhang 2026-02-26 23:19:26 +00:00
parent f951d9aefa
commit ba47d25a04
20 changed files with 2707 additions and 0 deletions

View file

@ -52,6 +52,10 @@ test = [
"fev>=0.6.1",
"pandas[pyarrow]>=2.0,<2.4",
]
vllm = [
"vllm>=0.13.0",
"pydantic>=2.0",
]
typecheck = ["mypy~=1.9"]
dev = [
"gluonts[pro]~=0.16",
@ -65,6 +69,12 @@ dev = [
"pandas>=2.0,<2.4",
]
[project.entry-points."vllm.general_plugins"]
chronos2 = "chronos.chronos2.vllm:register_chronos2_model"
[project.entry-points."vllm.io_processor_plugins"]
chronos2 = "chronos.chronos2.vllm:get_chronos2_io_processor"
[project.urls]
Homepage = "https://github.com/amazon-science/chronos-forecasting"
Issues = "https://github.com/amazon-science/chronos-forecasting/issues"

View file

@ -0,0 +1,179 @@
# Chronos-2 vLLM Plugin
A [vLLM](https://github.com/vllm-project/vllm) plugin that adds support for [Chronos-2](https://github.com/amazon-science/chronos-forecasting) time series forecasting via the `/pooling` API endpoint.
## Overview
Chronos-2 is an encoder-only time series foundation model for zero-shot forecasting. This plugin integrates it with vLLM using the **IOProcessor** plugin interface, so forecast requests are served through vLLM's standard pooling endpoint.
### Features
- **Zero-shot forecasting** — no fine-tuning required
- **Quantile predictions** — probabilistic forecasts with customizable quantile levels
- **Cross-series learning** — information sharing across time series in a batch
- **Covariates support** — past and future covariates (numeric and categorical)
- **Batch forecasting** — process multiple time series in a single request
## Installation
Requires Python 3.10+ and vLLM 0.13.0+.
```bash
pip install chronos-forecasting[vllm]
```
## Quick Start
### 1. Start the Server
```bash
vllm serve amazon/chronos-2 \
--io-processor-plugin chronos2 \
--runner pooling \
--enforce-eager \
--no-enable-prefix-caching \
--skip-tokenizer-init \
--enable-mm-embeds \
--dtype float32 \
--max-model-len 8192
```
### 2. Send a Forecast Request
```bash
curl -X POST http://localhost:8000/pooling \
-H "Content-Type: application/json" \
-d '{
"model": "amazon/chronos-2",
"task": "plugin",
"data": {
"inputs": [
{
"target": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
"item_id": "series_1"
}
],
"parameters": {
"prediction_length": 5,
"quantile_levels": [0.1, 0.5, 0.9]
}
}
}'
```
### 3. Parse the Response
```json
{
"request_id": null,
"created_at": 1739397600,
"data": {
"predictions": [
{
"mean": [11.0, 12.1, 13.0, 14.2, 15.1],
"0.1": [9.5, 10.3, 11.0, 11.8, 12.5],
"0.5": [11.0, 12.0, 13.0, 14.0, 15.0],
"0.9": [12.5, 13.8, 15.0, 16.3, 17.5],
"item_id": "series_1"
}
]
}
}
```
## API Reference
### Request Format
| Field | Type | Required | Description |
|---|---|---|---|
| `model` | `str` | ✅ | Model name (e.g., `"amazon/chronos-2"`) |
| `task` | `str` | ✅ | Must be `"plugin"` |
| `data.inputs` | `list` | ✅ | List of time series inputs (11024) |
| `data.parameters` | `dict` | | Forecast parameters |
#### Time Series Input (`data.inputs[*]`)
| Field | Type | Required | Description |
|---|---|---|---|
| `target` | `list[float]` or `list[list[float]]` | ✅ | Historical values (min 5 observations). 1-D for univariate, 2-D for multivariate. |
| `item_id` | `str` | | Identifier echoed in response |
| `start` | `str` | | ISO 8601 timestamp |
| `past_covariates` | `dict[str, list]` | | Past covariate arrays (must match target length) |
| `future_covariates` | `dict[str, list]` | | Future covariate arrays (must match `prediction_length`) |
#### Parameters (`data.parameters`)
| Field | Type | Default | Description |
|---|---|---|---|
| `prediction_length` | `int` | `1` | Forecast horizon (11024) |
| `quantile_levels` | `list[float]` | `[0.1, 0.5, 0.9]` | Quantile levels in (0, 1) |
| `freq` | `str` | `null` | Pandas frequency string (e.g., `"D"`, `"H"`) |
| `batch_size` | `int` | `256` | Inference batch size |
| `cross_learning` | `bool` | `false` | Enable cross-series learning |
### Response Format
Each prediction in `data.predictions` contains:
| Field | Type | Description |
|---|---|---|
| `mean` | `list[float]` | Point forecast (mean/median) |
| `"0.1"`, `"0.5"`, etc. | `list[float]` | Named quantile columns matching `quantile_levels` |
| `item_id` | `str` | Echoed from input (if provided) |
## Architecture
The vLLM model wrapper (`Chronos2ForForecasting`) is a thin adapter that delegates all computation to the existing `chronos.chronos2.model.Chronos2Model`. No model architecture is duplicated.
### Module Structure
```
src/chronos/chronos2/vllm/
├── __init__.py # Plugin entry point & registration
├── model.py # Chronos2ForForecasting (thin vLLM wrapper)
├── multimodal.py # MM pipeline for "timeseries" modality
├── io_processor.py # Chronos2IOProcessor (request/response handling)
├── protocol/
│ ├── __init__.py
│ ├── forecast.py # Pydantic models (TimeSeriesInput, ForecastParameters, etc.)
│ ├── validation.py # Input validation logic
│ └── data_prep.py # Tensor preparation from validated inputs
└── utils/
├── __init__.py
├── helpers.py # Utility functions
└── quantiles.py # Quantile selection & interpolation
```
### Key Classes
| Class | File | Purpose |
|---|---|---|
| `Chronos2ForForecasting` | `model.py` | Thin vLLM wrapper — delegates to `chronos.chronos2.model.Chronos2Model` |
| `Chronos2IOProcessor` | `io_processor.py` | Request parsing, validation, pre/post processing |
| `ForecastParameters` | `protocol/forecast.py` | Pydantic validation for forecast parameters |
| `TimeSeriesInput` | `protocol/forecast.py` | Pydantic validation for time series inputs |
| `ForecastPrediction` | `protocol/forecast.py` | Pydantic model for forecast output |
## Troubleshooting
### Server Flags
The following flags are required for Chronos-2:
```bash
--io-processor-plugin chronos2 # Enable the forecast IOProcessor
--enforce-eager # Chronos-2 doesn't support CUDA graphs
--no-enable-prefix-caching # Not applicable for time series
--skip-tokenizer-init # Chronos-2 doesn't use a text tokenizer
```
### Plugin Not Loading
1. Verify installation: `pip list | grep chronos-forecasting`
2. Check entry points:
```python
from importlib.metadata import entry_points
print(list(entry_points(group='vllm.general_plugins')))
```
3. Enable debug logging: `VLLM_LOGGING_LEVEL=DEBUG vllm serve ...`

View file

@ -0,0 +1,48 @@
"""vLLM Chronos-2 Time Series Forecasting Model Plugin.
This plugin registers the Chronos-2 model with vLLM's ModelRegistry,
allowing it to be used with vLLM's inference engine for time series forecasting.
"""
def register_chronos2_model() -> None:
"""Register Chronos-2 models with vLLM's ModelRegistry.
This function is called automatically when the plugin is loaded
through vLLM's plugin discovery mechanism.
"""
try:
from vllm.logger import init_logger
from vllm.model_executor.models.registry import ModelRegistry
logger = init_logger(__name__)
ModelRegistry.register_model(
"Chronos2Model",
"chronos.chronos2.vllm.model:Chronos2ForForecasting",
)
logger.info("Successfully registered Chronos-2 model with vLLM")
except Exception as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.error(f"Failed to register Chronos-2 model: {e}")
raise
def get_chronos2_io_processor():
"""
Factory function for IOProcessor plugin registration.
This function is called by vLLM when --io-processor-plugin is set to 'chronos2'.
Returns the fully qualified name (module.Class format) as a string.
"""
return "chronos.chronos2.vllm.io_processor.Chronos2IOProcessor"
__all__ = [
"register_chronos2_model",
"get_chronos2_io_processor",
]

View file

@ -0,0 +1,247 @@
"""IOProcessor for Chronos-2 time series forecasting.
Thin orchestrator that delegates to:
- protocol.data_prep: tensor preparation from validated inputs (via Chronos2Dataset)
- protocol.validation: cross-series validation
- multimodal.MODALITY: MM prompt construction
- utils: quantile selection, helpers
Flow:
1. parse_request validate inputs via Pydantic models
2. pre_process prepare_request() timeseries MM prompts (one per batch)
3. post_process select_quantiles() per-series predictions
"""
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.pooling_params import PoolingParams
from chronos.chronos2.vllm.multimodal import MODALITY
from chronos.chronos2.vllm.protocol.data_prep import PreparedRequest, prepare_request
from chronos.chronos2.vllm.protocol.forecast import (
ForecastParameters,
ForecastPrediction,
TimeSeriesInput,
)
from chronos.chronos2.vllm.protocol.validation import validate_cross_series
from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list
from chronos.chronos2.vllm.utils.quantiles import select_quantiles
logger = init_logger(__name__)
@dataclass
class _PostProcessInfo:
"""Lightweight cache of data needed for post_process — avoids storing full tensors."""
item_ids: list[str | None]
parameters: ForecastParameters
target_idx_ranges: list[list[tuple[int, int]]] # per-batch list of (start, end) ranges
class Chronos2IOProcessor(IOProcessor[dict, dict]):
"""IOProcessor for Chronos-2 time series forecasting."""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
config = vllm_config.model_config.hf_config
config.is_encoder_decoder = False
# Set projection_dim=0 so vLLM uses IdentityPooler (no projection layer).
# This avoids requiring --hf-overrides '{"projection_dim": 0}' on the CLI.
if not hasattr(config, "projection_dim") or config.projection_dim != 0:
config.projection_dim = 0
cc = getattr(config, "chronos_config", {})
self._chronos_config = (
cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {})
)
self._context_length = self._chronos_config.get("context_length", 8192)
self._output_patch_size = self._chronos_config.get("output_patch_size", 16)
self._model_quantiles = self._chronos_config.get(
"quantiles",
[
0.01,
0.05,
0.1,
0.15,
0.2,
0.25,
0.3,
0.35,
0.4,
0.45,
0.5,
0.55,
0.6,
0.65,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
0.99,
],
)
self._request_info: dict[str, _PostProcessInfo] = {}
logger.info("Initialized Chronos2IOProcessor")
# -----------------------------------------------------------
# Request parsing
# -----------------------------------------------------------
def parse_request(self, request: Any) -> dict:
data = request.data if hasattr(request, "data") else request
if not isinstance(data, dict):
raise ValueError(f"Expected dict, got {type(data)}")
if "inputs" not in data:
raise ValueError("Request must contain 'inputs' field")
raw_inputs = data["inputs"]
if not isinstance(raw_inputs, list) or len(raw_inputs) == 0:
raise ValueError("'inputs' must be a non-empty list")
validated_inputs = []
for i, ts in enumerate(raw_inputs):
try:
validated_inputs.append(TimeSeriesInput(**ts))
except Exception as e:
raise ValueError(f"Invalid input at index {i}: {e}") from e
try:
validated_params = ForecastParameters(**data.get("parameters", {}))
except Exception as e:
raise ValueError(f"Invalid parameters: {e}") from e
validate_cross_series(validated_inputs, validated_params)
return {"inputs": validated_inputs, "parameters": validated_params}
# -----------------------------------------------------------
# Pre-process: inputs → MM prompts
# -----------------------------------------------------------
def pre_process(
self,
prompt: dict,
request_id: str | None = None,
**kwargs: Any,
) -> PromptType | Sequence[PromptType]:
prepared = prepare_request(
inputs=prompt["inputs"],
parameters=prompt["parameters"],
context_length=self._context_length,
output_patch_size=self._output_patch_size,
)
# Cache only lightweight post-processing metadata
if request_id is not None:
self._request_info[request_id] = _PostProcessInfo(
item_ids=prepared.item_ids,
parameters=prepared.parameters,
target_idx_ranges=[batch.target_idx_ranges for batch in prepared.batches],
)
# Build one MM prompt per batch (Chronos2Dataset handles batch splitting)
prompts: list[PromptType] = []
for batch in prepared.batches:
prompts.append(
{
"prompt_token_ids": [1],
"multi_modal_data": {
MODALITY: {
"context": batch.context,
"future_covariates": batch.future_covariates,
"group_ids": batch.group_ids,
"num_output_patches": batch.num_output_patches,
}
},
}
)
return prompts if len(prompts) > 1 else prompts[0]
# -----------------------------------------------------------
# Post-process: model output → predictions
# -----------------------------------------------------------
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs: Any,
) -> dict:
info = self._request_info.pop(request_id, None) if request_id else None
if info is None:
logger.error("No pending request for request_id=%s", request_id)
return {"predictions": []}
parameters = info.parameters
item_ids = info.item_ids
try:
# Collect predictions across all batches, trimmed to exact prediction_length
pred_len = parameters.prediction_length
all_predictions: list[torch.Tensor] = []
for batch_idx, output in enumerate(model_output):
tensor = output.outputs.data
while tensor.ndim > 3:
tensor = tensor.squeeze(0)
batch_ranges = info.target_idx_ranges[batch_idx]
for start, end in batch_ranges:
all_predictions.append(tensor[start:end, :, :pred_len])
quantiles_out, mean_out = select_quantiles(
all_predictions, self._model_quantiles, parameters.quantile_levels
)
result: dict[str, Any] = {"predictions": [], "request_id": request_id}
for i, (q_tensor, m_tensor) in enumerate(zip(quantiles_out, mean_out)):
pred: dict[str, Any] = {"mean": tensor_to_list(m_tensor)}
for q, q_vals in zip(parameters.quantile_levels, q_tensor.unbind(dim=-1)):
pred[str(q)] = tensor_to_list(q_vals)
if i < len(item_ids) and item_ids[i] is not None:
pred["item_id"] = item_ids[i]
result["predictions"].append(pred)
except Exception as e:
logger.error("Failed to decode predictions: %s", e, exc_info=True)
result = {
"predictions": [
empty_prediction(parameters.prediction_length, parameters.quantile_levels)
for _ in item_ids
],
"request_id": request_id,
}
return result
# -----------------------------------------------------------
# Response formatting
# -----------------------------------------------------------
def validate_or_generate_params(self, params: Any = None) -> PoolingParams:
return PoolingParams(task="plugin")
def output_to_response(self, plugin_output: dict) -> IOProcessorResponse:
validated = []
for pred in plugin_output.get("predictions", []):
try:
validated.append(ForecastPrediction(**pred).model_dump(exclude_none=True))
except Exception as e:
logger.warning("Failed to validate prediction: %s", e)
validated.append(pred)
return IOProcessorResponse(
request_id=plugin_output.get("request_id"),
created_at=int(time.time()),
data={"predictions": validated},
)

View file

@ -0,0 +1,196 @@
"""Chronos-2 vLLM model wrapper.
Thin wrapper around the existing ``chronos.chronos2.model.Chronos2Model``
that plugs into vLLM's multimodal (MM) interface. No model architecture
is duplicated all computation is delegated to the upstream implementation.
Architecture:
pre_process IOProcessor prepares context/covariates/group_ids tensors
Chronos2Dataset handles batching and group_id construction
returns MM prompt(s) with timeseries data
forward() receives pre-batched context/future_covariates/group_ids as kwargs
delegates to chronos.chronos2.model.Chronos2Model
pooler IdentityPooler passes output through unchanged
post_process IOProcessor selects/interpolates quantiles
"""
import math
from collections.abc import Iterable
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.models.interfaces import (
IsAttentionFree,
MultiModalEmbeddings,
SupportsMultiModal,
)
from vllm.model_executor.models.interfaces_base import attn_type
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils import length_from_prompt_token_ids_or_embeds
from chronos.chronos2.model import Chronos2Model
from .multimodal import (
MODALITY,
ChronosInputBuilder,
ChronosMultiModalProcessor,
ChronosProcessingInfo,
)
logger = init_logger(__name__)
@attn_type("attention_free")
@MULTIMODAL_REGISTRY.register_processor(
ChronosMultiModalProcessor,
info=ChronosProcessingInfo,
dummy_inputs=ChronosInputBuilder,
)
class Chronos2ForForecasting(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Chronos-2 forecasting model for vLLM.
Delegates all computation to ``chronos.chronos2.model.Chronos2Model``.
Receives pre-batched tensors (context, future_covariates, group_ids)
directly in forward() via the timeseries MM pipeline. Batching is
handled upstream by ``Chronos2Dataset`` in the IOProcessor.
"""
supports_multimodal_raw_input_only = True
is_pooling_model = True
# Required by VllmModel interface
packed_modules_mapping: dict[str, Any] = {}
supported_lora_modules: list[str] = []
embedding_modules: dict[str, str] = {}
embedding_padding_modules: list[str] = []
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith(MODALITY):
return None
raise ValueError(f"Only {MODALITY} modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config.is_encoder_decoder = False
if not hasattr(config, "projection_dim") or config.projection_dim != 0:
config.projection_dim = 0
vllm_config.model_config.hf_text_config = None
self.config = config
self.model_name = vllm_config.model_config.model
self.d_model = getattr(config, "d_model", 768)
cc = getattr(config, "chronos_config", {})
self.chronos_config = (
cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {})
)
self.output_patch_size = self.chronos_config.get("output_patch_size", 16)
self.max_output_patches = self.chronos_config.get("max_output_patches", 64)
# Instantiate the upstream Chronos2Model — reuses all existing layers
self.model = Chronos2Model(config)
self.pooler = IdentityPooler()
logger.info(
"Initialized Chronos2ForForecasting (d_model=%d, delegating to chronos.chronos2.model)",
self.d_model,
)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""Return empty embeddings — Chronos-2 has no token vocabulary."""
return torch.empty((input_ids.shape[0], 0))
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
"""Run Chronos-2 inference by delegating to the upstream model.
Receives pre-batched context, future_covariates, group_ids from
the timeseries MM pipeline. Batching is handled upstream by
Chronos2Dataset in the IOProcessor.
"""
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
context: torch.Tensor | None = kwargs.get("context") # type: ignore[assignment]
future_covariates: torch.Tensor | None = kwargs.get("future_covariates") # type: ignore[assignment]
group_ids: torch.Tensor | None = kwargs.get("group_ids") # type: ignore[assignment]
num_output_patches: int | None = kwargs.get("num_output_patches") # type: ignore[assignment]
if context is None:
# Warmup/profiling pass — return zeros
return torch.zeros(input_len, 0, device=positions.device, dtype=torch.float32)
# Determine num_output_patches from preprocessor or fall back to computing from future_covariates
if num_output_patches is None:
prediction_length = future_covariates.shape[1] if future_covariates is not None else 0
if prediction_length == 0:
prediction_length = self.output_patch_size * self.max_output_patches
num_output_patches = int(math.ceil(prediction_length / self.output_patch_size))
num_output_patches = min(num_output_patches, self.max_output_patches)
prediction_length = num_output_patches * self.output_patch_size
# Pad or trim future_covariates to match output_size
if future_covariates is not None:
if prediction_length > future_covariates.shape[1]:
pad_size = prediction_length - future_covariates.shape[1]
pad_tensor = torch.full(
(future_covariates.shape[0], pad_size),
fill_value=float("nan"),
device=future_covariates.device,
)
future_covariates = torch.cat([future_covariates, pad_tensor], dim=1)
else:
future_covariates = future_covariates[:, :prediction_length]
# Delegate to the upstream Chronos2Model — single forward pass
model_kwargs: dict[str, Any] = {
"context": context,
"num_output_patches": num_output_patches,
}
if group_ids is not None:
model_kwargs["group_ids"] = group_ids
if future_covariates is not None:
model_kwargs["future_covariates"] = future_covariates
output = self.model(**model_kwargs)
batch_prediction = output.quantile_preds[..., :prediction_length]
# Expand to match input_len for vLLM pipeline compatibility.
# IdentityPooler passes through unchanged; post_process squeezes.
hidden_states = batch_prediction[None].expand(
input_len, *(-1 for _ in range(batch_prediction.ndim))
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights into the upstream Chronos2Model.
Prefix checkpoint names with 'model.' to match our wrapping.
Uses AutoWeightsLoader for robust weight loading.
"""
prefixed = [(f"model.{name}", tensor) for name, tensor in weights]
loader = AutoWeightsLoader(self)
return loader.load_weights(prefixed)

View file

@ -0,0 +1,237 @@
"""Multimodal boilerplate for the "timeseries" modality in vLLM.
Provides the MM processing pipeline classes that route timeseries
dict data (context, future_covariates, group_ids) through vLLM's
multimodal infrastructure.
"""
import hashlib
import time
from typing import Any, Mapping, Sequence
import torch
from transformers import BatchFeature
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalFieldElem,
MultiModalInputs,
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptUpdate,
)
# The modality name used throughout the MM pipeline.
MODALITY = "timeseries"
# Field names expected in the timeseries MM data dict.
REQUIRED_FIELDS = frozenset({"context", "future_covariates", "group_ids", "num_output_patches"})
def _field_config() -> dict[str, MultiModalFieldConfig]:
"""Shared field config for all timeseries fields."""
return {
"context": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
"future_covariates": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
"group_ids": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
"num_output_patches": MultiModalFieldConfig.shared(MODALITY, batch_size=1),
}
# -------------------------------------------------------------------
# Processing info: tells vLLM what modalities we support
# -------------------------------------------------------------------
class ChronosProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {MODALITY: None}
def get_data_parser(self) -> MultiModalDataParser:
return ChronosMultiModalDataParser()
def build_data_parser(self) -> MultiModalDataParser:
return ChronosMultiModalDataParser()
@property # type: ignore[override]
def data_parser(self) -> MultiModalDataParser:
if not hasattr(self, "_data_parser"):
self._data_parser = ChronosMultiModalDataParser()
return self._data_parser
# -------------------------------------------------------------------
# Dummy input builder: provides profiling data for GPU warmup
# -------------------------------------------------------------------
class ChronosInputBuilder(BaseDummyInputsBuilder[ChronosProcessingInfo]):
"""Provides dummy data for vLLM's GPU profiling/warmup pass."""
def __init__(self, info: ChronosProcessingInfo):
super().__init__(info)
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
**kwargs: Any,
) -> MultiModalDataDict:
return {
MODALITY: {
"context": torch.ones(100, 2048, dtype=torch.float32),
"future_covariates": torch.ones(100, 1024, dtype=torch.float32),
"group_ids": torch.zeros(100, dtype=torch.long),
"num_output_patches": 64,
}
}
# -------------------------------------------------------------------
# Data parser: routes timeseries dict data through the MM pipeline
# -------------------------------------------------------------------
class ChronosMultiModalDataParser(MultiModalDataParser):
"""Parses timeseries dict data for vLLM's MM pipeline."""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
def _parse_timeseries_data(
self,
data: dict[str, torch.Tensor],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality=MODALITY,
required_fields=REQUIRED_FIELDS,
fields_factory=lambda _: _field_config(),
)
return None
def _get_subparsers(self) -> Mapping[str, Any]:
return {MODALITY: self._parse_timeseries_data}
def parse_mm_data(self, mm_data: MultiModalDataDict, **kwargs: Any) -> MultiModalDataItems:
if MODALITY not in mm_data:
mm_data = {MODALITY: mm_data}
ts_data = mm_data[MODALITY]
items = self._parse_timeseries_data(ts_data)
if items is None:
raise ValueError("Failed to parse timeseries data")
return MultiModalDataItems({MODALITY: items})
# -------------------------------------------------------------------
# Processor: converts parsed data into MultiModalInputs
# -------------------------------------------------------------------
class ChronosMultiModalProcessor(BaseMultiModalProcessor):
"""Processes timeseries MM data into MultiModalInputs for vLLM."""
def __init__(
self,
info: ChronosProcessingInfo,
dummy_inputs: BaseDummyInputsBuilder[ChronosProcessingInfo],
*,
cache: MultiModalProcessorOnlyCache | None = None,
) -> None:
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _field_config()
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
return []
def apply(self, *args: Any, **kwargs: Any) -> MultiModalInputs:
# Handle both old and new vLLM calling conventions:
# Old: apply(prompt, mm_items, hf_processor_mm_kwargs, ...)
# New: apply(processor_inputs, timing_ctx)
# Internal (from get_dummy_mm_inputs): apply(prompt=..., mm_items=..., ...)
ts_data: dict[str, Any] = {}
if args and hasattr(args[0], "prompt"):
# New vLLM: first arg is ProcessorInputs dataclass
processor_inputs = args[0]
mm_items = getattr(processor_inputs, "mm_items", None) or getattr(
processor_inputs, "mm_data_items", None
)
if mm_items is not None and isinstance(mm_items, MultiModalDataItems):
if MODALITY in mm_items:
ts_items = mm_items[MODALITY]
ts_data = ts_items.data if hasattr(ts_items, "data") else {}
else:
# Old vLLM / direct call: extract from positional/keyword args
mm_items = args[1] if len(args) > 1 else kwargs.get("mm_items")
mm_data = kwargs.get("mm_data")
if mm_items is not None and isinstance(mm_items, MultiModalDataItems):
if MODALITY in mm_items:
ts_items = mm_items[MODALITY]
ts_data = ts_items.data if hasattr(ts_items, "data") else {}
elif mm_data is not None and isinstance(mm_data, dict):
ts_data = mm_data.get(MODALITY, mm_data)
mm_placeholders = {MODALITY: [PlaceholderRange(offset=0, length=0)]}
# Build MultiModalKwargsItems directly to ensure proper modality keying.
# from_hf_inputs + BatchFeature can produce empty modalities when
# the data doesn't match the expected shared field batch structure.
field_config = _field_config()
mm_item_dict: dict[str, MultiModalFieldElem] = {}
for key, config in field_config.items():
tensor = ts_data.get(key) if isinstance(ts_data, dict) else None
if tensor is not None:
mm_item_dict[key] = MultiModalFieldElem(
data=tensor,
field=config.field,
)
mm_kwargs_items = MultiModalKwargsItems({MODALITY: [MultiModalKwargsItem(mm_item_dict)]})
# Unique hash per request (required by vLLM v0.16+)
ts_hash = hashlib.sha256(
str(id(ts_data)).encode() + str(time.monotonic()).encode()
).hexdigest()[:16]
return MultiModalInputs(
type="multimodal",
prompt_token_ids=[1],
mm_kwargs=mm_kwargs_items,
mm_hashes={MODALITY: [ts_hash]},
mm_placeholders=mm_placeholders,
)

View file

@ -0,0 +1,17 @@
"""Chronos-2 forecasting protocol definitions and data preparation."""
from chronos.chronos2.vllm.protocol.forecast import (
ForecastParameters,
ForecastPrediction,
ForecastRequest,
ForecastResponse,
TimeSeriesInput,
)
__all__ = [
"TimeSeriesInput",
"ForecastParameters",
"ForecastRequest",
"ForecastPrediction",
"ForecastResponse",
]

View file

@ -0,0 +1,136 @@
"""Data preparation for Chronos-2 model input tensors.
Converts validated TimeSeriesInput objects into batched tensors
ready for the model by delegating to ``chronos.chronos2.dataset``.
"""
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode
from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput
@dataclass
class PreparedBatch:
"""A single batch of prepared model input tensors."""
context: torch.Tensor # (batch_rows, padded_ctx_len)
future_covariates: torch.Tensor # (batch_rows, prediction_length)
group_ids: torch.Tensor # (batch_rows,)
num_output_patches: int
target_idx_ranges: list[tuple[int, int]] # per-input-series (start, end) in this batch
@dataclass
class PreparedRequest:
"""All batches and metadata for a single forecast request."""
batches: list[PreparedBatch]
item_ids: list[str | None]
parameters: ForecastParameters
def _timeseries_input_to_dict(ts: TimeSeriesInput) -> dict[str, Any]:
"""Convert a Pydantic TimeSeriesInput to the dict format expected by Chronos2Dataset.
The dataset expects:
- ``target``: np.ndarray of shape (history_length,) or (n_variates, history_length)
- ``past_covariates``: dict[str, np.ndarray]
- ``future_covariates``: dict[str, np.ndarray]
"""
target = np.array(ts.target, dtype=np.float32)
entry: dict[str, Any] = {"target": target}
if ts.past_covariates:
past_covs: dict[str, np.ndarray] = {}
for key, vals in ts.past_covariates.items():
arr = np.array([v if v is not None else np.nan for v in vals])
# Keep string arrays as object dtype for categorical encoding
if any(isinstance(v, str) for v in vals if v is not None):
arr = np.array([str(v) if v is not None else "nan" for v in vals])
else:
arr = arr.astype(np.float32)
past_covs[key] = arr
entry["past_covariates"] = past_covs
if ts.future_covariates:
future_covs: dict[str, np.ndarray] = {}
for key, vals in ts.future_covariates.items():
arr = np.array([v if v is not None else np.nan for v in vals])
if any(isinstance(v, str) for v in vals if v is not None):
arr = np.array([str(v) if v is not None else "nan" for v in vals])
else:
arr = arr.astype(np.float32)
future_covs[key] = arr
entry["future_covariates"] = future_covs
return entry
def prepare_request(
inputs: list[TimeSeriesInput],
parameters: ForecastParameters,
context_length: int = 8192,
output_patch_size: int = 16,
) -> PreparedRequest:
"""Convert validated inputs into batched model-ready tensors.
Delegates to ``Chronos2Dataset`` in TEST mode for data preparation,
covariate encoding, batching, and group ID construction.
Args:
inputs: validated time series inputs
parameters: forecast parameters
context_length: maximum context length (from model config)
output_patch_size: model's output patch size (from model config)
Returns:
PreparedRequest with batched tensors and metadata
"""
# Convert Pydantic models to dicts for Chronos2Dataset
raw_inputs = [_timeseries_input_to_dict(ts) for ts in inputs]
item_ids = [ts.item_id for ts in inputs]
pred_len = parameters.prediction_length
# Use Chronos2Dataset in TEST mode — handles validation, covariate encoding,
# left-padding, and group_id construction. We use a very large batch_size
# to always produce exactly one batch; row-level chunking (if needed for
# memory) can be handled by the caller or the model's forward pass.
dataset = Chronos2Dataset(
inputs=raw_inputs,
context_length=context_length,
prediction_length=pred_len,
batch_size=2**31 - 1, # large enough to fit all series in one batch
output_patch_size=output_patch_size,
mode=DatasetMode.TEST,
convert_inputs=True,
)
batches: list[PreparedBatch] = []
for batch_dict in dataset:
group_ids = batch_dict["group_ids"]
if parameters.cross_learning:
group_ids = torch.zeros_like(group_ids)
batches.append(
PreparedBatch(
context=batch_dict["context"],
future_covariates=batch_dict["future_covariates"],
group_ids=group_ids,
num_output_patches=batch_dict["num_output_patches"],
target_idx_ranges=batch_dict["target_idx_ranges"],
)
)
return PreparedRequest(
batches=batches,
item_ids=item_ids,
parameters=parameters,
)

View file

@ -0,0 +1,218 @@
import math
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, field_validator, model_validator
try:
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
except ImportError:
OpenAIBaseModel = BaseModel # type: ignore[misc,assignment]
from chronos.chronos2.vllm.protocol.validation import (
MAX_NUM_TIME_SERIES,
validate_quantile_levels,
validate_single_series_covariates,
validate_start_timestamp,
validate_target,
)
class TimeSeriesInput(BaseModel):
"""Input time series data for forecasting."""
target: list[float] | list[list[float]] = Field(
...,
description="Historical time series values. "
"1-D array for univariate, 2-D array for multivariate",
)
item_id: str | None = Field(default=None, description="Unique identifier for the time series")
start: str | None = Field(
default=None,
description="Start timestamp in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)",
)
# Values can be NaN (numeric) or None/NaN mixed in string arrays.
past_covariates: dict[str, list[Union[float, str, None]]] | None = Field(
default=None,
description=(
"Dictionary of past covariate arrays (numeric or categorical). "
"Each array must match the length of the target"
),
)
# Values can be NaN (numeric) or None/NaN mixed in string arrays.
future_covariates: dict[str, list[Union[float, str, None]]] | None = Field(
default=None,
description="Dictionary of known future covariate arrays. "
"Keys must be a subset of past_covariates. "
"Each array must match prediction_length",
)
@field_validator("past_covariates", "future_covariates", mode="before")
@classmethod
def _sanitize_nan_in_covariates(
cls, v: dict[str, list[Any]] | None
) -> dict[str, list[Any]] | None:
"""Convert NaN floats to None in covariate arrays.
Datasets often encode missing categorical values as float NaN.
Pydantic rejects NaN in string-typed lists, so we normalize
NaN None before validation.
"""
if v is None:
return v
sanitized: dict[str, list[Any]] = {}
for key, arr in v.items():
sanitized[key] = [None if isinstance(x, float) and math.isnan(x) else x for x in arr]
return sanitized
@field_validator("target")
@classmethod
def _validate_target(
cls, v: list[float] | list[list[float]]
) -> list[float] | list[list[float]]:
return validate_target(v)
@field_validator("start")
@classmethod
def _validate_start(cls, v: str | None) -> str | None:
return validate_start_timestamp(v)
@model_validator(mode="after")
def _validate_covariates(self) -> "TimeSeriesInput":
validate_single_series_covariates(self.target, self.past_covariates, self.future_covariates)
return self
class ForecastParameters(BaseModel):
"""Parameters for time series forecasting."""
prediction_length: int = Field(
default=1, ge=1, le=1024, description="Number of future steps to forecast"
)
quantile_levels: list[float] = Field(
default=[0.1, 0.5, 0.9],
description="Quantile levels for uncertainty quantification. "
"Each value must be between 0 and 1 (exclusive)",
)
freq: str | None = Field(
default=None,
description="Pandas frequency string (e.g., 'D' for daily, 'H' for hourly). "
"Required if 'start' is provided in inputs",
)
batch_size: int = Field(
default=256,
ge=1,
description="Internal row batch size for model inference. "
"Controls how many rows (series + covariates) are processed "
"in a single model forward pass. Rows are chunked internally "
"respecting series boundaries via group_ids.",
)
cross_learning: bool = Field(
default=False,
description="Enable information sharing across time series in batch",
)
@field_validator("quantile_levels")
@classmethod
def _validate_quantiles(cls, v: list[float]) -> list[float]:
return validate_quantile_levels(v)
class ForecastRequest(OpenAIBaseModel):
"""Request format for time series forecasting via pooling API."""
model: str = Field(..., description="Model name to use for forecasting")
task: Literal["forecast"] = Field(
default="forecast", description="Task type, must be 'forecast'"
)
data: dict[str, Any] = Field(
...,
description=("Forecast request data containing 'inputs' and optional 'parameters'"),
)
@field_validator("data")
@classmethod
def validate_data_structure(cls, v: dict[str, Any]) -> dict[str, Any]:
"""Validate the data field contains required structure."""
if "inputs" not in v:
raise ValueError("data must contain 'inputs' field")
if not isinstance(v["inputs"], list):
raise ValueError("data.inputs must be a list")
if len(v["inputs"]) == 0:
raise ValueError("data.inputs cannot be empty")
if len(v["inputs"]) > MAX_NUM_TIME_SERIES:
raise ValueError(
f"data.inputs may contain at most {MAX_NUM_TIME_SERIES} time series "
f"(received {len(v['inputs'])})"
)
# Validate each input as TimeSeriesInput
validated_inputs = []
for i, ts_input in enumerate(v["inputs"]):
try:
validated_input = TimeSeriesInput(**ts_input)
validated_inputs.append(validated_input)
except Exception as e:
raise ValueError(f"Invalid time series input at index {i}: {e}") from e
# Validate parameters if present
validated_params = None
if "parameters" in v and v["parameters"] is not None:
try:
validated_params = ForecastParameters(**v["parameters"])
except Exception as e:
raise ValueError(f"Invalid parameters: {e}") from e
# Cross-validate future_covariates length with prediction_length
if validated_params is not None:
prediction_length = validated_params.prediction_length
for i, ts_input in enumerate(validated_inputs):
if ts_input.future_covariates is not None:
for key, values in ts_input.future_covariates.items():
if len(values) != prediction_length:
raise ValueError(
f"Input {i}: future_covariate '{key}' length "
f"({len(values)}) must match prediction_length "
f"({prediction_length})"
)
return v
class ForecastPrediction(BaseModel):
"""Single time series forecast result with quantile forecasts."""
mean: list[float] | list[list[float]] = Field(
..., description="Point forecast (mean). Shape matches input target"
)
item_id: str | None = Field(default=None, description="Echoed from input if provided")
start: str | None = Field(default=None, description="Start timestamp of forecast horizon")
class Config:
extra = "allow" # Allow dynamic quantile fields like "0.1", "0.5", "0.9"
class ForecastResponse(OpenAIBaseModel):
"""Response format for time series forecasting."""
request_id: str = Field(..., description="Request identifier")
created_at: int = Field(..., description="Unix timestamp when the response was created")
data: dict[str, list[ForecastPrediction]] = Field(
..., description="Forecast results with 'predictions' key"
)

View file

@ -0,0 +1,242 @@
"""Centralized validation for Chronos-2 forecast requests.
All validation rules live here so that both Pydantic model validators
(in ``protocol.forecast``) and the IOProcessor can delegate to a single
source of truth. The design mirrors the SageMaker endpoint's
``validate_payload`` / ``validate_covariates`` utilities.
"""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from chronos.chronos2.vllm.protocol.forecast import ( # noqa: F811
ForecastParameters,
TimeSeriesInput,
)
# ---------------------------------------------------------------------------
# Constants — match the SageMaker endpoint constraints
# TODO: Get some of the values from model config
# ---------------------------------------------------------------------------
MIN_TARGET_LENGTH: int = 5
MAX_PREDICTION_LENGTH: int = 1024
MAX_NUM_TIME_SERIES: int = 1024
# ---------------------------------------------------------------------------
# Per-input validators (called from Pydantic field/model validators)
# ---------------------------------------------------------------------------
def validate_target(
target: list[float] | list[list[float]],
) -> list[float] | list[list[float]]:
"""Validate minimum target length and multivariate row consistency.
Raises ``ValueError`` if:
- The target is empty.
- The target has fewer than ``MIN_TARGET_LENGTH`` observations.
- For multivariate targets, rows have inconsistent lengths.
"""
if not target:
raise ValueError("Target must not be empty")
if isinstance(target[0], list):
first_len = len(target[0])
if first_len < MIN_TARGET_LENGTH:
raise ValueError(
f"Target must contain at least {MIN_TARGET_LENGTH} "
f"observations (received {first_len})"
)
for i, dim in enumerate(target[1:], start=1):
if len(dim) != first_len: # type: ignore[arg-type]
raise ValueError(
f"All target dimensions must have same length. "
f"Dimension 0 has {first_len} observations, "
f"dimension {i} has {len(dim)}" # type: ignore[arg-type]
)
else:
if len(target) < MIN_TARGET_LENGTH:
raise ValueError(
f"Target must contain at least {MIN_TARGET_LENGTH} "
f"observations (received {len(target)})"
)
return target
def validate_start_timestamp(start: str | None) -> str | None:
"""Validate that *start* is a valid ISO-8601 string (or ``None``)."""
if start is not None:
try:
datetime.fromisoformat(start.replace("Z", "+00:00"))
except ValueError as e:
raise ValueError(
f"Invalid start timestamp format: {start}. "
f"Expected ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)"
) from e
return start
def validate_quantile_levels(quantile_levels: list[float]) -> list[float]:
"""Validate that every quantile level is in the open interval (0, 1)."""
for q in quantile_levels:
if not (0 < q < 1):
raise ValueError(f"Quantile levels must be between 0 and 1 (exclusive), got {q}")
return quantile_levels
def validate_single_series_covariates(
target: list[float] | list[list[float]],
past_covariates: dict[str, list] | None,
future_covariates: dict[str, list] | None,
) -> None:
"""Validate covariate constraints for a **single** time series.
Raises ``ValueError`` if any of the following are violated:
- ``'target'`` is used as a covariate name.
- ``future_covariates`` is provided without ``past_covariates``.
- Future covariate keys are not a subset of past covariate keys.
- Past covariate array lengths do not match the target length.
"""
if not target:
raise ValueError("Target must not be empty")
target_len = len(target[0]) if isinstance(target[0], list) else len(target)
# 'target' must not be used as a covariate name
for label, covariates in [
("past_covariates", past_covariates),
("future_covariates", future_covariates),
]:
if covariates is not None and "target" in covariates:
raise ValueError("Covariate with name 'target' is not allowed")
# future_covariates requires past_covariates
if future_covariates is not None and past_covariates is None:
raise ValueError(
"Both 'past_covariates' and 'future_covariates' must be provided "
"together. Got 'future_covariates' without 'past_covariates'"
)
# future keys ⊆ past keys
if past_covariates is not None and future_covariates is not None:
past_keys = set(past_covariates.keys())
future_keys = set(future_covariates.keys())
if not future_keys.issubset(past_keys):
extra = future_keys - past_keys
raise ValueError(
f"All future covariate keys must be present in past covariates. "
f"Keys {extra} are in 'future_covariates' but not in "
f"'past_covariates'"
)
# past covariate lengths must match target length
if past_covariates is not None:
for key, values in past_covariates.items():
if len(values) != target_len:
raise ValueError(
f"Past covariate '{key}' length ({len(values)}) "
f"must match target length ({target_len})"
)
# ---------------------------------------------------------------------------
# Cross-series validators (called from IOProcessor.parse_request)
# ---------------------------------------------------------------------------
def validate_cross_series(
inputs: list[TimeSeriesInput],
parameters: ForecastParameters,
) -> None:
"""Validate constraints that span across multiple time series.
Mirrors the SageMaker endpoint's ``validate_payload`` and
``validate_covariates`` utilities.
Raises ``ValueError`` if any of the following are violated:
- ``item_id`` provided for some but not all inputs.
- ``item_id`` values are not unique.
- ``start`` provided for some but not all inputs.
- ``start`` is provided without ``freq`` (or vice-versa).
- Covariate keys are not identical across all series.
- ``future_covariates`` array lengths don't match ``prediction_length``.
"""
_validate_item_ids(inputs)
_validate_start_freq(inputs, parameters)
_validate_covariate_consistency(inputs)
_validate_future_covariate_lengths(inputs, parameters)
# -- helpers (private) -------------------------------------------------------
def _validate_item_ids(inputs: list[TimeSeriesInput]) -> None:
item_ids = [ts.item_id for ts in inputs]
has_none = any(x is None for x in item_ids)
has_value = any(x is not None for x in item_ids)
if has_none and has_value:
raise ValueError(
"If 'item_id' is provided for at least one time series in "
"'inputs', it should be provided for all time series"
)
if has_value and len(item_ids) != len(set(item_ids)):
raise ValueError("'item_id' must be unique for all time series in 'inputs'")
def _validate_start_freq(
inputs: list[TimeSeriesInput],
parameters: ForecastParameters,
) -> None:
starts = [ts.start for ts in inputs]
has_none = any(x is None for x in starts)
has_value = any(x is not None for x in starts)
if has_none and has_value:
raise ValueError(
"If 'start' is provided for at least one time series in "
"'inputs', it should be provided for all time series"
)
if has_value and parameters.freq is None:
raise ValueError(
"If 'start' is provided, then 'freq' must also be provided " "in 'parameters'"
)
if parameters.freq is not None and not has_value:
raise ValueError(
"If 'freq' is provided in 'parameters', then 'start' must "
"also be provided for all time series in 'inputs'"
)
def _validate_covariate_consistency(inputs: list[TimeSeriesInput]) -> None:
key_sets: list[frozenset[str] | None] = []
for ts in inputs:
if ts.past_covariates is not None:
key_sets.append(frozenset(ts.past_covariates.keys()))
else:
key_sets.append(None)
if len(set(key_sets)) > 1:
raise ValueError(
"If 'past_covariates' and 'future_covariates' are provided "
"for at least one time series in 'inputs', the same "
"covariates should be provided for all time series"
)
def _validate_future_covariate_lengths(
inputs: list[TimeSeriesInput],
parameters: ForecastParameters,
) -> None:
prediction_length = parameters.prediction_length
for i, ts in enumerate(inputs):
if ts.future_covariates is not None:
for key, values in ts.future_covariates.items():
if len(values) != prediction_length:
raise ValueError(
f"Input {i}: length of future covariate '{key}' "
f"({len(values)}) must equal prediction_length "
f"({prediction_length})"
)

View file

@ -0,0 +1 @@
"""Utility modules for Chronos-2 vLLM plugin."""

View file

@ -0,0 +1,35 @@
"""Common helper functions for Chronos-2 vLLM plugin."""
from typing import Any
import torch
def tensor_to_list(tensor: torch.Tensor) -> list[float] | list[list[float]]:
"""Convert 2-D tensor to list (univariate) or list of lists (multivariate).
Args:
tensor: shape (n_variates, horizon)
Returns:
list[float] if univariate (n_variates == 1), else list[list[float]]
"""
assert tensor.ndim == 2
return tensor[0].tolist() if tensor.shape[0] == 1 else [row.tolist() for row in tensor]
def empty_prediction(
prediction_length: int,
quantile_levels: list[float],
) -> dict[str, Any]:
"""Return a zero-filled prediction dict for error cases.
Args:
prediction_length: number of forecast steps
quantile_levels: list of quantile levels to include
"""
zeros = [0.0] * prediction_length
pred: dict[str, Any] = {"mean": zeros}
for q in quantile_levels:
pred[str(q)] = zeros
return pred

View file

@ -0,0 +1,40 @@
"""Quantile selection and interpolation utilities."""
import torch
from chronos.utils import interpolate_quantiles
def select_quantiles(
predictions: list[torch.Tensor],
model_quantiles: list[float],
requested_levels: list[float],
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Select or interpolate quantiles from model output.
Args:
predictions: list of tensors, each shape (..., num_quantiles, horizon)
model_quantiles: the quantile levels the model was trained on (sorted ascending)
requested_levels: the quantile levels the caller wants
Returns:
(quantiles, mean) where:
- quantiles: list of tensors, each shape (..., horizon, num_requested_quantiles)
- mean: list of tensors, each shape (..., horizon) the median
"""
# Swap quantile and time axes: [... q h] -> [... h q]
swapped = [pred.permute(*range(pred.ndim - 2), -1, -2) for pred in predictions]
if set(requested_levels).issubset(model_quantiles):
indices = [model_quantiles.index(q) for q in requested_levels]
quantiles = [pred[..., indices] for pred in swapped]
else:
quantiles = [
interpolate_quantiles(requested_levels, model_quantiles, pred) for pred in swapped
]
# Median as mean (Chronos-2 convention)
median_idx = model_quantiles.index(0.5) if 0.5 in model_quantiles else len(model_quantiles) // 2
mean = [pred[..., median_idx] for pred in swapped]
return quantiles, mean

View file

@ -0,0 +1 @@
# Chronos-2 plugin tests

View file

@ -0,0 +1,360 @@
"""Tests for Chronos-2 plugin validation logic."""
import pytest
from chronos.chronos2.vllm.protocol.forecast import (
ForecastParameters,
ForecastRequest,
TimeSeriesInput,
)
# ============================================================================
# TimeSeriesInput — per-input validation
# ============================================================================
class TestTimeSeriesInputTargetValidation:
"""Tests for target field validation."""
def test_valid_univariate_target(self):
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])
assert len(ts.target) == 5
def test_valid_multivariate_target(self):
ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
assert len(ts.target) == 2
def test_univariate_target_too_short(self):
with pytest.raises(ValueError, match="at least 5 observations"):
TimeSeriesInput(target=[1.0, 2.0, 3.0])
def test_multivariate_target_too_short(self):
with pytest.raises(ValueError, match="at least 5 observations"):
TimeSeriesInput(target=[[1.0, 2.0], [3.0, 4.0]])
def test_multivariate_inconsistent_lengths(self):
with pytest.raises(ValueError, match="All target dimensions must have same length"):
TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
class TestTimeSeriesInputStartTimestamp:
"""Tests for start timestamp validation."""
def test_valid_date_format(self):
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01")
assert ts.start == "2024-01-01"
def test_valid_datetime_format(self):
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01T12:00:00")
assert ts.start == "2024-01-01T12:00:00"
def test_invalid_timestamp_format(self):
with pytest.raises(ValueError, match="Invalid start timestamp format"):
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="not-a-date")
class TestTimeSeriesInputCovariateValidation:
"""Tests for per-input covariate validation."""
def test_valid_past_covariates(self):
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
)
assert ts.past_covariates is not None
def test_past_covariate_length_mismatch(self):
with pytest.raises(ValueError, match="must match target length"):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"feature_a": [1.0, 2.0, 3.0]},
)
def test_target_name_in_past_covariates_rejected(self):
with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"target": [1.0, 2.0, 3.0, 4.0, 5.0]},
)
def test_target_name_in_future_covariates_rejected(self):
with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
future_covariates={"target": [1.0]},
)
def test_future_covariates_without_past_rejected(self):
with pytest.raises(
ValueError, match="Both 'past_covariates' and 'future_covariates' must be provided"
):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
future_covariates={"feature_a": [1.0]},
)
def test_future_covariate_keys_must_be_subset_of_past(self):
with pytest.raises(
ValueError, match="All future covariate keys must be present in past covariates"
):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
future_covariates={"feature_b": [1.0]},
)
def test_future_covariate_keys_subset_is_valid(self):
"""Past covariates may have more keys than future — that's fine."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={
"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature_b": [1.0, 2.0, 3.0, 4.0, 5.0],
},
future_covariates={"feature_a": [1.0]},
)
assert ts.future_covariates is not None
def test_past_only_covariates_valid(self):
"""Having past_covariates without future_covariates is acceptable."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]},
)
assert ts.past_covariates is not None
assert ts.future_covariates is None
# ============================================================================
# ForecastParameters validation
# ============================================================================
class TestForecastParameters:
"""Tests for ForecastParameters validation."""
def test_default_values(self):
params = ForecastParameters()
assert params.prediction_length == 1
assert params.quantile_levels == [0.1, 0.5, 0.9]
assert params.freq is None
assert params.batch_size == 256
assert params.cross_learning is False
def test_prediction_length_too_large(self):
with pytest.raises(ValueError):
ForecastParameters(prediction_length=2000)
def test_prediction_length_zero(self):
with pytest.raises(ValueError):
ForecastParameters(prediction_length=0)
def test_quantile_out_of_range(self):
with pytest.raises(ValueError, match="between 0 and 1"):
ForecastParameters(quantile_levels=[0.0, 0.5])
def test_quantile_one(self):
with pytest.raises(ValueError, match="between 0 and 1"):
ForecastParameters(quantile_levels=[0.5, 1.0])
# ============================================================================
# Cross-series validation (via validation module)
# ============================================================================
class TestCrossSeriesValidation:
"""Tests for cross-series validation logic in validation module."""
@staticmethod
def _validate(inputs, parameters=None):
"""Helper that calls validate_cross_series from the validation module."""
from chronos.chronos2.vllm.protocol.validation import validate_cross_series
if parameters is None:
parameters = ForecastParameters()
validate_cross_series(inputs, parameters)
def test_item_id_all_provided(self):
"""All item_ids present — valid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
TimeSeriesInput(target=[2.0] * 5, item_id="B"),
]
self._validate(inputs) # should not raise
def test_item_id_none_provided(self):
"""No item_ids at all — valid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5),
TimeSeriesInput(target=[2.0] * 5),
]
self._validate(inputs) # should not raise
def test_item_id_partial(self):
"""Some have item_id, some don't — invalid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
TimeSeriesInput(target=[2.0] * 5),
]
with pytest.raises(ValueError, match="item_id.*provided for all time series"):
self._validate(inputs)
def test_item_id_not_unique(self):
"""Duplicate item_ids — invalid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, item_id="A"),
TimeSeriesInput(target=[2.0] * 5, item_id="A"),
]
with pytest.raises(ValueError, match="item_id.*must be unique"):
self._validate(inputs)
def test_start_all_provided_with_freq(self):
"""All starts present with freq — valid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
TimeSeriesInput(target=[2.0] * 5, start="2024-02-01"),
]
params = ForecastParameters(freq="D")
self._validate(inputs, params) # should not raise
def test_start_partial(self):
"""Some have start, some don't — invalid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
TimeSeriesInput(target=[2.0] * 5),
]
params = ForecastParameters(freq="D")
with pytest.raises(ValueError, match="start.*provided for all time series"):
self._validate(inputs, params)
def test_start_without_freq(self):
"""start provided but freq missing — invalid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"),
]
params = ForecastParameters() # freq is None
with pytest.raises(ValueError, match="freq.*must also be provided"):
self._validate(inputs, params)
def test_freq_without_start(self):
"""freq provided but no start on inputs — invalid."""
inputs = [
TimeSeriesInput(target=[1.0] * 5),
]
params = ForecastParameters(freq="D")
with pytest.raises(ValueError, match="start.*must also be provided"):
self._validate(inputs, params)
def test_covariate_keys_consistent(self):
"""Same covariate keys on all series — valid."""
inputs = [
TimeSeriesInput(
target=[1.0] * 5,
past_covariates={"feat": [1.0] * 5},
),
TimeSeriesInput(
target=[2.0] * 5,
past_covariates={"feat": [2.0] * 5},
),
]
self._validate(inputs) # should not raise
def test_covariate_keys_inconsistent(self):
"""Different covariate keys across series — invalid."""
inputs = [
TimeSeriesInput(
target=[1.0] * 5,
past_covariates={"feat_a": [1.0] * 5},
),
TimeSeriesInput(
target=[2.0] * 5,
past_covariates={"feat_b": [2.0] * 5},
),
]
with pytest.raises(ValueError, match="same covariates should be provided"):
self._validate(inputs)
def test_covariate_some_have_none(self):
"""One series has covariates, other doesn't — invalid."""
inputs = [
TimeSeriesInput(
target=[1.0] * 5,
past_covariates={"feat": [1.0] * 5},
),
TimeSeriesInput(target=[2.0] * 5),
]
with pytest.raises(ValueError, match="same covariates should be provided"):
self._validate(inputs)
def test_future_covariate_length_matches_prediction_length(self):
"""future_covariates length equals prediction_length — valid."""
inputs = [
TimeSeriesInput(
target=[1.0] * 5,
past_covariates={"feat": [1.0] * 5},
future_covariates={"feat": [1.0, 2.0, 3.0]},
),
]
params = ForecastParameters(prediction_length=3)
self._validate(inputs, params) # should not raise
def test_future_covariate_length_mismatch(self):
"""future_covariates length != prediction_length — invalid."""
inputs = [
TimeSeriesInput(
target=[1.0] * 5,
past_covariates={"feat": [1.0] * 5},
future_covariates={"feat": [1.0, 2.0]},
),
]
params = ForecastParameters(prediction_length=3)
with pytest.raises(ValueError, match="must equal prediction_length"):
self._validate(inputs, params)
# ============================================================================
# ForecastRequest — end-to-end request validation
# ============================================================================
class TestForecastRequest:
"""Tests for ForecastRequest end-to-end validation."""
def test_valid_minimal_request(self):
req = ForecastRequest(
model="chronos-v2",
data={"inputs": [{"target": [1.0, 2.0, 3.0, 4.0, 5.0]}]},
)
assert req.data["inputs"] is not None
def test_empty_inputs_rejected(self):
with pytest.raises(ValueError, match="cannot be empty"):
ForecastRequest(model="chronos-v2", data={"inputs": []})
def test_too_many_inputs_rejected(self):
many_inputs = [{"target": [1.0] * 5} for _ in range(1025)]
with pytest.raises(ValueError, match="at most 1024"):
ForecastRequest(model="chronos-v2", data={"inputs": many_inputs})
def test_missing_inputs_rejected(self):
with pytest.raises(ValueError, match="inputs"):
ForecastRequest(model="chronos-v2", data={"parameters": {}})
def test_future_covariate_length_cross_validated(self):
"""ForecastRequest validates future_covariates against prediction_length."""
with pytest.raises(ValueError, match="prediction_length"):
ForecastRequest(
model="chronos-v2",
data={
"inputs": [
{
"target": [1.0] * 5,
"past_covariates": {"feat": [1.0] * 5},
"future_covariates": {"feat": [1.0, 2.0]},
}
],
"parameters": {"prediction_length": 5},
},
)

View file

@ -0,0 +1,248 @@
"""Unit tests for chronos.chronos2.vllm.protocol.data_prep."""
import pytest
torch = pytest.importorskip("torch")
from chronos.chronos2.vllm.protocol.data_prep import ( # noqa: E402
PreparedRequest,
prepare_request,
)
from chronos.chronos2.vllm.protocol.forecast import ( # noqa: E402
ForecastParameters,
TimeSeriesInput,
)
class TestPrepareRequestUnivariate:
"""Tests for prepare_request with univariate inputs."""
def test_single_series(self):
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
params = ForecastParameters(prediction_length=3)
result = prepare_request(inputs, params)
assert isinstance(result, PreparedRequest)
assert len(result.batches) >= 1
assert len(result.item_ids) == 1
assert result.item_ids[0] is None
batch = result.batches[0]
assert batch.context.shape[0] == 1
assert batch.context.shape[1] == 5
assert batch.group_ids.shape == (1,)
assert batch.target_idx_ranges == [(0, 1)]
def test_multiple_series(self):
inputs = [
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="a"),
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0], item_id="b"),
]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
# Both series should fit in one batch with default batch_size
assert len(result.batches) >= 1
# Count total rows across all batches
total_rows = sum(b.context.shape[0] for b in result.batches)
assert total_rows == 2
assert result.item_ids == ["a", "b"]
def test_different_lengths_padded(self):
inputs = [
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]),
]
params = ForecastParameters(prediction_length=1)
result = prepare_request(inputs, params)
batch = result.batches[0]
# Padded to max length = 10
assert batch.context.shape == (2, 10)
# First series: right-aligned, left NaN-padded
assert torch.isnan(batch.context[0, :5]).all()
assert not torch.isnan(batch.context[0, 5:]).any()
def test_truncation_to_context_length(self):
long_target = list(range(100))
inputs = [TimeSeriesInput(target=long_target)]
params = ForecastParameters(prediction_length=1)
result = prepare_request(inputs, params, context_length=20)
batch = result.batches[0]
assert batch.context.shape == (1, 20)
def test_future_covariates_nan_for_targets(self):
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
params = ForecastParameters(prediction_length=3)
result = prepare_request(inputs, params)
# Target rows should have NaN future covariates
assert torch.isnan(result.batches[0].future_covariates[0]).all()
class TestPrepareRequestMultivariate:
"""Tests for prepare_request with multivariate inputs."""
def test_bivariate(self):
inputs = [TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
batch = result.batches[0]
# 2 variates = 2 rows
assert batch.context.shape == (2, 5)
assert batch.target_idx_ranges == [(0, 2)]
assert batch.group_ids.tolist() == [0, 0]
class TestPrepareRequestCovariates:
"""Tests for prepare_request with covariates."""
def test_past_covariates(self):
inputs = [
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
)
]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
batch = result.batches[0]
# 1 target + 1 covariate = 2 rows
assert batch.context.shape == (2, 5)
# Only target row is in target_idx_ranges
assert batch.target_idx_ranges == [(0, 1)]
assert batch.group_ids.tolist() == [0, 0]
def test_future_covariates(self):
inputs = [
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
future_covariates={"temp": [60.0, 70.0]},
)
]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
batch = result.batches[0]
# Target row: NaN future covariates
assert torch.isnan(batch.future_covariates[0]).all()
# Covariate row: actual future values
assert batch.future_covariates[1, 0].item() == 60.0
assert batch.future_covariates[1, 1].item() == 70.0
class TestPrepareRequestCrossLearning:
"""Tests for cross-learning mode."""
def test_cross_learning_zeros_group_ids(self):
inputs = [
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]),
]
params = ForecastParameters(prediction_length=1, cross_learning=True)
result = prepare_request(inputs, params)
assert result.batches[0].group_ids.tolist() == [0, 0]
def test_no_cross_learning_distinct_group_ids(self):
inputs = [
TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]),
TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]),
]
params = ForecastParameters(prediction_length=1, cross_learning=False)
result = prepare_request(inputs, params)
assert result.batches[0].group_ids.tolist() == [0, 1]
class TestPrepareRequestNanCovariates:
"""Tests for prepare_request handling of NaN-sanitized covariates."""
def test_numeric_covariates_with_none(self):
"""Numeric covariates with None values should become NaN in tensors."""
inputs = [
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, None, 30.0, None, 50.0]},
future_covariates={"temp": [None, 70.0, None]},
)
]
params = ForecastParameters(prediction_length=3)
result = prepare_request(inputs, params)
batch = result.batches[0]
# 1 target + 1 covariate = 2 rows
assert batch.context.shape == (2, 5)
# Covariate row: None → NaN in context tensor
assert torch.isnan(batch.context[1, 1])
assert torch.isnan(batch.context[1, 3])
assert batch.context[1, 0].item() == 10.0
assert batch.context[1, 2].item() == 30.0
# Future covariates: None → NaN
assert torch.isnan(batch.future_covariates[1, 0])
assert batch.future_covariates[1, 1].item() == 70.0
def test_string_covariates_encoded_in_tensors(self):
"""String covariates (like holidays) should be encoded and included in tensors."""
inputs = [
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"holiday": ["A", None, "B", None, "C"]},
)
]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
batch = result.batches[0]
# 1 target + 1 encoded categorical covariate = 2 rows
assert batch.context.shape == (2, 5)
# Encoded values should be finite floats (not NaN) for non-None entries
assert not torch.isnan(batch.context[1, 0]) # "A" encoded
assert not torch.isnan(batch.context[1, 2]) # "B" encoded
assert not torch.isnan(batch.context[1, 4]) # "C" encoded
def test_categorical_with_future(self):
"""Categorical covariates with future values should be encoded consistently."""
inputs = [
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"holiday": ["A", "B", "A", "B", "A"]},
future_covariates={"holiday": ["A", "B"]},
)
]
params = ForecastParameters(prediction_length=2)
result = prepare_request(inputs, params)
batch = result.batches[0]
# 1 target + 1 encoded categorical = 2 rows
assert batch.context.shape == (2, 5)
# Future covariates should be filled for the categorical row
assert not torch.isnan(batch.future_covariates[1, 0])
assert not torch.isnan(batch.future_covariates[1, 1])
class TestPrepareRequestBatching:
"""Tests for batch output — always a single batch (chunking deferred to model forward)."""
def test_always_single_batch(self):
"""prepare_request always returns one batch regardless of batch_size parameter."""
inputs = [TimeSeriesInput(target=[float(i)] * 5) for i in range(5)]
params = ForecastParameters(prediction_length=1, batch_size=2)
result = prepare_request(inputs, params)
# Always 1 batch — model forward() handles chunking if needed
assert len(result.batches) == 1
assert result.batches[0].context.shape[0] == 5
assert len(result.batches[0].target_idx_ranges) == 5
def test_single_series_single_batch(self):
inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])]
params = ForecastParameters(prediction_length=1, batch_size=256)
result = prepare_request(inputs, params)
assert len(result.batches) == 1
assert result.batches[0].context.shape[0] == 1

View file

@ -0,0 +1,53 @@
"""Unit tests for chronos.chronos2.vllm.utils.helpers."""
import pytest
torch = pytest.importorskip("torch")
from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list # noqa: E402
class TestTensorToList:
"""Tests for tensor_to_list."""
def test_univariate(self):
t = torch.tensor([[1.0, 2.0, 3.0]])
result = tensor_to_list(t)
assert result == [1.0, 2.0, 3.0]
assert isinstance(result, list)
assert isinstance(result[0], float)
def test_multivariate(self):
t = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
result = tensor_to_list(t)
assert result == [[1.0, 2.0], [3.0, 4.0]]
assert isinstance(result[0], list)
def test_single_value(self):
t = torch.tensor([[42.0]])
assert tensor_to_list(t) == [42.0]
def test_wrong_dims_raises(self):
with pytest.raises(AssertionError):
tensor_to_list(torch.tensor([1.0, 2.0])) # 1-D
class TestEmptyPrediction:
"""Tests for empty_prediction."""
def test_basic(self):
pred = empty_prediction(3, [0.1, 0.5, 0.9])
assert pred["mean"] == [0.0, 0.0, 0.0]
assert pred["0.1"] == [0.0, 0.0, 0.0]
assert pred["0.5"] == [0.0, 0.0, 0.0]
assert pred["0.9"] == [0.0, 0.0, 0.0]
def test_single_quantile(self):
pred = empty_prediction(2, [0.5])
assert "mean" in pred
assert "0.5" in pred
assert len(pred["mean"]) == 2
def test_empty_quantiles(self):
pred = empty_prediction(1, [])
assert pred == {"mean": [0.0]}

View file

@ -0,0 +1,178 @@
"""Unit tests for chronos.chronos2.vllm.protocol.forecast Pydantic models."""
import pytest
from chronos.chronos2.vllm.protocol.forecast import (
ForecastParameters,
ForecastPrediction,
TimeSeriesInput,
)
class TestTimeSeriesInput:
"""Tests for TimeSeriesInput Pydantic model."""
def test_minimal(self):
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])
assert ts.target == [1.0, 2.0, 3.0, 4.0, 5.0]
assert ts.item_id is None
assert ts.start is None
assert ts.past_covariates is None
assert ts.future_covariates is None
def test_with_item_id(self):
ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="series_1")
assert ts.item_id == "series_1"
def test_multivariate_target(self):
ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
assert len(ts.target) == 2
def test_with_covariates(self):
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]},
future_covariates={"temp": [60.0]},
)
assert "temp" in ts.past_covariates
assert "temp" in ts.future_covariates
def test_too_short_target_rejected(self):
with pytest.raises(Exception):
TimeSeriesInput(target=[1.0, 2.0])
def test_covariate_length_mismatch_rejected(self):
with pytest.raises(Exception):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, 20.0]}, # wrong length
)
def test_future_without_past_rejected(self):
with pytest.raises(Exception):
TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
future_covariates={"temp": [60.0]},
)
class TestTimeSeriesInputNanCovariates:
"""Tests for NaN handling in covariate arrays (favorita_stores_1D scenario)."""
def test_nan_in_string_future_covariates_sanitized(self):
"""NaN floats in string covariate lists should be sanitized to None."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={
"holiday": ["New Year", "MLK Day", float("nan"), float("nan"), "Easter"]
},
future_covariates={"holiday": [float("nan")]},
)
# NaN should be converted to None
assert ts.past_covariates["holiday"][2] is None
assert ts.past_covariates["holiday"][3] is None
assert ts.future_covariates["holiday"][0] is None
# Strings should remain unchanged
assert ts.past_covariates["holiday"][0] == "New Year"
assert ts.past_covariates["holiday"][4] == "Easter"
def test_nan_in_numeric_covariates_preserved(self):
"""NaN in numeric covariates should be sanitized to None too (consistent)."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"temp": [10.0, float("nan"), 30.0, 40.0, 50.0]},
future_covariates={"temp": [float("nan")]},
)
# NaN floats → None after sanitization
assert ts.past_covariates["temp"][1] is None
assert ts.future_covariates["temp"][0] is None
def test_all_nan_string_covariates(self):
"""All-NaN covariate arrays should be accepted."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"holiday": [float("nan")] * 5},
future_covariates={"holiday": [float("nan")]},
)
assert all(v is None for v in ts.past_covariates["holiday"])
assert all(v is None for v in ts.future_covariates["holiday"])
def test_mixed_string_and_none_accepted(self):
"""Mix of strings and None should be accepted."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"holiday": ["A", None, "B", None, "C"]},
)
assert ts.past_covariates["holiday"] == ["A", None, "B", None, "C"]
def test_clean_string_covariates_unchanged(self):
"""String covariates without NaN should pass through unchanged."""
ts = TimeSeriesInput(
target=[1.0, 2.0, 3.0, 4.0, 5.0],
past_covariates={"category": ["A", "B", "C", "D", "E"]},
)
assert ts.past_covariates["category"] == ["A", "B", "C", "D", "E"]
class TestForecastParameters:
"""Tests for ForecastParameters Pydantic model."""
def test_defaults(self):
params = ForecastParameters()
assert params.prediction_length == 1
assert params.quantile_levels == [0.1, 0.5, 0.9]
assert params.freq is None
assert params.batch_size == 256
assert params.cross_learning is False
def test_custom_values(self):
params = ForecastParameters(
prediction_length=24,
quantile_levels=[0.25, 0.5, 0.75],
batch_size=100,
cross_learning=True,
)
assert params.prediction_length == 24
assert params.quantile_levels == [0.25, 0.5, 0.75]
assert params.batch_size == 100
assert params.cross_learning is True
def test_invalid_prediction_length(self):
with pytest.raises(Exception):
ForecastParameters(prediction_length=0)
def test_invalid_quantile_level(self):
with pytest.raises(Exception):
ForecastParameters(quantile_levels=[0.0, 0.5])
def test_prediction_length_max(self):
with pytest.raises(Exception):
ForecastParameters(prediction_length=2000)
class TestForecastPrediction:
"""Tests for ForecastPrediction Pydantic model."""
def test_minimal(self):
pred = ForecastPrediction(mean=[1.0, 2.0, 3.0])
assert pred.mean == [1.0, 2.0, 3.0]
assert pred.item_id is None
def test_with_quantiles(self):
pred = ForecastPrediction(
mean=[1.0, 2.0],
item_id="s1",
**{"0.1": [0.5, 1.5], "0.9": [1.5, 2.5]},
)
assert pred.item_id == "s1"
def test_multivariate_mean(self):
pred = ForecastPrediction(mean=[[1.0, 2.0], [3.0, 4.0]])
assert len(pred.mean) == 2
def test_extra_fields_allowed(self):
"""ForecastPrediction allows dynamic quantile fields."""
pred = ForecastPrediction(mean=[1.0], **{"0.5": [1.0], "0.1": [0.5]})
d = pred.model_dump()
assert "0.5" in d
assert "0.1" in d

View file

@ -0,0 +1,98 @@
"""Unit tests for chronos.chronos2.vllm.utils.quantiles and chronos.utils.interpolate_quantiles."""
import pytest
torch = pytest.importorskip("torch")
from chronos.chronos2.vllm.utils.quantiles import select_quantiles # noqa: E402
from chronos.utils import interpolate_quantiles # noqa: E402
MODEL_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
class TestSelectQuantiles:
"""Tests for select_quantiles."""
def test_subset_selection(self):
"""Requesting a subset of model quantiles should do direct indexing."""
# Shape: (1, 9_quantiles, 4_horizon)
pred = torch.randn(1, len(MODEL_QUANTILES), 4)
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9])
assert len(quantiles) == 1
assert len(mean) == 1
# After swap: (1, 4, 3)
assert quantiles[0].shape == (1, 4, 3)
# Mean is median (0.5 = index 4)
assert mean[0].shape == (1, 4)
def test_direct_selection_values(self):
"""Verify selected values match expected indices."""
pred = torch.arange(36, dtype=torch.float32).reshape(1, 9, 4)
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.9])
# q=0.1 is index 0, q=0.9 is index 8
# After swap: pred[..., h, q] → values at (h, q_idx)
q = quantiles[0] # (1, 4, 2)
assert q[0, 0, 0].item() == 0.0 # q=0.1, h=0 → row 0, col 0
assert q[0, 0, 1].item() == 32.0 # q=0.9, h=0 → row 8, col 0
def test_interpolation_triggered(self):
"""Non-subset quantile levels should trigger interpolation."""
pred = torch.randn(1, len(MODEL_QUANTILES), 4)
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.15, 0.5])
assert quantiles[0].shape == (1, 4, 2)
def test_multiple_predictions(self):
"""Works with multiple prediction tensors."""
preds = [torch.randn(1, 9, 4) for _ in range(3)]
quantiles, mean = select_quantiles(preds, MODEL_QUANTILES, [0.5])
assert len(quantiles) == 3
assert len(mean) == 3
def test_multivariate(self):
"""Works with multivariate predictions (2 variates)."""
pred = torch.randn(2, len(MODEL_QUANTILES), 4)
quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9])
assert quantiles[0].shape == (2, 4, 3)
assert mean[0].shape == (2, 4)
class TestInterpolateQuantiles:
"""Tests for chronos.utils.interpolate_quantiles."""
def test_exact_match(self):
"""Exact match should return the original values."""
values = torch.tensor([[1.0, 2.0, 3.0]]) # (1, 3) for levels [0.1, 0.5, 0.9]
result = interpolate_quantiles([0.5], [0.1, 0.5, 0.9], values)
assert result.shape == (1, 1)
assert result[0, 0].item() == 2.0
def test_midpoint_interpolation(self):
"""Midpoint between two levels should be the average."""
values = torch.tensor([[0.0, 10.0]]) # (1, 2) for levels [0.2, 0.8]
result = interpolate_quantiles([0.5], [0.2, 0.8], values)
assert abs(result[0, 0].item() - 5.0) < 1e-5
def test_clamping(self):
"""Query below/above model range should clamp to boundary."""
values = torch.tensor([[1.0, 2.0, 3.0]]) # levels [0.1, 0.5, 0.9]
low = interpolate_quantiles([0.01], [0.1, 0.5, 0.9], values)
high = interpolate_quantiles([0.99], [0.1, 0.5, 0.9], values)
assert low[0, 0].item() == pytest.approx(1.0, abs=1e-5) # clamped to 0.1
assert high[0, 0].item() == pytest.approx(3.0, abs=1e-5) # clamped to 0.9
def test_multiple_queries(self):
"""Multiple query levels should produce correct shape."""
values = torch.tensor([[1.0, 2.0, 3.0]])
result = interpolate_quantiles([0.1, 0.3, 0.5, 0.9], [0.1, 0.5, 0.9], values)
assert result.shape == (1, 4)
def test_batch_dimension(self):
"""Works with batch dimension."""
values = torch.randn(5, 9) # 5 time steps, 9 quantiles
result = interpolate_quantiles([0.25], MODEL_QUANTILES, values)
assert result.shape == (5, 1)

View file

@ -0,0 +1,163 @@
"""Unit tests for chronos.chronos2.vllm.protocol.validation."""
import pytest
from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput
from chronos.chronos2.vllm.protocol.validation import (
validate_cross_series,
validate_quantile_levels,
validate_single_series_covariates,
validate_start_timestamp,
validate_target,
)
class TestValidateTarget:
"""Tests for validate_target."""
def test_valid_univariate(self):
result = validate_target([1.0, 2.0, 3.0, 4.0, 5.0])
assert len(result) == 5
def test_valid_multivariate(self):
result = validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])
assert len(result) == 2
def test_too_short_univariate(self):
with pytest.raises(ValueError, match="at least 5"):
validate_target([1.0, 2.0])
def test_too_short_multivariate(self):
with pytest.raises(ValueError, match="at least 5"):
validate_target([[1.0, 2.0], [3.0, 4.0]])
def test_inconsistent_multivariate_lengths(self):
with pytest.raises(ValueError, match="same length"):
validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0]])
class TestValidateStartTimestamp:
"""Tests for validate_start_timestamp."""
def test_none(self):
assert validate_start_timestamp(None) is None
def test_valid_date(self):
assert validate_start_timestamp("2024-01-01") == "2024-01-01"
def test_valid_datetime(self):
assert validate_start_timestamp("2024-01-01T12:00:00") == "2024-01-01T12:00:00"
def test_invalid_format(self):
with pytest.raises(ValueError, match="Invalid start timestamp"):
validate_start_timestamp("not-a-date")
class TestValidateQuantileLevels:
"""Tests for validate_quantile_levels."""
def test_valid(self):
result = validate_quantile_levels([0.1, 0.5, 0.9])
assert result == [0.1, 0.5, 0.9]
def test_zero_invalid(self):
with pytest.raises(ValueError, match="between 0 and 1"):
validate_quantile_levels([0.0, 0.5])
def test_one_invalid(self):
with pytest.raises(ValueError, match="between 0 and 1"):
validate_quantile_levels([0.5, 1.0])
def test_negative_invalid(self):
with pytest.raises(ValueError, match="between 0 and 1"):
validate_quantile_levels([-0.1])
class TestValidateSingleSeriesCovariates:
"""Tests for validate_single_series_covariates."""
def test_no_covariates(self):
validate_single_series_covariates([1.0, 2.0, 3.0, 4.0, 5.0], None, None)
def test_valid_past_covariates(self):
target = [1.0, 2.0, 3.0, 4.0, 5.0]
past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}
validate_single_series_covariates(target, past, None)
def test_future_without_past_raises(self):
target = [1.0, 2.0, 3.0, 4.0, 5.0]
future = {"temp": [60.0]}
with pytest.raises(ValueError, match="together"):
validate_single_series_covariates(target, None, future)
def test_future_key_not_in_past_raises(self):
target = [1.0, 2.0, 3.0, 4.0, 5.0]
past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}
future = {"wind": [1.0]}
with pytest.raises(ValueError, match="not in"):
validate_single_series_covariates(target, past, future)
def test_past_length_mismatch(self):
target = [1.0, 2.0, 3.0, 4.0, 5.0]
past = {"temp": [10.0, 20.0]}
with pytest.raises(ValueError, match="must match target length"):
validate_single_series_covariates(target, past, None)
def test_target_name_forbidden(self):
target = [1.0, 2.0, 3.0, 4.0, 5.0]
past = {"target": [10.0, 20.0, 30.0, 40.0, 50.0]}
with pytest.raises(ValueError, match="not allowed"):
validate_single_series_covariates(target, past, None)
class TestValidateCrossSeries:
"""Tests for validate_cross_series."""
def _make_input(self, target=None, item_id=None, start=None, past_cov=None, future_cov=None):
return TimeSeriesInput(
target=target or [1.0, 2.0, 3.0, 4.0, 5.0],
item_id=item_id,
start=start,
past_covariates=past_cov,
future_covariates=future_cov,
)
def test_valid_simple(self):
inputs = [self._make_input(), self._make_input()]
params = ForecastParameters(prediction_length=3)
validate_cross_series(inputs, params)
def test_inconsistent_item_ids(self):
inputs = [self._make_input(item_id="a"), self._make_input(item_id=None)]
params = ForecastParameters()
with pytest.raises(ValueError, match="item_id"):
validate_cross_series(inputs, params)
def test_duplicate_item_ids(self):
inputs = [self._make_input(item_id="a"), self._make_input(item_id="a")]
params = ForecastParameters()
with pytest.raises(ValueError, match="unique"):
validate_cross_series(inputs, params)
def test_start_without_freq(self):
inputs = [self._make_input(start="2024-01-01")]
params = ForecastParameters(prediction_length=1)
with pytest.raises(ValueError, match="freq"):
validate_cross_series(inputs, params)
def test_freq_without_start(self):
inputs = [self._make_input()]
params = ForecastParameters(prediction_length=1, freq="D")
with pytest.raises(ValueError, match="start"):
validate_cross_series(inputs, params)
def test_future_cov_length_mismatch(self):
inputs = [
self._make_input(
past_cov={"temp": [1.0, 2.0, 3.0, 4.0, 5.0]},
future_cov={"temp": [10.0, 20.0]}, # length 2 != prediction_length 3
)
]
params = ForecastParameters(prediction_length=3)
with pytest.raises(ValueError, match="prediction_length"):
validate_cross_series(inputs, params)