Update BaseChronosPipeline docstrings

This commit is contained in:
Abdul Fatir Ansari 2025-12-08 16:45:32 +01:00
parent 822c773424
commit d2ff57ff40

View file

@ -42,25 +42,84 @@ class PipelineRegistry(type):
class BaseChronosPipeline(metaclass=PipelineRegistry):
"""
Abstract base class for Chronos pretrained time series forecasting pipelines.
This class defines the common interface for all Chronos models. The package provides
multiple pipeline implementations with different forecasting approaches and architectures:
- ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
- ChronosBoltPipeline: Quantile-based forecasting with patching
- Chronos2Pipeline (recommended): Quantile-based forecasting with support for multivariate and covariate-informed forecasting
Each subclass implements the abstract methods and properties defined here,
potentially with different parameter signatures and return types depending
on the model architecture and forecasting approach.
Attributes
----------
forecast_type
Enum indicating whether the pipeline produces samples or quantiles
inner_model
The underlying HuggingFace transformers model
See Also
--------
ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
ChronosBoltPipeline: Quantile-based forecasting with patching
Chronos2Pipeline (recommended): Quantile-based forecasting with support for multivariate and covariate-informed forecasting
"""
forecast_type: ForecastType
dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
def __init__(self, inner_model: "PreTrainedModel"):
"""
Initialize the base pipeline with a pretrained model.
Parameters
----------
inner_model : PreTrainedModel
A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
inner_model
A HuggingFace transformers PreTrainedModel that serves as the
underlying forecasting model (e.g., T5ForConditionalGeneration)
"""
# for easy access to the inner HF-style model
self.inner_model = inner_model
@property
def model_context_length(self) -> int:
"""
Maximum number of time steps the model can use as context.
This is an abstract property that must be implemented by subclasses.
Returns
-------
int
Maximum context length supported by the model
Notes
-----
Subclasses must implement this property based on their specific
model architecture and configuration.
"""
raise NotImplementedError()
@property
def model_prediction_length(self) -> int:
"""
Default prediction horizon for the model.
This is an abstract property that must be implemented by subclasses.
Returns
-------
int
Default prediction horizon
Notes
-----
Subclasses must implement this property based on their specific model architecture and configuration.
"""
raise NotImplementedError()
def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
@ -75,25 +134,35 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
def predict(self, inputs: Union[torch.Tensor, List[torch.Tensor]], prediction_length: Optional[int] = None):
"""
Get forecasts for the given time series. Predictions will be
returned in fp32 on the cpu.
Generate forecasts for the given time series.
This is an abstract method that must be implemented by subclasses.
Each subclass may have different parameters and return types depending
on the model architecture and forecasting approach. Predictions are
typically returned in fp32 on the CPU.
Parameters
----------
inputs
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Input time series. Can be a 1D tensor (single series), a list
of 1D tensors (multiple series of varying lengths), or a 2D tensor
where the first dimension is batch size. For 2D tensors, use
left-padding with torch.nan to align series of different lengths.
prediction_length
Time steps to predict. Defaults to a model-dependent
value if not given.
Number of time steps to forecast. If not provided, defaults to
the model's default prediction length.
Returns
-------
forecasts
Tensor containing forecasts. The layout and meaning
of the forecasts values depends on ``self.forecast_type``.
torch.Tensor
Forecasts tensor. The shape and interpretation depend on the
subclass's forecast_type (samples or quantiles).
Notes
-----
Subclasses may extend this interface with additional parameters
specific to their forecasting approach. Refer to specific subclass
documentation for complete parameter lists and return value details.
"""
raise NotImplementedError()
@ -105,30 +174,42 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get quantile and mean forecasts for given time series.
Predictions will be returned in fp32 on the cpu.
Generate quantile and mean forecasts for given time series.
This is an abstract method that must be implemented by subclasses.
Each subclass may have different parameters depending on the model
architecture. Predictions are typically returned in fp32 on the CPU.
Parameters
----------
inputs : Union[torch.Tensor, List[torch.Tensor]]
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
prediction_length : Optional[int], optional
Time steps to predict. Defaults to a model-dependent
value if not given.
quantile_levels : List[float], optional
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
inputs
Input time series. Can be a 1D tensor (single series), a list
of 1D tensors (multiple series of varying lengths), or a 2D tensor
where the first dimension is batch size. For 2D tensors, use
left-padding with torch.nan to align series of different lengths.
prediction_length
Number of time steps to forecast. If not provided, defaults to
the model's default prediction length.
quantile_levels
List of quantile levels to compute, each between 0 and 1.
Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
**kwargs
Additional keyword arguments that may be used by subclass implementations.
Returns
-------
quantiles
Tensor containing quantile forecasts. Shape
Tensor of quantile forecasts with shape
(batch_size, prediction_length, num_quantiles)
mean
Tensor containing mean (point) forecasts. Shape
Tensor of mean (point) forecasts with shape
(batch_size, prediction_length)
Notes
-----
Subclasses may extend this interface with additional parameters
specific to their forecasting approach. Refer to specific subclass
documentation for complete parameter lists and implementation details.
"""
raise NotImplementedError()
@ -145,8 +226,12 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
**predict_kwargs,
) -> "pd.DataFrame":
"""
Perform forecasting on time series data in a long-format pandas DataFrame.
Generate forecasts for time series data in a pandas DataFrame.
This method provides a convenient interface for forecasting on long-format
pandas DataFrames containing multiple time series. It handles data conversion,
batching, and result formatting automatically.
Parameters
----------
df
@ -171,12 +256,30 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
Returns
-------
The forecasts dataframe generated by the model with the following columns
- `id_column`: The time series ID
- `timestamp_column`: Future timestamps
- "target_name": The name of the target column
- "predictions": The point predictions generated by the model
- One column for predictions at each quantile level in `quantile_levels`
pd.DataFrame
Forecast results in long format with the following columns:
- Column named by id_column: Time series identifiers
- Column named by timestamp_column: Future timestamps
- "target_name": Name of the forecasted target variable
- "predictions": Point forecasts (mean predictions)
- One column per quantile level (e.g., "0.1", "0.5", "0.9")
Raises
------
ImportError
If pandas is not installed.
ValueError
If target is not a string (multivariate forecasting not supported).
Notes
-----
This method requires pandas to be installed. Install with `pip install pandas`.
The method internally converts the DataFrame to tensor format, generates
forecasts using predict_quantiles, and converts results back to DataFrame format.
Subclasses may have additional parameters or behavior. Refer to specific
subclass documentation for implementation details.
"""
try:
import pandas as pd
@ -245,23 +348,43 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
self, task: "fev.Task", batch_size: int = 32, **kwargs
) -> tuple[list["datasets.DatasetDict"], float]:
"""
Make predictions for evaluation on a fev.Task.
Generate predictions for evaluation on a fev benchmark task.
This method provides integration with the fev (Forecasting Evaluation)
library for standardized benchmark evaluation. It handles batching,
timing, and formatting predictions according to the task requirements.
Parameters
----------
task
Benchmark task on which the evaluation should be done.
A fev.Task object defining the benchmark evaluation task, including
the dataset, horizon, quantile levels, and evaluation metric.
batch_size
Batch size used during evaluation.
Number of time series to process in each batch during inference.
Larger batch sizes may improve throughput but require more memory.
Default is 32.
**kwargs
Additional keyword arguments that will be forwarded to `self.predict_quantiles`.
Additional keyword arguments forwarded to the predict_quantiles method.
These may include model-specific parameters.
Returns
-------
predictions_per_window
Predictions for each window, each stored as a DatasetDict
List of DatasetDict objects, one for each evaluation window in the task.
Each DatasetDict contains predictions formatted according to fev requirements.
inference_time_s
Total time that it took to make predictions for all windows (in seconds).
Total inference time in seconds across all windows, excluding data
loading and preprocessing time.
Raises
------
ImportError
If the fev library is not installed.
Notes
-----
This method requires the fev library to be installed. Install with
`pip install fev`.
"""
import datasets
@ -334,8 +457,57 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
**kwargs,
):
"""
Load the model, either from a local path, S3 prefix, or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
Load a pretrained Chronos pipeline from various sources.
This class method loads a pretrained model from a local path, S3 bucket,
or the HuggingFace Hub. It automatically detects the appropriate pipeline
class based on the model configuration and instantiates it.
Parameters
----------
pretrained_model_name_or_path
Path or identifier for the pretrained model. Can be:
- A local directory path containing model files
- An S3 URI (s3://bucket/prefix)
- A HuggingFace Hub model identifier (e.g., "amazon/chronos-t5-small")
*model_args
Additional positional arguments passed to the model constructor.
force_s3_download
When True, forces re-downloading from S3 even if cached locally.
Only applicable for S3 URIs. Default is False.
**kwargs
Additional keyword arguments passed to AutoConfig and the model
constructor. Common options include:
- torch_dtype: Data type for model weights ("auto", "float32", "bfloat16")
- device_map: Device placement strategy for model layers
- Other transformers AutoConfig and AutoModel arguments
Returns
-------
BaseChronosPipeline
An instance of the appropriate pipeline subclass (ChronosPipeline,
ChronosBoltPipeline, or Chronos2Pipeline) based on the model configuration.
Raises
------
ValueError
If the configuration is not a valid Chronos config or if the
specified pipeline class is not recognized.
ImportError
If required dependencies are not installed.
Notes
-----
The method reads the model configuration to determine which pipeline
class to instantiate. The configuration must contain either a
`chronos_pipeline_class` or `chronos_config` attribute.
For S3 URIs, the model is first downloaded to a local cache directory
before loading.
The torch_dtype parameter can be specified as a string ("float32", "bfloat16")
or as a torch dtype object. When set to "auto", the dtype is determined
from the model configuration.
"""
if str(pretrained_model_name_or_path).startswith("s3://"):
from .boto_utils import cache_model_from_s3