From ba47d25a04aedac33af2fb9780609afd54744d5e Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 26 Feb 2026 23:19:26 +0000 Subject: [PATCH] Add vLLM plugin for Chronos-2 inference Signed-off-by: Li Zhang --- pyproject.toml | 10 + src/chronos/chronos2/vllm/README.md | 179 +++++++++ src/chronos/chronos2/vllm/__init__.py | 48 +++ src/chronos/chronos2/vllm/io_processor.py | 247 ++++++++++++ src/chronos/chronos2/vllm/model.py | 196 ++++++++++ src/chronos/chronos2/vllm/multimodal.py | 237 ++++++++++++ .../chronos2/vllm/protocol/__init__.py | 17 + .../chronos2/vllm/protocol/data_prep.py | 136 +++++++ .../chronos2/vllm/protocol/forecast.py | 218 +++++++++++ .../chronos2/vllm/protocol/validation.py | 242 ++++++++++++ src/chronos/chronos2/vllm/utils/__init__.py | 1 + src/chronos/chronos2/vllm/utils/helpers.py | 35 ++ src/chronos/chronos2/vllm/utils/quantiles.py | 40 ++ test/chronos2/vllm/__init__.py | 1 + test/chronos2/vllm/test_chronos2_plugin.py | 360 ++++++++++++++++++ test/chronos2/vllm/test_data_prep.py | 248 ++++++++++++ test/chronos2/vllm/test_helpers.py | 53 +++ test/chronos2/vllm/test_protocol.py | 178 +++++++++ test/chronos2/vllm/test_quantiles.py | 98 +++++ test/chronos2/vllm/test_validation.py | 163 ++++++++ 20 files changed, 2707 insertions(+) create mode 100644 src/chronos/chronos2/vllm/README.md create mode 100644 src/chronos/chronos2/vllm/__init__.py create mode 100644 src/chronos/chronos2/vllm/io_processor.py create mode 100644 src/chronos/chronos2/vllm/model.py create mode 100644 src/chronos/chronos2/vllm/multimodal.py create mode 100644 src/chronos/chronos2/vllm/protocol/__init__.py create mode 100644 src/chronos/chronos2/vllm/protocol/data_prep.py create mode 100644 src/chronos/chronos2/vllm/protocol/forecast.py create mode 100644 src/chronos/chronos2/vllm/protocol/validation.py create mode 100644 src/chronos/chronos2/vllm/utils/__init__.py create mode 100644 src/chronos/chronos2/vllm/utils/helpers.py create mode 100644 src/chronos/chronos2/vllm/utils/quantiles.py create mode 100644 test/chronos2/vllm/__init__.py create mode 100644 test/chronos2/vllm/test_chronos2_plugin.py create mode 100644 test/chronos2/vllm/test_data_prep.py create mode 100644 test/chronos2/vllm/test_helpers.py create mode 100644 test/chronos2/vllm/test_protocol.py create mode 100644 test/chronos2/vllm/test_quantiles.py create mode 100644 test/chronos2/vllm/test_validation.py diff --git a/pyproject.toml b/pyproject.toml index d9e7117..c337a83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,10 @@ test = [ "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] +vllm = [ + "vllm>=0.13.0", + "pydantic>=2.0", +] typecheck = ["mypy~=1.9"] dev = [ "gluonts[pro]~=0.16", @@ -65,6 +69,12 @@ dev = [ "pandas>=2.0,<2.4", ] +[project.entry-points."vllm.general_plugins"] +chronos2 = "chronos.chronos2.vllm:register_chronos2_model" + +[project.entry-points."vllm.io_processor_plugins"] +chronos2 = "chronos.chronos2.vllm:get_chronos2_io_processor" + [project.urls] Homepage = "https://github.com/amazon-science/chronos-forecasting" Issues = "https://github.com/amazon-science/chronos-forecasting/issues" diff --git a/src/chronos/chronos2/vllm/README.md b/src/chronos/chronos2/vllm/README.md new file mode 100644 index 0000000..97a4af8 --- /dev/null +++ b/src/chronos/chronos2/vllm/README.md @@ -0,0 +1,179 @@ +# Chronos-2 vLLM Plugin + +A [vLLM](https://github.com/vllm-project/vllm) plugin that adds support for [Chronos-2](https://github.com/amazon-science/chronos-forecasting) time series forecasting via the `/pooling` API endpoint. + +## Overview + +Chronos-2 is an encoder-only time series foundation model for zero-shot forecasting. This plugin integrates it with vLLM using the **IOProcessor** plugin interface, so forecast requests are served through vLLM's standard pooling endpoint. + +### Features + +- **Zero-shot forecasting** — no fine-tuning required +- **Quantile predictions** — probabilistic forecasts with customizable quantile levels +- **Cross-series learning** — information sharing across time series in a batch +- **Covariates support** — past and future covariates (numeric and categorical) +- **Batch forecasting** — process multiple time series in a single request + +## Installation + +Requires Python 3.10+ and vLLM 0.13.0+. + +```bash +pip install chronos-forecasting[vllm] +``` + +## Quick Start + +### 1. Start the Server + +```bash +vllm serve amazon/chronos-2 \ + --io-processor-plugin chronos2 \ + --runner pooling \ + --enforce-eager \ + --no-enable-prefix-caching \ + --skip-tokenizer-init \ + --enable-mm-embeds \ + --dtype float32 \ + --max-model-len 8192 +``` + +### 2. Send a Forecast Request + +```bash +curl -X POST http://localhost:8000/pooling \ + -H "Content-Type: application/json" \ + -d '{ + "model": "amazon/chronos-2", + "task": "plugin", + "data": { + "inputs": [ + { + "target": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "item_id": "series_1" + } + ], + "parameters": { + "prediction_length": 5, + "quantile_levels": [0.1, 0.5, 0.9] + } + } + }' +``` + +### 3. Parse the Response + +```json +{ + "request_id": null, + "created_at": 1739397600, + "data": { + "predictions": [ + { + "mean": [11.0, 12.1, 13.0, 14.2, 15.1], + "0.1": [9.5, 10.3, 11.0, 11.8, 12.5], + "0.5": [11.0, 12.0, 13.0, 14.0, 15.0], + "0.9": [12.5, 13.8, 15.0, 16.3, 17.5], + "item_id": "series_1" + } + ] + } +} +``` + +## API Reference + +### Request Format + +| Field | Type | Required | Description | +|---|---|---|---| +| `model` | `str` | ✅ | Model name (e.g., `"amazon/chronos-2"`) | +| `task` | `str` | ✅ | Must be `"plugin"` | +| `data.inputs` | `list` | ✅ | List of time series inputs (1–1024) | +| `data.parameters` | `dict` | | Forecast parameters | + +#### Time Series Input (`data.inputs[*]`) + +| Field | Type | Required | Description | +|---|---|---|---| +| `target` | `list[float]` or `list[list[float]]` | ✅ | Historical values (min 5 observations). 1-D for univariate, 2-D for multivariate. | +| `item_id` | `str` | | Identifier echoed in response | +| `start` | `str` | | ISO 8601 timestamp | +| `past_covariates` | `dict[str, list]` | | Past covariate arrays (must match target length) | +| `future_covariates` | `dict[str, list]` | | Future covariate arrays (must match `prediction_length`) | + +#### Parameters (`data.parameters`) + +| Field | Type | Default | Description | +|---|---|---|---| +| `prediction_length` | `int` | `1` | Forecast horizon (1–1024) | +| `quantile_levels` | `list[float]` | `[0.1, 0.5, 0.9]` | Quantile levels in (0, 1) | +| `freq` | `str` | `null` | Pandas frequency string (e.g., `"D"`, `"H"`) | +| `batch_size` | `int` | `256` | Inference batch size | +| `cross_learning` | `bool` | `false` | Enable cross-series learning | + +### Response Format + +Each prediction in `data.predictions` contains: + +| Field | Type | Description | +|---|---|---| +| `mean` | `list[float]` | Point forecast (mean/median) | +| `"0.1"`, `"0.5"`, etc. | `list[float]` | Named quantile columns matching `quantile_levels` | +| `item_id` | `str` | Echoed from input (if provided) | + +## Architecture + +The vLLM model wrapper (`Chronos2ForForecasting`) is a thin adapter that delegates all computation to the existing `chronos.chronos2.model.Chronos2Model`. No model architecture is duplicated. + +### Module Structure + +``` +src/chronos/chronos2/vllm/ +├── __init__.py # Plugin entry point & registration +├── model.py # Chronos2ForForecasting (thin vLLM wrapper) +├── multimodal.py # MM pipeline for "timeseries" modality +├── io_processor.py # Chronos2IOProcessor (request/response handling) +├── protocol/ +│ ├── __init__.py +│ ├── forecast.py # Pydantic models (TimeSeriesInput, ForecastParameters, etc.) +│ ├── validation.py # Input validation logic +│ └── data_prep.py # Tensor preparation from validated inputs +└── utils/ + ├── __init__.py + ├── helpers.py # Utility functions + └── quantiles.py # Quantile selection & interpolation +``` + +### Key Classes + +| Class | File | Purpose | +|---|---|---| +| `Chronos2ForForecasting` | `model.py` | Thin vLLM wrapper — delegates to `chronos.chronos2.model.Chronos2Model` | +| `Chronos2IOProcessor` | `io_processor.py` | Request parsing, validation, pre/post processing | +| `ForecastParameters` | `protocol/forecast.py` | Pydantic validation for forecast parameters | +| `TimeSeriesInput` | `protocol/forecast.py` | Pydantic validation for time series inputs | +| `ForecastPrediction` | `protocol/forecast.py` | Pydantic model for forecast output | + +## Troubleshooting + +### Server Flags + +The following flags are required for Chronos-2: + +```bash +--io-processor-plugin chronos2 # Enable the forecast IOProcessor +--enforce-eager # Chronos-2 doesn't support CUDA graphs +--no-enable-prefix-caching # Not applicable for time series +--skip-tokenizer-init # Chronos-2 doesn't use a text tokenizer +``` + +### Plugin Not Loading + +1. Verify installation: `pip list | grep chronos-forecasting` +2. Check entry points: + ```python + from importlib.metadata import entry_points + print(list(entry_points(group='vllm.general_plugins'))) + ``` +3. Enable debug logging: `VLLM_LOGGING_LEVEL=DEBUG vllm serve ...` \ No newline at end of file diff --git a/src/chronos/chronos2/vllm/__init__.py b/src/chronos/chronos2/vllm/__init__.py new file mode 100644 index 0000000..970ad7c --- /dev/null +++ b/src/chronos/chronos2/vllm/__init__.py @@ -0,0 +1,48 @@ +"""vLLM Chronos-2 Time Series Forecasting Model Plugin. + +This plugin registers the Chronos-2 model with vLLM's ModelRegistry, +allowing it to be used with vLLM's inference engine for time series forecasting. +""" + + +def register_chronos2_model() -> None: + """Register Chronos-2 models with vLLM's ModelRegistry. + + This function is called automatically when the plugin is loaded + through vLLM's plugin discovery mechanism. + """ + try: + from vllm.logger import init_logger + from vllm.model_executor.models.registry import ModelRegistry + + logger = init_logger(__name__) + + ModelRegistry.register_model( + "Chronos2Model", + "chronos.chronos2.vllm.model:Chronos2ForForecasting", + ) + + logger.info("Successfully registered Chronos-2 model with vLLM") + + except Exception as e: + from vllm.logger import init_logger + + logger = init_logger(__name__) + logger.error(f"Failed to register Chronos-2 model: {e}") + raise + + +def get_chronos2_io_processor(): + """ + Factory function for IOProcessor plugin registration. + + This function is called by vLLM when --io-processor-plugin is set to 'chronos2'. + Returns the fully qualified name (module.Class format) as a string. + """ + return "chronos.chronos2.vllm.io_processor.Chronos2IOProcessor" + + +__all__ = [ + "register_chronos2_model", + "get_chronos2_io_processor", +] \ No newline at end of file diff --git a/src/chronos/chronos2/vllm/io_processor.py b/src/chronos/chronos2/vllm/io_processor.py new file mode 100644 index 0000000..c267f30 --- /dev/null +++ b/src/chronos/chronos2/vllm/io_processor.py @@ -0,0 +1,247 @@ +"""IOProcessor for Chronos-2 time series forecasting. + +Thin orchestrator that delegates to: + - protocol.data_prep: tensor preparation from validated inputs (via Chronos2Dataset) + - protocol.validation: cross-series validation + - multimodal.MODALITY: MM prompt construction + - utils: quantile selection, helpers + +Flow: + 1. parse_request → validate inputs via Pydantic models + 2. pre_process → prepare_request() → timeseries MM prompts (one per batch) + 3. post_process → select_quantiles() → per-series predictions +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +import torch +from vllm.config import VllmConfig +from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.plugins.io_processors.interface import IOProcessor +from vllm.pooling_params import PoolingParams + +from chronos.chronos2.vllm.multimodal import MODALITY +from chronos.chronos2.vllm.protocol.data_prep import PreparedRequest, prepare_request +from chronos.chronos2.vllm.protocol.forecast import ( + ForecastParameters, + ForecastPrediction, + TimeSeriesInput, +) +from chronos.chronos2.vllm.protocol.validation import validate_cross_series +from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list +from chronos.chronos2.vllm.utils.quantiles import select_quantiles + +logger = init_logger(__name__) + + +@dataclass +class _PostProcessInfo: + """Lightweight cache of data needed for post_process — avoids storing full tensors.""" + + item_ids: list[str | None] + parameters: ForecastParameters + target_idx_ranges: list[list[tuple[int, int]]] # per-batch list of (start, end) ranges + + +class Chronos2IOProcessor(IOProcessor[dict, dict]): + """IOProcessor for Chronos-2 time series forecasting.""" + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + config = vllm_config.model_config.hf_config + config.is_encoder_decoder = False + # Set projection_dim=0 so vLLM uses IdentityPooler (no projection layer). + # This avoids requiring --hf-overrides '{"projection_dim": 0}' on the CLI. + if not hasattr(config, "projection_dim") or config.projection_dim != 0: + config.projection_dim = 0 + + cc = getattr(config, "chronos_config", {}) + self._chronos_config = ( + cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {}) + ) + self._context_length = self._chronos_config.get("context_length", 8192) + self._output_patch_size = self._chronos_config.get("output_patch_size", 16) + self._model_quantiles = self._chronos_config.get( + "quantiles", + [ + 0.01, + 0.05, + 0.1, + 0.15, + 0.2, + 0.25, + 0.3, + 0.35, + 0.4, + 0.45, + 0.5, + 0.55, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + 0.99, + ], + ) + self._request_info: dict[str, _PostProcessInfo] = {} + logger.info("Initialized Chronos2IOProcessor") + + # ----------------------------------------------------------- + # Request parsing + # ----------------------------------------------------------- + + def parse_request(self, request: Any) -> dict: + data = request.data if hasattr(request, "data") else request + if not isinstance(data, dict): + raise ValueError(f"Expected dict, got {type(data)}") + if "inputs" not in data: + raise ValueError("Request must contain 'inputs' field") + raw_inputs = data["inputs"] + if not isinstance(raw_inputs, list) or len(raw_inputs) == 0: + raise ValueError("'inputs' must be a non-empty list") + + validated_inputs = [] + for i, ts in enumerate(raw_inputs): + try: + validated_inputs.append(TimeSeriesInput(**ts)) + except Exception as e: + raise ValueError(f"Invalid input at index {i}: {e}") from e + + try: + validated_params = ForecastParameters(**data.get("parameters", {})) + except Exception as e: + raise ValueError(f"Invalid parameters: {e}") from e + + validate_cross_series(validated_inputs, validated_params) + return {"inputs": validated_inputs, "parameters": validated_params} + + # ----------------------------------------------------------- + # Pre-process: inputs → MM prompts + # ----------------------------------------------------------- + + def pre_process( + self, + prompt: dict, + request_id: str | None = None, + **kwargs: Any, + ) -> PromptType | Sequence[PromptType]: + prepared = prepare_request( + inputs=prompt["inputs"], + parameters=prompt["parameters"], + context_length=self._context_length, + output_patch_size=self._output_patch_size, + ) + + # Cache only lightweight post-processing metadata + if request_id is not None: + self._request_info[request_id] = _PostProcessInfo( + item_ids=prepared.item_ids, + parameters=prepared.parameters, + target_idx_ranges=[batch.target_idx_ranges for batch in prepared.batches], + ) + + # Build one MM prompt per batch (Chronos2Dataset handles batch splitting) + prompts: list[PromptType] = [] + for batch in prepared.batches: + prompts.append( + { + "prompt_token_ids": [1], + "multi_modal_data": { + MODALITY: { + "context": batch.context, + "future_covariates": batch.future_covariates, + "group_ids": batch.group_ids, + "num_output_patches": batch.num_output_patches, + } + }, + } + ) + + return prompts if len(prompts) > 1 else prompts[0] + + # ----------------------------------------------------------- + # Post-process: model output → predictions + # ----------------------------------------------------------- + + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: str | None = None, + **kwargs: Any, + ) -> dict: + info = self._request_info.pop(request_id, None) if request_id else None + if info is None: + logger.error("No pending request for request_id=%s", request_id) + return {"predictions": []} + + parameters = info.parameters + item_ids = info.item_ids + + try: + # Collect predictions across all batches, trimmed to exact prediction_length + pred_len = parameters.prediction_length + all_predictions: list[torch.Tensor] = [] + for batch_idx, output in enumerate(model_output): + tensor = output.outputs.data + while tensor.ndim > 3: + tensor = tensor.squeeze(0) + + batch_ranges = info.target_idx_ranges[batch_idx] + for start, end in batch_ranges: + all_predictions.append(tensor[start:end, :, :pred_len]) + + quantiles_out, mean_out = select_quantiles( + all_predictions, self._model_quantiles, parameters.quantile_levels + ) + + result: dict[str, Any] = {"predictions": [], "request_id": request_id} + for i, (q_tensor, m_tensor) in enumerate(zip(quantiles_out, mean_out)): + pred: dict[str, Any] = {"mean": tensor_to_list(m_tensor)} + for q, q_vals in zip(parameters.quantile_levels, q_tensor.unbind(dim=-1)): + pred[str(q)] = tensor_to_list(q_vals) + if i < len(item_ids) and item_ids[i] is not None: + pred["item_id"] = item_ids[i] + result["predictions"].append(pred) + + except Exception as e: + logger.error("Failed to decode predictions: %s", e, exc_info=True) + result = { + "predictions": [ + empty_prediction(parameters.prediction_length, parameters.quantile_levels) + for _ in item_ids + ], + "request_id": request_id, + } + + return result + + # ----------------------------------------------------------- + # Response formatting + # ----------------------------------------------------------- + + def validate_or_generate_params(self, params: Any = None) -> PoolingParams: + return PoolingParams(task="plugin") + + def output_to_response(self, plugin_output: dict) -> IOProcessorResponse: + validated = [] + for pred in plugin_output.get("predictions", []): + try: + validated.append(ForecastPrediction(**pred).model_dump(exclude_none=True)) + except Exception as e: + logger.warning("Failed to validate prediction: %s", e) + validated.append(pred) + return IOProcessorResponse( + request_id=plugin_output.get("request_id"), + created_at=int(time.time()), + data={"predictions": validated}, + ) \ No newline at end of file diff --git a/src/chronos/chronos2/vllm/model.py b/src/chronos/chronos2/vllm/model.py new file mode 100644 index 0000000..131b722 --- /dev/null +++ b/src/chronos/chronos2/vllm/model.py @@ -0,0 +1,196 @@ +"""Chronos-2 vLLM model wrapper. + +Thin wrapper around the existing ``chronos.chronos2.model.Chronos2Model`` +that plugs into vLLM's multimodal (MM) interface. No model architecture +is duplicated — all computation is delegated to the upstream implementation. + +Architecture: + pre_process → IOProcessor prepares context/covariates/group_ids tensors + → Chronos2Dataset handles batching and group_id construction + → returns MM prompt(s) with timeseries data + forward() → receives pre-batched context/future_covariates/group_ids as kwargs + → delegates to chronos.chronos2.model.Chronos2Model + pooler → IdentityPooler passes output through unchanged + post_process → IOProcessor selects/interpolates quantiles +""" + +import math +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import IdentityPooler +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, + MultiModalEmbeddings, + SupportsMultiModal, +) +from vllm.model_executor.models.interfaces_base import attn_type +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.utils import length_from_prompt_token_ids_or_embeds + +from chronos.chronos2.model import Chronos2Model + +from .multimodal import ( + MODALITY, + ChronosInputBuilder, + ChronosMultiModalProcessor, + ChronosProcessingInfo, +) + +logger = init_logger(__name__) + + +@attn_type("attention_free") +@MULTIMODAL_REGISTRY.register_processor( + ChronosMultiModalProcessor, + info=ChronosProcessingInfo, + dummy_inputs=ChronosInputBuilder, +) +class Chronos2ForForecasting(nn.Module, IsAttentionFree, SupportsMultiModal): + """Chronos-2 forecasting model for vLLM. + + Delegates all computation to ``chronos.chronos2.model.Chronos2Model``. + Receives pre-batched tensors (context, future_covariates, group_ids) + directly in forward() via the timeseries MM pipeline. Batching is + handled upstream by ``Chronos2Dataset`` in the IOProcessor. + """ + + supports_multimodal_raw_input_only = True + is_pooling_model = True + + # Required by VllmModel interface + packed_modules_mapping: dict[str, Any] = {} + supported_lora_modules: list[str] = [] + embedding_modules: dict[str, str] = {} + embedding_padding_modules: list[str] = [] + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith(MODALITY): + return None + raise ValueError(f"Only {MODALITY} modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + config.is_encoder_decoder = False + if not hasattr(config, "projection_dim") or config.projection_dim != 0: + config.projection_dim = 0 + vllm_config.model_config.hf_text_config = None + + self.config = config + self.model_name = vllm_config.model_config.model + self.d_model = getattr(config, "d_model", 768) + + cc = getattr(config, "chronos_config", {}) + self.chronos_config = ( + cc if isinstance(cc, dict) else (cc.__dict__ if hasattr(cc, "__dict__") else {}) + ) + self.output_patch_size = self.chronos_config.get("output_patch_size", 16) + self.max_output_patches = self.chronos_config.get("max_output_patches", 64) + + # Instantiate the upstream Chronos2Model — reuses all existing layers + self.model = Chronos2Model(config) + + self.pooler = IdentityPooler() + + logger.info( + "Initialized Chronos2ForForecasting (d_model=%d, delegating to chronos.chronos2.model)", + self.d_model, + ) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + """Return empty embeddings — Chronos-2 has no token vocabulary.""" + return torch.empty((input_ids.shape[0], 0)) + + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + """Run Chronos-2 inference by delegating to the upstream model. + + Receives pre-batched context, future_covariates, group_ids from + the timeseries MM pipeline. Batching is handled upstream by + Chronos2Dataset in the IOProcessor. + """ + input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds) + + context: torch.Tensor | None = kwargs.get("context") # type: ignore[assignment] + future_covariates: torch.Tensor | None = kwargs.get("future_covariates") # type: ignore[assignment] + group_ids: torch.Tensor | None = kwargs.get("group_ids") # type: ignore[assignment] + num_output_patches: int | None = kwargs.get("num_output_patches") # type: ignore[assignment] + + if context is None: + # Warmup/profiling pass — return zeros + return torch.zeros(input_len, 0, device=positions.device, dtype=torch.float32) + + # Determine num_output_patches from preprocessor or fall back to computing from future_covariates + if num_output_patches is None: + prediction_length = future_covariates.shape[1] if future_covariates is not None else 0 + if prediction_length == 0: + prediction_length = self.output_patch_size * self.max_output_patches + num_output_patches = int(math.ceil(prediction_length / self.output_patch_size)) + num_output_patches = min(num_output_patches, self.max_output_patches) + + prediction_length = num_output_patches * self.output_patch_size + + # Pad or trim future_covariates to match output_size + if future_covariates is not None: + if prediction_length > future_covariates.shape[1]: + pad_size = prediction_length - future_covariates.shape[1] + pad_tensor = torch.full( + (future_covariates.shape[0], pad_size), + fill_value=float("nan"), + device=future_covariates.device, + ) + future_covariates = torch.cat([future_covariates, pad_tensor], dim=1) + else: + future_covariates = future_covariates[:, :prediction_length] + + # Delegate to the upstream Chronos2Model — single forward pass + model_kwargs: dict[str, Any] = { + "context": context, + "num_output_patches": num_output_patches, + } + if group_ids is not None: + model_kwargs["group_ids"] = group_ids + if future_covariates is not None: + model_kwargs["future_covariates"] = future_covariates + + output = self.model(**model_kwargs) + batch_prediction = output.quantile_preds[..., :prediction_length] + + # Expand to match input_len for vLLM pipeline compatibility. + # IdentityPooler passes through unchanged; post_process squeezes. + hidden_states = batch_prediction[None].expand( + input_len, *(-1 for _ in range(batch_prediction.ndim)) + ) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights into the upstream Chronos2Model. + + Prefix checkpoint names with 'model.' to match our wrapping. + Uses AutoWeightsLoader for robust weight loading. + """ + prefixed = [(f"model.{name}", tensor) for name, tensor in weights] + loader = AutoWeightsLoader(self) + return loader.load_weights(prefixed) \ No newline at end of file diff --git a/src/chronos/chronos2/vllm/multimodal.py b/src/chronos/chronos2/vllm/multimodal.py new file mode 100644 index 0000000..19fd72b --- /dev/null +++ b/src/chronos/chronos2/vllm/multimodal.py @@ -0,0 +1,237 @@ +"""Multimodal boilerplate for the "timeseries" modality in vLLM. + +Provides the MM processing pipeline classes that route timeseries +dict data (context, future_covariates, group_ids) through vLLM's +multimodal infrastructure. +""" + +import hashlib +import time +from typing import Any, Mapping, Sequence + +import torch +from transformers import BatchFeature +from vllm.config.multimodal import BaseDummyOptions +from vllm.multimodal.cache import MultiModalProcessorOnlyCache +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalFieldElem, + MultiModalInputs, + MultiModalKwargsItem, + MultiModalKwargsItems, + PlaceholderRange, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptUpdate, +) + +# The modality name used throughout the MM pipeline. +MODALITY = "timeseries" + +# Field names expected in the timeseries MM data dict. +REQUIRED_FIELDS = frozenset({"context", "future_covariates", "group_ids", "num_output_patches"}) + + +def _field_config() -> dict[str, MultiModalFieldConfig]: + """Shared field config for all timeseries fields.""" + return { + "context": MultiModalFieldConfig.shared(MODALITY, batch_size=1), + "future_covariates": MultiModalFieldConfig.shared(MODALITY, batch_size=1), + "group_ids": MultiModalFieldConfig.shared(MODALITY, batch_size=1), + "num_output_patches": MultiModalFieldConfig.shared(MODALITY, batch_size=1), + } + + +# ------------------------------------------------------------------- +# Processing info: tells vLLM what modalities we support +# ------------------------------------------------------------------- + + +class ChronosProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {MODALITY: None} + + def get_data_parser(self) -> MultiModalDataParser: + return ChronosMultiModalDataParser() + + def build_data_parser(self) -> MultiModalDataParser: + return ChronosMultiModalDataParser() + + @property # type: ignore[override] + def data_parser(self) -> MultiModalDataParser: + if not hasattr(self, "_data_parser"): + self._data_parser = ChronosMultiModalDataParser() + return self._data_parser + + +# ------------------------------------------------------------------- +# Dummy input builder: provides profiling data for GPU warmup +# ------------------------------------------------------------------- + + +class ChronosInputBuilder(BaseDummyInputsBuilder[ChronosProcessingInfo]): + """Provides dummy data for vLLM's GPU profiling/warmup pass.""" + + def __init__(self, info: ChronosProcessingInfo): + super().__init__(info) + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + **kwargs: Any, + ) -> MultiModalDataDict: + return { + MODALITY: { + "context": torch.ones(100, 2048, dtype=torch.float32), + "future_covariates": torch.ones(100, 1024, dtype=torch.float32), + "group_ids": torch.zeros(100, dtype=torch.long), + "num_output_patches": 64, + } + } + + +# ------------------------------------------------------------------- +# Data parser: routes timeseries dict data through the MM pipeline +# ------------------------------------------------------------------- + + +class ChronosMultiModalDataParser(MultiModalDataParser): + """Parses timeseries dict data for vLLM's MM pipeline.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def _parse_timeseries_data( + self, + data: dict[str, torch.Tensor], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality=MODALITY, + required_fields=REQUIRED_FIELDS, + fields_factory=lambda _: _field_config(), + ) + return None + + def _get_subparsers(self) -> Mapping[str, Any]: + return {MODALITY: self._parse_timeseries_data} + + def parse_mm_data(self, mm_data: MultiModalDataDict, **kwargs: Any) -> MultiModalDataItems: + if MODALITY not in mm_data: + mm_data = {MODALITY: mm_data} + + ts_data = mm_data[MODALITY] + items = self._parse_timeseries_data(ts_data) + if items is None: + raise ValueError("Failed to parse timeseries data") + + return MultiModalDataItems({MODALITY: items}) + + +# ------------------------------------------------------------------- +# Processor: converts parsed data into MultiModalInputs +# ------------------------------------------------------------------- + + +class ChronosMultiModalProcessor(BaseMultiModalProcessor): + """Processes timeseries MM data into MultiModalInputs for vLLM.""" + + def __init__( + self, + info: ChronosProcessingInfo, + dummy_inputs: BaseDummyInputsBuilder[ChronosProcessingInfo], + *, + cache: MultiModalProcessorOnlyCache | None = None, + ) -> None: + super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + return [] + + def apply(self, *args: Any, **kwargs: Any) -> MultiModalInputs: + # Handle both old and new vLLM calling conventions: + # Old: apply(prompt, mm_items, hf_processor_mm_kwargs, ...) + # New: apply(processor_inputs, timing_ctx) + # Internal (from get_dummy_mm_inputs): apply(prompt=..., mm_items=..., ...) + + ts_data: dict[str, Any] = {} + + if args and hasattr(args[0], "prompt"): + # New vLLM: first arg is ProcessorInputs dataclass + processor_inputs = args[0] + mm_items = getattr(processor_inputs, "mm_items", None) or getattr( + processor_inputs, "mm_data_items", None + ) + if mm_items is not None and isinstance(mm_items, MultiModalDataItems): + if MODALITY in mm_items: + ts_items = mm_items[MODALITY] + ts_data = ts_items.data if hasattr(ts_items, "data") else {} + else: + # Old vLLM / direct call: extract from positional/keyword args + mm_items = args[1] if len(args) > 1 else kwargs.get("mm_items") + mm_data = kwargs.get("mm_data") + + if mm_items is not None and isinstance(mm_items, MultiModalDataItems): + if MODALITY in mm_items: + ts_items = mm_items[MODALITY] + ts_data = ts_items.data if hasattr(ts_items, "data") else {} + elif mm_data is not None and isinstance(mm_data, dict): + ts_data = mm_data.get(MODALITY, mm_data) + + mm_placeholders = {MODALITY: [PlaceholderRange(offset=0, length=0)]} + + # Build MultiModalKwargsItems directly to ensure proper modality keying. + # from_hf_inputs + BatchFeature can produce empty modalities when + # the data doesn't match the expected shared field batch structure. + field_config = _field_config() + mm_item_dict: dict[str, MultiModalFieldElem] = {} + for key, config in field_config.items(): + tensor = ts_data.get(key) if isinstance(ts_data, dict) else None + if tensor is not None: + mm_item_dict[key] = MultiModalFieldElem( + data=tensor, + field=config.field, + ) + + mm_kwargs_items = MultiModalKwargsItems({MODALITY: [MultiModalKwargsItem(mm_item_dict)]}) + + # Unique hash per request (required by vLLM v0.16+) + ts_hash = hashlib.sha256( + str(id(ts_data)).encode() + str(time.monotonic()).encode() + ).hexdigest()[:16] + + return MultiModalInputs( + type="multimodal", + prompt_token_ids=[1], + mm_kwargs=mm_kwargs_items, + mm_hashes={MODALITY: [ts_hash]}, + mm_placeholders=mm_placeholders, + ) diff --git a/src/chronos/chronos2/vllm/protocol/__init__.py b/src/chronos/chronos2/vllm/protocol/__init__.py new file mode 100644 index 0000000..64be7f9 --- /dev/null +++ b/src/chronos/chronos2/vllm/protocol/__init__.py @@ -0,0 +1,17 @@ +"""Chronos-2 forecasting protocol definitions and data preparation.""" + +from chronos.chronos2.vllm.protocol.forecast import ( + ForecastParameters, + ForecastPrediction, + ForecastRequest, + ForecastResponse, + TimeSeriesInput, +) + +__all__ = [ + "TimeSeriesInput", + "ForecastParameters", + "ForecastRequest", + "ForecastPrediction", + "ForecastResponse", +] diff --git a/src/chronos/chronos2/vllm/protocol/data_prep.py b/src/chronos/chronos2/vllm/protocol/data_prep.py new file mode 100644 index 0000000..2e12e93 --- /dev/null +++ b/src/chronos/chronos2/vllm/protocol/data_prep.py @@ -0,0 +1,136 @@ +"""Data preparation for Chronos-2 model input tensors. + +Converts validated TimeSeriesInput objects into batched tensors +ready for the model by delegating to ``chronos.chronos2.dataset``. +""" + +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode +from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput + + +@dataclass +class PreparedBatch: + """A single batch of prepared model input tensors.""" + + context: torch.Tensor # (batch_rows, padded_ctx_len) + future_covariates: torch.Tensor # (batch_rows, prediction_length) + group_ids: torch.Tensor # (batch_rows,) + num_output_patches: int + target_idx_ranges: list[tuple[int, int]] # per-input-series (start, end) in this batch + + +@dataclass +class PreparedRequest: + """All batches and metadata for a single forecast request.""" + + batches: list[PreparedBatch] + item_ids: list[str | None] + parameters: ForecastParameters + + +def _timeseries_input_to_dict(ts: TimeSeriesInput) -> dict[str, Any]: + """Convert a Pydantic TimeSeriesInput to the dict format expected by Chronos2Dataset. + + The dataset expects: + - ``target``: np.ndarray of shape (history_length,) or (n_variates, history_length) + - ``past_covariates``: dict[str, np.ndarray] + - ``future_covariates``: dict[str, np.ndarray] + """ + target = np.array(ts.target, dtype=np.float32) + + entry: dict[str, Any] = {"target": target} + + if ts.past_covariates: + past_covs: dict[str, np.ndarray] = {} + for key, vals in ts.past_covariates.items(): + arr = np.array([v if v is not None else np.nan for v in vals]) + # Keep string arrays as object dtype for categorical encoding + if any(isinstance(v, str) for v in vals if v is not None): + arr = np.array([str(v) if v is not None else "nan" for v in vals]) + else: + arr = arr.astype(np.float32) + past_covs[key] = arr + entry["past_covariates"] = past_covs + + if ts.future_covariates: + future_covs: dict[str, np.ndarray] = {} + for key, vals in ts.future_covariates.items(): + arr = np.array([v if v is not None else np.nan for v in vals]) + if any(isinstance(v, str) for v in vals if v is not None): + arr = np.array([str(v) if v is not None else "nan" for v in vals]) + else: + arr = arr.astype(np.float32) + future_covs[key] = arr + entry["future_covariates"] = future_covs + + return entry + + +def prepare_request( + inputs: list[TimeSeriesInput], + parameters: ForecastParameters, + context_length: int = 8192, + output_patch_size: int = 16, +) -> PreparedRequest: + """Convert validated inputs into batched model-ready tensors. + + Delegates to ``Chronos2Dataset`` in TEST mode for data preparation, + covariate encoding, batching, and group ID construction. + + Args: + inputs: validated time series inputs + parameters: forecast parameters + context_length: maximum context length (from model config) + output_patch_size: model's output patch size (from model config) + + Returns: + PreparedRequest with batched tensors and metadata + """ + # Convert Pydantic models to dicts for Chronos2Dataset + raw_inputs = [_timeseries_input_to_dict(ts) for ts in inputs] + item_ids = [ts.item_id for ts in inputs] + + pred_len = parameters.prediction_length + + # Use Chronos2Dataset in TEST mode — handles validation, covariate encoding, + # left-padding, and group_id construction. We use a very large batch_size + # to always produce exactly one batch; row-level chunking (if needed for + # memory) can be handled by the caller or the model's forward pass. + dataset = Chronos2Dataset( + inputs=raw_inputs, + context_length=context_length, + prediction_length=pred_len, + batch_size=2**31 - 1, # large enough to fit all series in one batch + output_patch_size=output_patch_size, + mode=DatasetMode.TEST, + convert_inputs=True, + ) + + batches: list[PreparedBatch] = [] + for batch_dict in dataset: + group_ids = batch_dict["group_ids"] + if parameters.cross_learning: + group_ids = torch.zeros_like(group_ids) + + batches.append( + PreparedBatch( + context=batch_dict["context"], + future_covariates=batch_dict["future_covariates"], + group_ids=group_ids, + num_output_patches=batch_dict["num_output_patches"], + target_idx_ranges=batch_dict["target_idx_ranges"], + ) + ) + + return PreparedRequest( + batches=batches, + item_ids=item_ids, + parameters=parameters, + ) \ No newline at end of file diff --git a/src/chronos/chronos2/vllm/protocol/forecast.py b/src/chronos/chronos2/vllm/protocol/forecast.py new file mode 100644 index 0000000..d524157 --- /dev/null +++ b/src/chronos/chronos2/vllm/protocol/forecast.py @@ -0,0 +1,218 @@ +import math +from typing import Any, Literal, Union + +from pydantic import BaseModel, Field, field_validator, model_validator + +try: + from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel +except ImportError: + OpenAIBaseModel = BaseModel # type: ignore[misc,assignment] + +from chronos.chronos2.vllm.protocol.validation import ( + MAX_NUM_TIME_SERIES, + validate_quantile_levels, + validate_single_series_covariates, + validate_start_timestamp, + validate_target, +) + + +class TimeSeriesInput(BaseModel): + """Input time series data for forecasting.""" + + target: list[float] | list[list[float]] = Field( + ..., + description="Historical time series values. " + "1-D array for univariate, 2-D array for multivariate", + ) + + item_id: str | None = Field(default=None, description="Unique identifier for the time series") + + start: str | None = Field( + default=None, + description="Start timestamp in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)", + ) + + # Values can be NaN (numeric) or None/NaN mixed in string arrays. + past_covariates: dict[str, list[Union[float, str, None]]] | None = Field( + default=None, + description=( + "Dictionary of past covariate arrays (numeric or categorical). " + "Each array must match the length of the target" + ), + ) + + # Values can be NaN (numeric) or None/NaN mixed in string arrays. + future_covariates: dict[str, list[Union[float, str, None]]] | None = Field( + default=None, + description="Dictionary of known future covariate arrays. " + "Keys must be a subset of past_covariates. " + "Each array must match prediction_length", + ) + + @field_validator("past_covariates", "future_covariates", mode="before") + @classmethod + def _sanitize_nan_in_covariates( + cls, v: dict[str, list[Any]] | None + ) -> dict[str, list[Any]] | None: + """Convert NaN floats to None in covariate arrays. + + Datasets often encode missing categorical values as float NaN. + Pydantic rejects NaN in string-typed lists, so we normalize + NaN → None before validation. + """ + if v is None: + return v + sanitized: dict[str, list[Any]] = {} + for key, arr in v.items(): + sanitized[key] = [None if isinstance(x, float) and math.isnan(x) else x for x in arr] + return sanitized + + @field_validator("target") + @classmethod + def _validate_target( + cls, v: list[float] | list[list[float]] + ) -> list[float] | list[list[float]]: + return validate_target(v) + + @field_validator("start") + @classmethod + def _validate_start(cls, v: str | None) -> str | None: + return validate_start_timestamp(v) + + @model_validator(mode="after") + def _validate_covariates(self) -> "TimeSeriesInput": + validate_single_series_covariates(self.target, self.past_covariates, self.future_covariates) + return self + + +class ForecastParameters(BaseModel): + """Parameters for time series forecasting.""" + + prediction_length: int = Field( + default=1, ge=1, le=1024, description="Number of future steps to forecast" + ) + + quantile_levels: list[float] = Field( + default=[0.1, 0.5, 0.9], + description="Quantile levels for uncertainty quantification. " + "Each value must be between 0 and 1 (exclusive)", + ) + + freq: str | None = Field( + default=None, + description="Pandas frequency string (e.g., 'D' for daily, 'H' for hourly). " + "Required if 'start' is provided in inputs", + ) + + batch_size: int = Field( + default=256, + ge=1, + description="Internal row batch size for model inference. " + "Controls how many rows (series + covariates) are processed " + "in a single model forward pass. Rows are chunked internally " + "respecting series boundaries via group_ids.", + ) + + cross_learning: bool = Field( + default=False, + description="Enable information sharing across time series in batch", + ) + + @field_validator("quantile_levels") + @classmethod + def _validate_quantiles(cls, v: list[float]) -> list[float]: + return validate_quantile_levels(v) + + +class ForecastRequest(OpenAIBaseModel): + """Request format for time series forecasting via pooling API.""" + + model: str = Field(..., description="Model name to use for forecasting") + + task: Literal["forecast"] = Field( + default="forecast", description="Task type, must be 'forecast'" + ) + + data: dict[str, Any] = Field( + ..., + description=("Forecast request data containing 'inputs' and optional 'parameters'"), + ) + + @field_validator("data") + @classmethod + def validate_data_structure(cls, v: dict[str, Any]) -> dict[str, Any]: + """Validate the data field contains required structure.""" + if "inputs" not in v: + raise ValueError("data must contain 'inputs' field") + + if not isinstance(v["inputs"], list): + raise ValueError("data.inputs must be a list") + + if len(v["inputs"]) == 0: + raise ValueError("data.inputs cannot be empty") + + if len(v["inputs"]) > MAX_NUM_TIME_SERIES: + raise ValueError( + f"data.inputs may contain at most {MAX_NUM_TIME_SERIES} time series " + f"(received {len(v['inputs'])})" + ) + + # Validate each input as TimeSeriesInput + validated_inputs = [] + for i, ts_input in enumerate(v["inputs"]): + try: + validated_input = TimeSeriesInput(**ts_input) + validated_inputs.append(validated_input) + except Exception as e: + raise ValueError(f"Invalid time series input at index {i}: {e}") from e + + # Validate parameters if present + validated_params = None + if "parameters" in v and v["parameters"] is not None: + try: + validated_params = ForecastParameters(**v["parameters"]) + except Exception as e: + raise ValueError(f"Invalid parameters: {e}") from e + + # Cross-validate future_covariates length with prediction_length + if validated_params is not None: + prediction_length = validated_params.prediction_length + for i, ts_input in enumerate(validated_inputs): + if ts_input.future_covariates is not None: + for key, values in ts_input.future_covariates.items(): + if len(values) != prediction_length: + raise ValueError( + f"Input {i}: future_covariate '{key}' length " + f"({len(values)}) must match prediction_length " + f"({prediction_length})" + ) + + return v + + +class ForecastPrediction(BaseModel): + """Single time series forecast result with quantile forecasts.""" + + mean: list[float] | list[list[float]] = Field( + ..., description="Point forecast (mean). Shape matches input target" + ) + + item_id: str | None = Field(default=None, description="Echoed from input if provided") + + start: str | None = Field(default=None, description="Start timestamp of forecast horizon") + + class Config: + extra = "allow" # Allow dynamic quantile fields like "0.1", "0.5", "0.9" + + +class ForecastResponse(OpenAIBaseModel): + """Response format for time series forecasting.""" + + request_id: str = Field(..., description="Request identifier") + + created_at: int = Field(..., description="Unix timestamp when the response was created") + + data: dict[str, list[ForecastPrediction]] = Field( + ..., description="Forecast results with 'predictions' key" + ) diff --git a/src/chronos/chronos2/vllm/protocol/validation.py b/src/chronos/chronos2/vllm/protocol/validation.py new file mode 100644 index 0000000..60072ea --- /dev/null +++ b/src/chronos/chronos2/vllm/protocol/validation.py @@ -0,0 +1,242 @@ +"""Centralized validation for Chronos-2 forecast requests. + +All validation rules live here so that both Pydantic model validators +(in ``protocol.forecast``) and the IOProcessor can delegate to a single +source of truth. The design mirrors the SageMaker endpoint's +``validate_payload`` / ``validate_covariates`` utilities. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from chronos.chronos2.vllm.protocol.forecast import ( # noqa: F811 + ForecastParameters, + TimeSeriesInput, + ) + +# --------------------------------------------------------------------------- +# Constants — match the SageMaker endpoint constraints +# TODO: Get some of the values from model config +# --------------------------------------------------------------------------- +MIN_TARGET_LENGTH: int = 5 +MAX_PREDICTION_LENGTH: int = 1024 +MAX_NUM_TIME_SERIES: int = 1024 + + +# --------------------------------------------------------------------------- +# Per-input validators (called from Pydantic field/model validators) +# --------------------------------------------------------------------------- + + +def validate_target( + target: list[float] | list[list[float]], +) -> list[float] | list[list[float]]: + """Validate minimum target length and multivariate row consistency. + + Raises ``ValueError`` if: + - The target is empty. + - The target has fewer than ``MIN_TARGET_LENGTH`` observations. + - For multivariate targets, rows have inconsistent lengths. + """ + if not target: + raise ValueError("Target must not be empty") + if isinstance(target[0], list): + first_len = len(target[0]) + if first_len < MIN_TARGET_LENGTH: + raise ValueError( + f"Target must contain at least {MIN_TARGET_LENGTH} " + f"observations (received {first_len})" + ) + for i, dim in enumerate(target[1:], start=1): + if len(dim) != first_len: # type: ignore[arg-type] + raise ValueError( + f"All target dimensions must have same length. " + f"Dimension 0 has {first_len} observations, " + f"dimension {i} has {len(dim)}" # type: ignore[arg-type] + ) + else: + if len(target) < MIN_TARGET_LENGTH: + raise ValueError( + f"Target must contain at least {MIN_TARGET_LENGTH} " + f"observations (received {len(target)})" + ) + return target + + +def validate_start_timestamp(start: str | None) -> str | None: + """Validate that *start* is a valid ISO-8601 string (or ``None``).""" + if start is not None: + try: + datetime.fromisoformat(start.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError( + f"Invalid start timestamp format: {start}. " + f"Expected ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)" + ) from e + return start + + +def validate_quantile_levels(quantile_levels: list[float]) -> list[float]: + """Validate that every quantile level is in the open interval (0, 1).""" + for q in quantile_levels: + if not (0 < q < 1): + raise ValueError(f"Quantile levels must be between 0 and 1 (exclusive), got {q}") + return quantile_levels + + +def validate_single_series_covariates( + target: list[float] | list[list[float]], + past_covariates: dict[str, list] | None, + future_covariates: dict[str, list] | None, +) -> None: + """Validate covariate constraints for a **single** time series. + + Raises ``ValueError`` if any of the following are violated: + + - ``'target'`` is used as a covariate name. + - ``future_covariates`` is provided without ``past_covariates``. + - Future covariate keys are not a subset of past covariate keys. + - Past covariate array lengths do not match the target length. + """ + if not target: + raise ValueError("Target must not be empty") + target_len = len(target[0]) if isinstance(target[0], list) else len(target) + + # 'target' must not be used as a covariate name + for label, covariates in [ + ("past_covariates", past_covariates), + ("future_covariates", future_covariates), + ]: + if covariates is not None and "target" in covariates: + raise ValueError("Covariate with name 'target' is not allowed") + + # future_covariates requires past_covariates + if future_covariates is not None and past_covariates is None: + raise ValueError( + "Both 'past_covariates' and 'future_covariates' must be provided " + "together. Got 'future_covariates' without 'past_covariates'" + ) + + # future keys ⊆ past keys + if past_covariates is not None and future_covariates is not None: + past_keys = set(past_covariates.keys()) + future_keys = set(future_covariates.keys()) + if not future_keys.issubset(past_keys): + extra = future_keys - past_keys + raise ValueError( + f"All future covariate keys must be present in past covariates. " + f"Keys {extra} are in 'future_covariates' but not in " + f"'past_covariates'" + ) + + # past covariate lengths must match target length + if past_covariates is not None: + for key, values in past_covariates.items(): + if len(values) != target_len: + raise ValueError( + f"Past covariate '{key}' length ({len(values)}) " + f"must match target length ({target_len})" + ) + + +# --------------------------------------------------------------------------- +# Cross-series validators (called from IOProcessor.parse_request) +# --------------------------------------------------------------------------- + + +def validate_cross_series( + inputs: list[TimeSeriesInput], + parameters: ForecastParameters, +) -> None: + """Validate constraints that span across multiple time series. + + Mirrors the SageMaker endpoint's ``validate_payload`` and + ``validate_covariates`` utilities. + + Raises ``ValueError`` if any of the following are violated: + + - ``item_id`` provided for some but not all inputs. + - ``item_id`` values are not unique. + - ``start`` provided for some but not all inputs. + - ``start`` is provided without ``freq`` (or vice-versa). + - Covariate keys are not identical across all series. + - ``future_covariates`` array lengths don't match ``prediction_length``. + """ + _validate_item_ids(inputs) + _validate_start_freq(inputs, parameters) + _validate_covariate_consistency(inputs) + _validate_future_covariate_lengths(inputs, parameters) + + +# -- helpers (private) ------------------------------------------------------- + + +def _validate_item_ids(inputs: list[TimeSeriesInput]) -> None: + item_ids = [ts.item_id for ts in inputs] + has_none = any(x is None for x in item_ids) + has_value = any(x is not None for x in item_ids) + if has_none and has_value: + raise ValueError( + "If 'item_id' is provided for at least one time series in " + "'inputs', it should be provided for all time series" + ) + if has_value and len(item_ids) != len(set(item_ids)): + raise ValueError("'item_id' must be unique for all time series in 'inputs'") + + +def _validate_start_freq( + inputs: list[TimeSeriesInput], + parameters: ForecastParameters, +) -> None: + starts = [ts.start for ts in inputs] + has_none = any(x is None for x in starts) + has_value = any(x is not None for x in starts) + if has_none and has_value: + raise ValueError( + "If 'start' is provided for at least one time series in " + "'inputs', it should be provided for all time series" + ) + if has_value and parameters.freq is None: + raise ValueError( + "If 'start' is provided, then 'freq' must also be provided " "in 'parameters'" + ) + if parameters.freq is not None and not has_value: + raise ValueError( + "If 'freq' is provided in 'parameters', then 'start' must " + "also be provided for all time series in 'inputs'" + ) + + +def _validate_covariate_consistency(inputs: list[TimeSeriesInput]) -> None: + key_sets: list[frozenset[str] | None] = [] + for ts in inputs: + if ts.past_covariates is not None: + key_sets.append(frozenset(ts.past_covariates.keys())) + else: + key_sets.append(None) + + if len(set(key_sets)) > 1: + raise ValueError( + "If 'past_covariates' and 'future_covariates' are provided " + "for at least one time series in 'inputs', the same " + "covariates should be provided for all time series" + ) + + +def _validate_future_covariate_lengths( + inputs: list[TimeSeriesInput], + parameters: ForecastParameters, +) -> None: + prediction_length = parameters.prediction_length + for i, ts in enumerate(inputs): + if ts.future_covariates is not None: + for key, values in ts.future_covariates.items(): + if len(values) != prediction_length: + raise ValueError( + f"Input {i}: length of future covariate '{key}' " + f"({len(values)}) must equal prediction_length " + f"({prediction_length})" + ) diff --git a/src/chronos/chronos2/vllm/utils/__init__.py b/src/chronos/chronos2/vllm/utils/__init__.py new file mode 100644 index 0000000..5266995 --- /dev/null +++ b/src/chronos/chronos2/vllm/utils/__init__.py @@ -0,0 +1 @@ +"""Utility modules for Chronos-2 vLLM plugin.""" diff --git a/src/chronos/chronos2/vllm/utils/helpers.py b/src/chronos/chronos2/vllm/utils/helpers.py new file mode 100644 index 0000000..69344f9 --- /dev/null +++ b/src/chronos/chronos2/vllm/utils/helpers.py @@ -0,0 +1,35 @@ +"""Common helper functions for Chronos-2 vLLM plugin.""" + +from typing import Any + +import torch + + +def tensor_to_list(tensor: torch.Tensor) -> list[float] | list[list[float]]: + """Convert 2-D tensor to list (univariate) or list of lists (multivariate). + + Args: + tensor: shape (n_variates, horizon) + + Returns: + list[float] if univariate (n_variates == 1), else list[list[float]] + """ + assert tensor.ndim == 2 + return tensor[0].tolist() if tensor.shape[0] == 1 else [row.tolist() for row in tensor] + + +def empty_prediction( + prediction_length: int, + quantile_levels: list[float], +) -> dict[str, Any]: + """Return a zero-filled prediction dict for error cases. + + Args: + prediction_length: number of forecast steps + quantile_levels: list of quantile levels to include + """ + zeros = [0.0] * prediction_length + pred: dict[str, Any] = {"mean": zeros} + for q in quantile_levels: + pred[str(q)] = zeros + return pred diff --git a/src/chronos/chronos2/vllm/utils/quantiles.py b/src/chronos/chronos2/vllm/utils/quantiles.py new file mode 100644 index 0000000..8632907 --- /dev/null +++ b/src/chronos/chronos2/vllm/utils/quantiles.py @@ -0,0 +1,40 @@ +"""Quantile selection and interpolation utilities.""" + +import torch + +from chronos.utils import interpolate_quantiles + + +def select_quantiles( + predictions: list[torch.Tensor], + model_quantiles: list[float], + requested_levels: list[float], +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Select or interpolate quantiles from model output. + + Args: + predictions: list of tensors, each shape (..., num_quantiles, horizon) + model_quantiles: the quantile levels the model was trained on (sorted ascending) + requested_levels: the quantile levels the caller wants + + Returns: + (quantiles, mean) where: + - quantiles: list of tensors, each shape (..., horizon, num_requested_quantiles) + - mean: list of tensors, each shape (..., horizon) — the median + """ + # Swap quantile and time axes: [... q h] -> [... h q] + swapped = [pred.permute(*range(pred.ndim - 2), -1, -2) for pred in predictions] + + if set(requested_levels).issubset(model_quantiles): + indices = [model_quantiles.index(q) for q in requested_levels] + quantiles = [pred[..., indices] for pred in swapped] + else: + quantiles = [ + interpolate_quantiles(requested_levels, model_quantiles, pred) for pred in swapped + ] + + # Median as mean (Chronos-2 convention) + median_idx = model_quantiles.index(0.5) if 0.5 in model_quantiles else len(model_quantiles) // 2 + mean = [pred[..., median_idx] for pred in swapped] + + return quantiles, mean \ No newline at end of file diff --git a/test/chronos2/vllm/__init__.py b/test/chronos2/vllm/__init__.py new file mode 100644 index 0000000..9db0516 --- /dev/null +++ b/test/chronos2/vllm/__init__.py @@ -0,0 +1 @@ +# Chronos-2 plugin tests diff --git a/test/chronos2/vllm/test_chronos2_plugin.py b/test/chronos2/vllm/test_chronos2_plugin.py new file mode 100644 index 0000000..680b35b --- /dev/null +++ b/test/chronos2/vllm/test_chronos2_plugin.py @@ -0,0 +1,360 @@ +"""Tests for Chronos-2 plugin validation logic.""" + +import pytest + +from chronos.chronos2.vllm.protocol.forecast import ( + ForecastParameters, + ForecastRequest, + TimeSeriesInput, +) + +# ============================================================================ +# TimeSeriesInput — per-input validation +# ============================================================================ + + +class TestTimeSeriesInputTargetValidation: + """Tests for target field validation.""" + + def test_valid_univariate_target(self): + ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]) + assert len(ts.target) == 5 + + def test_valid_multivariate_target(self): + ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + assert len(ts.target) == 2 + + def test_univariate_target_too_short(self): + with pytest.raises(ValueError, match="at least 5 observations"): + TimeSeriesInput(target=[1.0, 2.0, 3.0]) + + def test_multivariate_target_too_short(self): + with pytest.raises(ValueError, match="at least 5 observations"): + TimeSeriesInput(target=[[1.0, 2.0], [3.0, 4.0]]) + + def test_multivariate_inconsistent_lengths(self): + with pytest.raises(ValueError, match="All target dimensions must have same length"): + TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]) + + +class TestTimeSeriesInputStartTimestamp: + """Tests for start timestamp validation.""" + + def test_valid_date_format(self): + ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01") + assert ts.start == "2024-01-01" + + def test_valid_datetime_format(self): + ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="2024-01-01T12:00:00") + assert ts.start == "2024-01-01T12:00:00" + + def test_invalid_timestamp_format(self): + with pytest.raises(ValueError, match="Invalid start timestamp format"): + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], start="not-a-date") + + +class TestTimeSeriesInputCovariateValidation: + """Tests for per-input covariate validation.""" + + def test_valid_past_covariates(self): + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]}, + ) + assert ts.past_covariates is not None + + def test_past_covariate_length_mismatch(self): + with pytest.raises(ValueError, match="must match target length"): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"feature_a": [1.0, 2.0, 3.0]}, + ) + + def test_target_name_in_past_covariates_rejected(self): + with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"target": [1.0, 2.0, 3.0, 4.0, 5.0]}, + ) + + def test_target_name_in_future_covariates_rejected(self): + with pytest.raises(ValueError, match="Covariate with name 'target' is not allowed"): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]}, + future_covariates={"target": [1.0]}, + ) + + def test_future_covariates_without_past_rejected(self): + with pytest.raises( + ValueError, match="Both 'past_covariates' and 'future_covariates' must be provided" + ): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + future_covariates={"feature_a": [1.0]}, + ) + + def test_future_covariate_keys_must_be_subset_of_past(self): + with pytest.raises( + ValueError, match="All future covariate keys must be present in past covariates" + ): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]}, + future_covariates={"feature_b": [1.0]}, + ) + + def test_future_covariate_keys_subset_is_valid(self): + """Past covariates may have more keys than future — that's fine.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={ + "feature_a": [1.0, 2.0, 3.0, 4.0, 5.0], + "feature_b": [1.0, 2.0, 3.0, 4.0, 5.0], + }, + future_covariates={"feature_a": [1.0]}, + ) + assert ts.future_covariates is not None + + def test_past_only_covariates_valid(self): + """Having past_covariates without future_covariates is acceptable.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"feature_a": [1.0, 2.0, 3.0, 4.0, 5.0]}, + ) + assert ts.past_covariates is not None + assert ts.future_covariates is None + + +# ============================================================================ +# ForecastParameters validation +# ============================================================================ + + +class TestForecastParameters: + """Tests for ForecastParameters validation.""" + + def test_default_values(self): + params = ForecastParameters() + assert params.prediction_length == 1 + assert params.quantile_levels == [0.1, 0.5, 0.9] + assert params.freq is None + assert params.batch_size == 256 + assert params.cross_learning is False + + def test_prediction_length_too_large(self): + with pytest.raises(ValueError): + ForecastParameters(prediction_length=2000) + + def test_prediction_length_zero(self): + with pytest.raises(ValueError): + ForecastParameters(prediction_length=0) + + def test_quantile_out_of_range(self): + with pytest.raises(ValueError, match="between 0 and 1"): + ForecastParameters(quantile_levels=[0.0, 0.5]) + + def test_quantile_one(self): + with pytest.raises(ValueError, match="between 0 and 1"): + ForecastParameters(quantile_levels=[0.5, 1.0]) + + +# ============================================================================ +# Cross-series validation (via validation module) +# ============================================================================ + + +class TestCrossSeriesValidation: + """Tests for cross-series validation logic in validation module.""" + + @staticmethod + def _validate(inputs, parameters=None): + """Helper that calls validate_cross_series from the validation module.""" + from chronos.chronos2.vllm.protocol.validation import validate_cross_series + + if parameters is None: + parameters = ForecastParameters() + validate_cross_series(inputs, parameters) + + def test_item_id_all_provided(self): + """All item_ids present — valid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, item_id="A"), + TimeSeriesInput(target=[2.0] * 5, item_id="B"), + ] + self._validate(inputs) # should not raise + + def test_item_id_none_provided(self): + """No item_ids at all — valid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5), + TimeSeriesInput(target=[2.0] * 5), + ] + self._validate(inputs) # should not raise + + def test_item_id_partial(self): + """Some have item_id, some don't — invalid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, item_id="A"), + TimeSeriesInput(target=[2.0] * 5), + ] + with pytest.raises(ValueError, match="item_id.*provided for all time series"): + self._validate(inputs) + + def test_item_id_not_unique(self): + """Duplicate item_ids — invalid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, item_id="A"), + TimeSeriesInput(target=[2.0] * 5, item_id="A"), + ] + with pytest.raises(ValueError, match="item_id.*must be unique"): + self._validate(inputs) + + def test_start_all_provided_with_freq(self): + """All starts present with freq — valid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"), + TimeSeriesInput(target=[2.0] * 5, start="2024-02-01"), + ] + params = ForecastParameters(freq="D") + self._validate(inputs, params) # should not raise + + def test_start_partial(self): + """Some have start, some don't — invalid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"), + TimeSeriesInput(target=[2.0] * 5), + ] + params = ForecastParameters(freq="D") + with pytest.raises(ValueError, match="start.*provided for all time series"): + self._validate(inputs, params) + + def test_start_without_freq(self): + """start provided but freq missing — invalid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5, start="2024-01-01"), + ] + params = ForecastParameters() # freq is None + with pytest.raises(ValueError, match="freq.*must also be provided"): + self._validate(inputs, params) + + def test_freq_without_start(self): + """freq provided but no start on inputs — invalid.""" + inputs = [ + TimeSeriesInput(target=[1.0] * 5), + ] + params = ForecastParameters(freq="D") + with pytest.raises(ValueError, match="start.*must also be provided"): + self._validate(inputs, params) + + def test_covariate_keys_consistent(self): + """Same covariate keys on all series — valid.""" + inputs = [ + TimeSeriesInput( + target=[1.0] * 5, + past_covariates={"feat": [1.0] * 5}, + ), + TimeSeriesInput( + target=[2.0] * 5, + past_covariates={"feat": [2.0] * 5}, + ), + ] + self._validate(inputs) # should not raise + + def test_covariate_keys_inconsistent(self): + """Different covariate keys across series — invalid.""" + inputs = [ + TimeSeriesInput( + target=[1.0] * 5, + past_covariates={"feat_a": [1.0] * 5}, + ), + TimeSeriesInput( + target=[2.0] * 5, + past_covariates={"feat_b": [2.0] * 5}, + ), + ] + with pytest.raises(ValueError, match="same covariates should be provided"): + self._validate(inputs) + + def test_covariate_some_have_none(self): + """One series has covariates, other doesn't — invalid.""" + inputs = [ + TimeSeriesInput( + target=[1.0] * 5, + past_covariates={"feat": [1.0] * 5}, + ), + TimeSeriesInput(target=[2.0] * 5), + ] + with pytest.raises(ValueError, match="same covariates should be provided"): + self._validate(inputs) + + def test_future_covariate_length_matches_prediction_length(self): + """future_covariates length equals prediction_length — valid.""" + inputs = [ + TimeSeriesInput( + target=[1.0] * 5, + past_covariates={"feat": [1.0] * 5}, + future_covariates={"feat": [1.0, 2.0, 3.0]}, + ), + ] + params = ForecastParameters(prediction_length=3) + self._validate(inputs, params) # should not raise + + def test_future_covariate_length_mismatch(self): + """future_covariates length != prediction_length — invalid.""" + inputs = [ + TimeSeriesInput( + target=[1.0] * 5, + past_covariates={"feat": [1.0] * 5}, + future_covariates={"feat": [1.0, 2.0]}, + ), + ] + params = ForecastParameters(prediction_length=3) + with pytest.raises(ValueError, match="must equal prediction_length"): + self._validate(inputs, params) + + +# ============================================================================ +# ForecastRequest — end-to-end request validation +# ============================================================================ + + +class TestForecastRequest: + """Tests for ForecastRequest end-to-end validation.""" + + def test_valid_minimal_request(self): + req = ForecastRequest( + model="chronos-v2", + data={"inputs": [{"target": [1.0, 2.0, 3.0, 4.0, 5.0]}]}, + ) + assert req.data["inputs"] is not None + + def test_empty_inputs_rejected(self): + with pytest.raises(ValueError, match="cannot be empty"): + ForecastRequest(model="chronos-v2", data={"inputs": []}) + + def test_too_many_inputs_rejected(self): + many_inputs = [{"target": [1.0] * 5} for _ in range(1025)] + with pytest.raises(ValueError, match="at most 1024"): + ForecastRequest(model="chronos-v2", data={"inputs": many_inputs}) + + def test_missing_inputs_rejected(self): + with pytest.raises(ValueError, match="inputs"): + ForecastRequest(model="chronos-v2", data={"parameters": {}}) + + def test_future_covariate_length_cross_validated(self): + """ForecastRequest validates future_covariates against prediction_length.""" + with pytest.raises(ValueError, match="prediction_length"): + ForecastRequest( + model="chronos-v2", + data={ + "inputs": [ + { + "target": [1.0] * 5, + "past_covariates": {"feat": [1.0] * 5}, + "future_covariates": {"feat": [1.0, 2.0]}, + } + ], + "parameters": {"prediction_length": 5}, + }, + ) diff --git a/test/chronos2/vllm/test_data_prep.py b/test/chronos2/vllm/test_data_prep.py new file mode 100644 index 0000000..66ebace --- /dev/null +++ b/test/chronos2/vllm/test_data_prep.py @@ -0,0 +1,248 @@ +"""Unit tests for chronos.chronos2.vllm.protocol.data_prep.""" + +import pytest + +torch = pytest.importorskip("torch") + +from chronos.chronos2.vllm.protocol.data_prep import ( # noqa: E402 + PreparedRequest, + prepare_request, +) +from chronos.chronos2.vllm.protocol.forecast import ( # noqa: E402 + ForecastParameters, + TimeSeriesInput, +) + + +class TestPrepareRequestUnivariate: + """Tests for prepare_request with univariate inputs.""" + + def test_single_series(self): + inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])] + params = ForecastParameters(prediction_length=3) + result = prepare_request(inputs, params) + + assert isinstance(result, PreparedRequest) + assert len(result.batches) >= 1 + assert len(result.item_ids) == 1 + assert result.item_ids[0] is None + + batch = result.batches[0] + assert batch.context.shape[0] == 1 + assert batch.context.shape[1] == 5 + assert batch.group_ids.shape == (1,) + assert batch.target_idx_ranges == [(0, 1)] + + def test_multiple_series(self): + inputs = [ + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="a"), + TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0], item_id="b"), + ] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + # Both series should fit in one batch with default batch_size + assert len(result.batches) >= 1 + # Count total rows across all batches + total_rows = sum(b.context.shape[0] for b in result.batches) + assert total_rows == 2 + assert result.item_ids == ["a", "b"] + + def test_different_lengths_padded(self): + inputs = [ + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]), + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]), + ] + params = ForecastParameters(prediction_length=1) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # Padded to max length = 10 + assert batch.context.shape == (2, 10) + # First series: right-aligned, left NaN-padded + assert torch.isnan(batch.context[0, :5]).all() + assert not torch.isnan(batch.context[0, 5:]).any() + + def test_truncation_to_context_length(self): + long_target = list(range(100)) + inputs = [TimeSeriesInput(target=long_target)] + params = ForecastParameters(prediction_length=1) + result = prepare_request(inputs, params, context_length=20) + + batch = result.batches[0] + assert batch.context.shape == (1, 20) + + def test_future_covariates_nan_for_targets(self): + inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])] + params = ForecastParameters(prediction_length=3) + result = prepare_request(inputs, params) + + # Target rows should have NaN future covariates + assert torch.isnan(result.batches[0].future_covariates[0]).all() + + +class TestPrepareRequestMultivariate: + """Tests for prepare_request with multivariate inputs.""" + + def test_bivariate(self): + inputs = [TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # 2 variates = 2 rows + assert batch.context.shape == (2, 5) + assert batch.target_idx_ranges == [(0, 2)] + assert batch.group_ids.tolist() == [0, 0] + + +class TestPrepareRequestCovariates: + """Tests for prepare_request with covariates.""" + + def test_past_covariates(self): + inputs = [ + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}, + ) + ] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # 1 target + 1 covariate = 2 rows + assert batch.context.shape == (2, 5) + # Only target row is in target_idx_ranges + assert batch.target_idx_ranges == [(0, 1)] + assert batch.group_ids.tolist() == [0, 0] + + def test_future_covariates(self): + inputs = [ + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}, + future_covariates={"temp": [60.0, 70.0]}, + ) + ] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # Target row: NaN future covariates + assert torch.isnan(batch.future_covariates[0]).all() + # Covariate row: actual future values + assert batch.future_covariates[1, 0].item() == 60.0 + assert batch.future_covariates[1, 1].item() == 70.0 + + +class TestPrepareRequestCrossLearning: + """Tests for cross-learning mode.""" + + def test_cross_learning_zeros_group_ids(self): + inputs = [ + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]), + TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]), + ] + params = ForecastParameters(prediction_length=1, cross_learning=True) + result = prepare_request(inputs, params) + + assert result.batches[0].group_ids.tolist() == [0, 0] + + def test_no_cross_learning_distinct_group_ids(self): + inputs = [ + TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]), + TimeSeriesInput(target=[6.0, 7.0, 8.0, 9.0, 10.0]), + ] + params = ForecastParameters(prediction_length=1, cross_learning=False) + result = prepare_request(inputs, params) + + assert result.batches[0].group_ids.tolist() == [0, 1] + + +class TestPrepareRequestNanCovariates: + """Tests for prepare_request handling of NaN-sanitized covariates.""" + + def test_numeric_covariates_with_none(self): + """Numeric covariates with None values should become NaN in tensors.""" + inputs = [ + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, None, 30.0, None, 50.0]}, + future_covariates={"temp": [None, 70.0, None]}, + ) + ] + params = ForecastParameters(prediction_length=3) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # 1 target + 1 covariate = 2 rows + assert batch.context.shape == (2, 5) + # Covariate row: None → NaN in context tensor + assert torch.isnan(batch.context[1, 1]) + assert torch.isnan(batch.context[1, 3]) + assert batch.context[1, 0].item() == 10.0 + assert batch.context[1, 2].item() == 30.0 + # Future covariates: None → NaN + assert torch.isnan(batch.future_covariates[1, 0]) + assert batch.future_covariates[1, 1].item() == 70.0 + + def test_string_covariates_encoded_in_tensors(self): + """String covariates (like holidays) should be encoded and included in tensors.""" + inputs = [ + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"holiday": ["A", None, "B", None, "C"]}, + ) + ] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # 1 target + 1 encoded categorical covariate = 2 rows + assert batch.context.shape == (2, 5) + # Encoded values should be finite floats (not NaN) for non-None entries + assert not torch.isnan(batch.context[1, 0]) # "A" encoded + assert not torch.isnan(batch.context[1, 2]) # "B" encoded + assert not torch.isnan(batch.context[1, 4]) # "C" encoded + + def test_categorical_with_future(self): + """Categorical covariates with future values should be encoded consistently.""" + inputs = [ + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"holiday": ["A", "B", "A", "B", "A"]}, + future_covariates={"holiday": ["A", "B"]}, + ) + ] + params = ForecastParameters(prediction_length=2) + result = prepare_request(inputs, params) + + batch = result.batches[0] + # 1 target + 1 encoded categorical = 2 rows + assert batch.context.shape == (2, 5) + # Future covariates should be filled for the categorical row + assert not torch.isnan(batch.future_covariates[1, 0]) + assert not torch.isnan(batch.future_covariates[1, 1]) + + +class TestPrepareRequestBatching: + """Tests for batch output — always a single batch (chunking deferred to model forward).""" + + def test_always_single_batch(self): + """prepare_request always returns one batch regardless of batch_size parameter.""" + inputs = [TimeSeriesInput(target=[float(i)] * 5) for i in range(5)] + params = ForecastParameters(prediction_length=1, batch_size=2) + result = prepare_request(inputs, params) + + # Always 1 batch — model forward() handles chunking if needed + assert len(result.batches) == 1 + assert result.batches[0].context.shape[0] == 5 + assert len(result.batches[0].target_idx_ranges) == 5 + + def test_single_series_single_batch(self): + inputs = [TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0])] + params = ForecastParameters(prediction_length=1, batch_size=256) + result = prepare_request(inputs, params) + + assert len(result.batches) == 1 + assert result.batches[0].context.shape[0] == 1 \ No newline at end of file diff --git a/test/chronos2/vllm/test_helpers.py b/test/chronos2/vllm/test_helpers.py new file mode 100644 index 0000000..67cb099 --- /dev/null +++ b/test/chronos2/vllm/test_helpers.py @@ -0,0 +1,53 @@ +"""Unit tests for chronos.chronos2.vllm.utils.helpers.""" + +import pytest + +torch = pytest.importorskip("torch") + +from chronos.chronos2.vllm.utils.helpers import empty_prediction, tensor_to_list # noqa: E402 + + +class TestTensorToList: + """Tests for tensor_to_list.""" + + def test_univariate(self): + t = torch.tensor([[1.0, 2.0, 3.0]]) + result = tensor_to_list(t) + assert result == [1.0, 2.0, 3.0] + assert isinstance(result, list) + assert isinstance(result[0], float) + + def test_multivariate(self): + t = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + result = tensor_to_list(t) + assert result == [[1.0, 2.0], [3.0, 4.0]] + assert isinstance(result[0], list) + + def test_single_value(self): + t = torch.tensor([[42.0]]) + assert tensor_to_list(t) == [42.0] + + def test_wrong_dims_raises(self): + with pytest.raises(AssertionError): + tensor_to_list(torch.tensor([1.0, 2.0])) # 1-D + + +class TestEmptyPrediction: + """Tests for empty_prediction.""" + + def test_basic(self): + pred = empty_prediction(3, [0.1, 0.5, 0.9]) + assert pred["mean"] == [0.0, 0.0, 0.0] + assert pred["0.1"] == [0.0, 0.0, 0.0] + assert pred["0.5"] == [0.0, 0.0, 0.0] + assert pred["0.9"] == [0.0, 0.0, 0.0] + + def test_single_quantile(self): + pred = empty_prediction(2, [0.5]) + assert "mean" in pred + assert "0.5" in pred + assert len(pred["mean"]) == 2 + + def test_empty_quantiles(self): + pred = empty_prediction(1, []) + assert pred == {"mean": [0.0]} diff --git a/test/chronos2/vllm/test_protocol.py b/test/chronos2/vllm/test_protocol.py new file mode 100644 index 0000000..65b34f7 --- /dev/null +++ b/test/chronos2/vllm/test_protocol.py @@ -0,0 +1,178 @@ +"""Unit tests for chronos.chronos2.vllm.protocol.forecast Pydantic models.""" + +import pytest + +from chronos.chronos2.vllm.protocol.forecast import ( + ForecastParameters, + ForecastPrediction, + TimeSeriesInput, +) + + +class TestTimeSeriesInput: + """Tests for TimeSeriesInput Pydantic model.""" + + def test_minimal(self): + ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0]) + assert ts.target == [1.0, 2.0, 3.0, 4.0, 5.0] + assert ts.item_id is None + assert ts.start is None + assert ts.past_covariates is None + assert ts.future_covariates is None + + def test_with_item_id(self): + ts = TimeSeriesInput(target=[1.0, 2.0, 3.0, 4.0, 5.0], item_id="series_1") + assert ts.item_id == "series_1" + + def test_multivariate_target(self): + ts = TimeSeriesInput(target=[[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + assert len(ts.target) == 2 + + def test_with_covariates(self): + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, 20.0, 30.0, 40.0, 50.0]}, + future_covariates={"temp": [60.0]}, + ) + assert "temp" in ts.past_covariates + assert "temp" in ts.future_covariates + + def test_too_short_target_rejected(self): + with pytest.raises(Exception): + TimeSeriesInput(target=[1.0, 2.0]) + + def test_covariate_length_mismatch_rejected(self): + with pytest.raises(Exception): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, 20.0]}, # wrong length + ) + + def test_future_without_past_rejected(self): + with pytest.raises(Exception): + TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + future_covariates={"temp": [60.0]}, + ) + + +class TestTimeSeriesInputNanCovariates: + """Tests for NaN handling in covariate arrays (favorita_stores_1D scenario).""" + + def test_nan_in_string_future_covariates_sanitized(self): + """NaN floats in string covariate lists should be sanitized to None.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={ + "holiday": ["New Year", "MLK Day", float("nan"), float("nan"), "Easter"] + }, + future_covariates={"holiday": [float("nan")]}, + ) + # NaN should be converted to None + assert ts.past_covariates["holiday"][2] is None + assert ts.past_covariates["holiday"][3] is None + assert ts.future_covariates["holiday"][0] is None + # Strings should remain unchanged + assert ts.past_covariates["holiday"][0] == "New Year" + assert ts.past_covariates["holiday"][4] == "Easter" + + def test_nan_in_numeric_covariates_preserved(self): + """NaN in numeric covariates should be sanitized to None too (consistent).""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"temp": [10.0, float("nan"), 30.0, 40.0, 50.0]}, + future_covariates={"temp": [float("nan")]}, + ) + # NaN floats → None after sanitization + assert ts.past_covariates["temp"][1] is None + assert ts.future_covariates["temp"][0] is None + + def test_all_nan_string_covariates(self): + """All-NaN covariate arrays should be accepted.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"holiday": [float("nan")] * 5}, + future_covariates={"holiday": [float("nan")]}, + ) + assert all(v is None for v in ts.past_covariates["holiday"]) + assert all(v is None for v in ts.future_covariates["holiday"]) + + def test_mixed_string_and_none_accepted(self): + """Mix of strings and None should be accepted.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"holiday": ["A", None, "B", None, "C"]}, + ) + assert ts.past_covariates["holiday"] == ["A", None, "B", None, "C"] + + def test_clean_string_covariates_unchanged(self): + """String covariates without NaN should pass through unchanged.""" + ts = TimeSeriesInput( + target=[1.0, 2.0, 3.0, 4.0, 5.0], + past_covariates={"category": ["A", "B", "C", "D", "E"]}, + ) + assert ts.past_covariates["category"] == ["A", "B", "C", "D", "E"] + + +class TestForecastParameters: + """Tests for ForecastParameters Pydantic model.""" + + def test_defaults(self): + params = ForecastParameters() + assert params.prediction_length == 1 + assert params.quantile_levels == [0.1, 0.5, 0.9] + assert params.freq is None + assert params.batch_size == 256 + assert params.cross_learning is False + + def test_custom_values(self): + params = ForecastParameters( + prediction_length=24, + quantile_levels=[0.25, 0.5, 0.75], + batch_size=100, + cross_learning=True, + ) + assert params.prediction_length == 24 + assert params.quantile_levels == [0.25, 0.5, 0.75] + assert params.batch_size == 100 + assert params.cross_learning is True + + def test_invalid_prediction_length(self): + with pytest.raises(Exception): + ForecastParameters(prediction_length=0) + + def test_invalid_quantile_level(self): + with pytest.raises(Exception): + ForecastParameters(quantile_levels=[0.0, 0.5]) + + def test_prediction_length_max(self): + with pytest.raises(Exception): + ForecastParameters(prediction_length=2000) + + +class TestForecastPrediction: + """Tests for ForecastPrediction Pydantic model.""" + + def test_minimal(self): + pred = ForecastPrediction(mean=[1.0, 2.0, 3.0]) + assert pred.mean == [1.0, 2.0, 3.0] + assert pred.item_id is None + + def test_with_quantiles(self): + pred = ForecastPrediction( + mean=[1.0, 2.0], + item_id="s1", + **{"0.1": [0.5, 1.5], "0.9": [1.5, 2.5]}, + ) + assert pred.item_id == "s1" + + def test_multivariate_mean(self): + pred = ForecastPrediction(mean=[[1.0, 2.0], [3.0, 4.0]]) + assert len(pred.mean) == 2 + + def test_extra_fields_allowed(self): + """ForecastPrediction allows dynamic quantile fields.""" + pred = ForecastPrediction(mean=[1.0], **{"0.5": [1.0], "0.1": [0.5]}) + d = pred.model_dump() + assert "0.5" in d + assert "0.1" in d diff --git a/test/chronos2/vllm/test_quantiles.py b/test/chronos2/vllm/test_quantiles.py new file mode 100644 index 0000000..375ff8c --- /dev/null +++ b/test/chronos2/vllm/test_quantiles.py @@ -0,0 +1,98 @@ +"""Unit tests for chronos.chronos2.vllm.utils.quantiles and chronos.utils.interpolate_quantiles.""" + +import pytest + +torch = pytest.importorskip("torch") + +from chronos.chronos2.vllm.utils.quantiles import select_quantiles # noqa: E402 +from chronos.utils import interpolate_quantiles # noqa: E402 + +MODEL_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + +class TestSelectQuantiles: + """Tests for select_quantiles.""" + + def test_subset_selection(self): + """Requesting a subset of model quantiles should do direct indexing.""" + # Shape: (1, 9_quantiles, 4_horizon) + pred = torch.randn(1, len(MODEL_QUANTILES), 4) + quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9]) + + assert len(quantiles) == 1 + assert len(mean) == 1 + # After swap: (1, 4, 3) + assert quantiles[0].shape == (1, 4, 3) + # Mean is median (0.5 = index 4) + assert mean[0].shape == (1, 4) + + def test_direct_selection_values(self): + """Verify selected values match expected indices.""" + pred = torch.arange(36, dtype=torch.float32).reshape(1, 9, 4) + quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.9]) + + # q=0.1 is index 0, q=0.9 is index 8 + # After swap: pred[..., h, q] → values at (h, q_idx) + q = quantiles[0] # (1, 4, 2) + assert q[0, 0, 0].item() == 0.0 # q=0.1, h=0 → row 0, col 0 + assert q[0, 0, 1].item() == 32.0 # q=0.9, h=0 → row 8, col 0 + + def test_interpolation_triggered(self): + """Non-subset quantile levels should trigger interpolation.""" + pred = torch.randn(1, len(MODEL_QUANTILES), 4) + quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.15, 0.5]) + + assert quantiles[0].shape == (1, 4, 2) + + def test_multiple_predictions(self): + """Works with multiple prediction tensors.""" + preds = [torch.randn(1, 9, 4) for _ in range(3)] + quantiles, mean = select_quantiles(preds, MODEL_QUANTILES, [0.5]) + + assert len(quantiles) == 3 + assert len(mean) == 3 + + def test_multivariate(self): + """Works with multivariate predictions (2 variates).""" + pred = torch.randn(2, len(MODEL_QUANTILES), 4) + quantiles, mean = select_quantiles([pred], MODEL_QUANTILES, [0.1, 0.5, 0.9]) + + assert quantiles[0].shape == (2, 4, 3) + assert mean[0].shape == (2, 4) + + +class TestInterpolateQuantiles: + """Tests for chronos.utils.interpolate_quantiles.""" + + def test_exact_match(self): + """Exact match should return the original values.""" + values = torch.tensor([[1.0, 2.0, 3.0]]) # (1, 3) for levels [0.1, 0.5, 0.9] + result = interpolate_quantiles([0.5], [0.1, 0.5, 0.9], values) + assert result.shape == (1, 1) + assert result[0, 0].item() == 2.0 + + def test_midpoint_interpolation(self): + """Midpoint between two levels should be the average.""" + values = torch.tensor([[0.0, 10.0]]) # (1, 2) for levels [0.2, 0.8] + result = interpolate_quantiles([0.5], [0.2, 0.8], values) + assert abs(result[0, 0].item() - 5.0) < 1e-5 + + def test_clamping(self): + """Query below/above model range should clamp to boundary.""" + values = torch.tensor([[1.0, 2.0, 3.0]]) # levels [0.1, 0.5, 0.9] + low = interpolate_quantiles([0.01], [0.1, 0.5, 0.9], values) + high = interpolate_quantiles([0.99], [0.1, 0.5, 0.9], values) + assert low[0, 0].item() == pytest.approx(1.0, abs=1e-5) # clamped to 0.1 + assert high[0, 0].item() == pytest.approx(3.0, abs=1e-5) # clamped to 0.9 + + def test_multiple_queries(self): + """Multiple query levels should produce correct shape.""" + values = torch.tensor([[1.0, 2.0, 3.0]]) + result = interpolate_quantiles([0.1, 0.3, 0.5, 0.9], [0.1, 0.5, 0.9], values) + assert result.shape == (1, 4) + + def test_batch_dimension(self): + """Works with batch dimension.""" + values = torch.randn(5, 9) # 5 time steps, 9 quantiles + result = interpolate_quantiles([0.25], MODEL_QUANTILES, values) + assert result.shape == (5, 1) \ No newline at end of file diff --git a/test/chronos2/vllm/test_validation.py b/test/chronos2/vllm/test_validation.py new file mode 100644 index 0000000..24045d5 --- /dev/null +++ b/test/chronos2/vllm/test_validation.py @@ -0,0 +1,163 @@ +"""Unit tests for chronos.chronos2.vllm.protocol.validation.""" + +import pytest + +from chronos.chronos2.vllm.protocol.forecast import ForecastParameters, TimeSeriesInput +from chronos.chronos2.vllm.protocol.validation import ( + validate_cross_series, + validate_quantile_levels, + validate_single_series_covariates, + validate_start_timestamp, + validate_target, +) + + +class TestValidateTarget: + """Tests for validate_target.""" + + def test_valid_univariate(self): + result = validate_target([1.0, 2.0, 3.0, 4.0, 5.0]) + assert len(result) == 5 + + def test_valid_multivariate(self): + result = validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + assert len(result) == 2 + + def test_too_short_univariate(self): + with pytest.raises(ValueError, match="at least 5"): + validate_target([1.0, 2.0]) + + def test_too_short_multivariate(self): + with pytest.raises(ValueError, match="at least 5"): + validate_target([[1.0, 2.0], [3.0, 4.0]]) + + def test_inconsistent_multivariate_lengths(self): + with pytest.raises(ValueError, match="same length"): + validate_target([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0]]) + + +class TestValidateStartTimestamp: + """Tests for validate_start_timestamp.""" + + def test_none(self): + assert validate_start_timestamp(None) is None + + def test_valid_date(self): + assert validate_start_timestamp("2024-01-01") == "2024-01-01" + + def test_valid_datetime(self): + assert validate_start_timestamp("2024-01-01T12:00:00") == "2024-01-01T12:00:00" + + def test_invalid_format(self): + with pytest.raises(ValueError, match="Invalid start timestamp"): + validate_start_timestamp("not-a-date") + + +class TestValidateQuantileLevels: + """Tests for validate_quantile_levels.""" + + def test_valid(self): + result = validate_quantile_levels([0.1, 0.5, 0.9]) + assert result == [0.1, 0.5, 0.9] + + def test_zero_invalid(self): + with pytest.raises(ValueError, match="between 0 and 1"): + validate_quantile_levels([0.0, 0.5]) + + def test_one_invalid(self): + with pytest.raises(ValueError, match="between 0 and 1"): + validate_quantile_levels([0.5, 1.0]) + + def test_negative_invalid(self): + with pytest.raises(ValueError, match="between 0 and 1"): + validate_quantile_levels([-0.1]) + + +class TestValidateSingleSeriesCovariates: + """Tests for validate_single_series_covariates.""" + + def test_no_covariates(self): + validate_single_series_covariates([1.0, 2.0, 3.0, 4.0, 5.0], None, None) + + def test_valid_past_covariates(self): + target = [1.0, 2.0, 3.0, 4.0, 5.0] + past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]} + validate_single_series_covariates(target, past, None) + + def test_future_without_past_raises(self): + target = [1.0, 2.0, 3.0, 4.0, 5.0] + future = {"temp": [60.0]} + with pytest.raises(ValueError, match="together"): + validate_single_series_covariates(target, None, future) + + def test_future_key_not_in_past_raises(self): + target = [1.0, 2.0, 3.0, 4.0, 5.0] + past = {"temp": [10.0, 20.0, 30.0, 40.0, 50.0]} + future = {"wind": [1.0]} + with pytest.raises(ValueError, match="not in"): + validate_single_series_covariates(target, past, future) + + def test_past_length_mismatch(self): + target = [1.0, 2.0, 3.0, 4.0, 5.0] + past = {"temp": [10.0, 20.0]} + with pytest.raises(ValueError, match="must match target length"): + validate_single_series_covariates(target, past, None) + + def test_target_name_forbidden(self): + target = [1.0, 2.0, 3.0, 4.0, 5.0] + past = {"target": [10.0, 20.0, 30.0, 40.0, 50.0]} + with pytest.raises(ValueError, match="not allowed"): + validate_single_series_covariates(target, past, None) + + +class TestValidateCrossSeries: + """Tests for validate_cross_series.""" + + def _make_input(self, target=None, item_id=None, start=None, past_cov=None, future_cov=None): + return TimeSeriesInput( + target=target or [1.0, 2.0, 3.0, 4.0, 5.0], + item_id=item_id, + start=start, + past_covariates=past_cov, + future_covariates=future_cov, + ) + + def test_valid_simple(self): + inputs = [self._make_input(), self._make_input()] + params = ForecastParameters(prediction_length=3) + validate_cross_series(inputs, params) + + def test_inconsistent_item_ids(self): + inputs = [self._make_input(item_id="a"), self._make_input(item_id=None)] + params = ForecastParameters() + with pytest.raises(ValueError, match="item_id"): + validate_cross_series(inputs, params) + + def test_duplicate_item_ids(self): + inputs = [self._make_input(item_id="a"), self._make_input(item_id="a")] + params = ForecastParameters() + with pytest.raises(ValueError, match="unique"): + validate_cross_series(inputs, params) + + def test_start_without_freq(self): + inputs = [self._make_input(start="2024-01-01")] + params = ForecastParameters(prediction_length=1) + with pytest.raises(ValueError, match="freq"): + validate_cross_series(inputs, params) + + def test_freq_without_start(self): + inputs = [self._make_input()] + params = ForecastParameters(prediction_length=1, freq="D") + with pytest.raises(ValueError, match="start"): + validate_cross_series(inputs, params) + + def test_future_cov_length_mismatch(self): + inputs = [ + self._make_input( + past_cov={"temp": [1.0, 2.0, 3.0, 4.0, 5.0]}, + future_cov={"temp": [10.0, 20.0]}, # length 2 != prediction_length 3 + ) + ] + params = ForecastParameters(prediction_length=3) + with pytest.raises(ValueError, match="prediction_length"): + validate_cross_series(inputs, params)