mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Update BaseChronosPipeline docstrings
This commit is contained in:
parent
822c773424
commit
d2ff57ff40
1 changed files with 221 additions and 49 deletions
|
|
@ -42,25 +42,84 @@ class PipelineRegistry(type):
|
|||
|
||||
|
||||
class BaseChronosPipeline(metaclass=PipelineRegistry):
|
||||
"""
|
||||
Abstract base class for Chronos pretrained time series forecasting pipelines.
|
||||
|
||||
This class defines the common interface for all Chronos models. The package provides
|
||||
multiple pipeline implementations with different forecasting approaches and architectures:
|
||||
|
||||
- ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
|
||||
- ChronosBoltPipeline: Quantile-based forecasting with patching
|
||||
- Chronos2Pipeline (recommended): Quantile-based forecasting with support for multivariate and covariate-informed forecasting
|
||||
|
||||
Each subclass implements the abstract methods and properties defined here,
|
||||
potentially with different parameter signatures and return types depending
|
||||
on the model architecture and forecasting approach.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
forecast_type
|
||||
Enum indicating whether the pipeline produces samples or quantiles
|
||||
inner_model
|
||||
The underlying HuggingFace transformers model
|
||||
|
||||
See Also
|
||||
--------
|
||||
ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
|
||||
ChronosBoltPipeline: Quantile-based forecasting with patching
|
||||
Chronos2Pipeline (recommended): Quantile-based forecasting with support for multivariate and covariate-informed forecasting
|
||||
"""
|
||||
forecast_type: ForecastType
|
||||
dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
|
||||
|
||||
def __init__(self, inner_model: "PreTrainedModel"):
|
||||
"""
|
||||
Initialize the base pipeline with a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_model : PreTrainedModel
|
||||
A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
|
||||
inner_model
|
||||
A HuggingFace transformers PreTrainedModel that serves as the
|
||||
underlying forecasting model (e.g., T5ForConditionalGeneration)
|
||||
"""
|
||||
# for easy access to the inner HF-style model
|
||||
self.inner_model = inner_model
|
||||
|
||||
@property
|
||||
def model_context_length(self) -> int:
|
||||
"""
|
||||
Maximum number of time steps the model can use as context.
|
||||
|
||||
This is an abstract property that must be implemented by subclasses.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Maximum context length supported by the model
|
||||
|
||||
Notes
|
||||
-----
|
||||
Subclasses must implement this property based on their specific
|
||||
model architecture and configuration.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def model_prediction_length(self) -> int:
|
||||
"""
|
||||
Default prediction horizon for the model.
|
||||
|
||||
This is an abstract property that must be implemented by subclasses.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Default prediction horizon
|
||||
|
||||
Notes
|
||||
-----
|
||||
Subclasses must implement this property based on their specific model architecture and configuration.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _prepare_and_validate_context(self, context: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
|
|
@ -75,25 +134,35 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
|
||||
def predict(self, inputs: Union[torch.Tensor, List[torch.Tensor]], prediction_length: Optional[int] = None):
|
||||
"""
|
||||
Get forecasts for the given time series. Predictions will be
|
||||
returned in fp32 on the cpu.
|
||||
|
||||
Generate forecasts for the given time series.
|
||||
|
||||
This is an abstract method that must be implemented by subclasses.
|
||||
Each subclass may have different parameters and return types depending
|
||||
on the model architecture and forecasting approach. Predictions are
|
||||
typically returned in fp32 on the CPU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
Input time series. Can be a 1D tensor (single series), a list
|
||||
of 1D tensors (multiple series of varying lengths), or a 2D tensor
|
||||
where the first dimension is batch size. For 2D tensors, use
|
||||
left-padding with torch.nan to align series of different lengths.
|
||||
prediction_length
|
||||
Time steps to predict. Defaults to a model-dependent
|
||||
value if not given.
|
||||
|
||||
Number of time steps to forecast. If not provided, defaults to
|
||||
the model's default prediction length.
|
||||
|
||||
Returns
|
||||
-------
|
||||
forecasts
|
||||
Tensor containing forecasts. The layout and meaning
|
||||
of the forecasts values depends on ``self.forecast_type``.
|
||||
torch.Tensor
|
||||
Forecasts tensor. The shape and interpretation depend on the
|
||||
subclass's forecast_type (samples or quantiles).
|
||||
|
||||
Notes
|
||||
-----
|
||||
Subclasses may extend this interface with additional parameters
|
||||
specific to their forecasting approach. Refer to specific subclass
|
||||
documentation for complete parameter lists and return value details.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -105,30 +174,42 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get quantile and mean forecasts for given time series.
|
||||
Predictions will be returned in fp32 on the cpu.
|
||||
|
||||
Generate quantile and mean forecasts for given time series.
|
||||
|
||||
This is an abstract method that must be implemented by subclasses.
|
||||
Each subclass may have different parameters depending on the model
|
||||
architecture. Predictions are typically returned in fp32 on the CPU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : Union[torch.Tensor, List[torch.Tensor]]
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
prediction_length : Optional[int], optional
|
||||
Time steps to predict. Defaults to a model-dependent
|
||||
value if not given.
|
||||
quantile_levels : List[float], optional
|
||||
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
|
||||
|
||||
inputs
|
||||
Input time series. Can be a 1D tensor (single series), a list
|
||||
of 1D tensors (multiple series of varying lengths), or a 2D tensor
|
||||
where the first dimension is batch size. For 2D tensors, use
|
||||
left-padding with torch.nan to align series of different lengths.
|
||||
prediction_length
|
||||
Number of time steps to forecast. If not provided, defaults to
|
||||
the model's default prediction length.
|
||||
quantile_levels
|
||||
List of quantile levels to compute, each between 0 and 1.
|
||||
Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
|
||||
**kwargs
|
||||
Additional keyword arguments that may be used by subclass implementations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quantiles
|
||||
Tensor containing quantile forecasts. Shape
|
||||
Tensor of quantile forecasts with shape
|
||||
(batch_size, prediction_length, num_quantiles)
|
||||
mean
|
||||
Tensor containing mean (point) forecasts. Shape
|
||||
Tensor of mean (point) forecasts with shape
|
||||
(batch_size, prediction_length)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Subclasses may extend this interface with additional parameters
|
||||
specific to their forecasting approach. Refer to specific subclass
|
||||
documentation for complete parameter lists and implementation details.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -145,8 +226,12 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
**predict_kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Perform forecasting on time series data in a long-format pandas DataFrame.
|
||||
|
||||
Generate forecasts for time series data in a pandas DataFrame.
|
||||
|
||||
This method provides a convenient interface for forecasting on long-format
|
||||
pandas DataFrames containing multiple time series. It handles data conversion,
|
||||
batching, and result formatting automatically.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df
|
||||
|
|
@ -171,12 +256,30 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
|
||||
Returns
|
||||
-------
|
||||
The forecasts dataframe generated by the model with the following columns
|
||||
- `id_column`: The time series ID
|
||||
- `timestamp_column`: Future timestamps
|
||||
- "target_name": The name of the target column
|
||||
- "predictions": The point predictions generated by the model
|
||||
- One column for predictions at each quantile level in `quantile_levels`
|
||||
pd.DataFrame
|
||||
Forecast results in long format with the following columns:
|
||||
- Column named by id_column: Time series identifiers
|
||||
- Column named by timestamp_column: Future timestamps
|
||||
- "target_name": Name of the forecasted target variable
|
||||
- "predictions": Point forecasts (mean predictions)
|
||||
- One column per quantile level (e.g., "0.1", "0.5", "0.9")
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If pandas is not installed.
|
||||
ValueError
|
||||
If target is not a string (multivariate forecasting not supported).
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method requires pandas to be installed. Install with `pip install pandas`.
|
||||
|
||||
The method internally converts the DataFrame to tensor format, generates
|
||||
forecasts using predict_quantiles, and converts results back to DataFrame format.
|
||||
|
||||
Subclasses may have additional parameters or behavior. Refer to specific
|
||||
subclass documentation for implementation details.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
|
@ -245,23 +348,43 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
self, task: "fev.Task", batch_size: int = 32, **kwargs
|
||||
) -> tuple[list["datasets.DatasetDict"], float]:
|
||||
"""
|
||||
Make predictions for evaluation on a fev.Task.
|
||||
|
||||
Generate predictions for evaluation on a fev benchmark task.
|
||||
|
||||
This method provides integration with the fev (Forecasting Evaluation)
|
||||
library for standardized benchmark evaluation. It handles batching,
|
||||
timing, and formatting predictions according to the task requirements.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task
|
||||
Benchmark task on which the evaluation should be done.
|
||||
A fev.Task object defining the benchmark evaluation task, including
|
||||
the dataset, horizon, quantile levels, and evaluation metric.
|
||||
batch_size
|
||||
Batch size used during evaluation.
|
||||
Number of time series to process in each batch during inference.
|
||||
Larger batch sizes may improve throughput but require more memory.
|
||||
Default is 32.
|
||||
**kwargs
|
||||
Additional keyword arguments that will be forwarded to `self.predict_quantiles`.
|
||||
|
||||
Additional keyword arguments forwarded to the predict_quantiles method.
|
||||
These may include model-specific parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
predictions_per_window
|
||||
Predictions for each window, each stored as a DatasetDict
|
||||
List of DatasetDict objects, one for each evaluation window in the task.
|
||||
Each DatasetDict contains predictions formatted according to fev requirements.
|
||||
inference_time_s
|
||||
Total time that it took to make predictions for all windows (in seconds).
|
||||
Total inference time in seconds across all windows, excluding data
|
||||
loading and preprocessing time.
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If the fev library is not installed.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method requires the fev library to be installed. Install with
|
||||
`pip install fev`.
|
||||
"""
|
||||
import datasets
|
||||
|
||||
|
|
@ -334,8 +457,57 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load the model, either from a local path, S3 prefix, or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
|
||||
Load a pretrained Chronos pipeline from various sources.
|
||||
|
||||
This class method loads a pretrained model from a local path, S3 bucket,
|
||||
or the HuggingFace Hub. It automatically detects the appropriate pipeline
|
||||
class based on the model configuration and instantiates it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pretrained_model_name_or_path
|
||||
Path or identifier for the pretrained model. Can be:
|
||||
- A local directory path containing model files
|
||||
- An S3 URI (s3://bucket/prefix)
|
||||
- A HuggingFace Hub model identifier (e.g., "amazon/chronos-t5-small")
|
||||
*model_args
|
||||
Additional positional arguments passed to the model constructor.
|
||||
force_s3_download
|
||||
When True, forces re-downloading from S3 even if cached locally.
|
||||
Only applicable for S3 URIs. Default is False.
|
||||
**kwargs
|
||||
Additional keyword arguments passed to AutoConfig and the model
|
||||
constructor. Common options include:
|
||||
- torch_dtype: Data type for model weights ("auto", "float32", "bfloat16")
|
||||
- device_map: Device placement strategy for model layers
|
||||
- Other transformers AutoConfig and AutoModel arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseChronosPipeline
|
||||
An instance of the appropriate pipeline subclass (ChronosPipeline,
|
||||
ChronosBoltPipeline, or Chronos2Pipeline) based on the model configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the configuration is not a valid Chronos config or if the
|
||||
specified pipeline class is not recognized.
|
||||
ImportError
|
||||
If required dependencies are not installed.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The method reads the model configuration to determine which pipeline
|
||||
class to instantiate. The configuration must contain either a
|
||||
`chronos_pipeline_class` or `chronos_config` attribute.
|
||||
|
||||
For S3 URIs, the model is first downloaded to a local cache directory
|
||||
before loading.
|
||||
|
||||
The torch_dtype parameter can be specified as a string ("float32", "bfloat16")
|
||||
or as a torch dtype object. When set to "auto", the dtype is determined
|
||||
from the model configuration.
|
||||
"""
|
||||
if str(pretrained_model_name_or_path).startswith("s3://"):
|
||||
from .boto_utils import cache_model_from_s3
|
||||
|
|
|
|||
Loading…
Reference in a new issue