Chronos-Bolt

This commit is contained in:
Abdul Fatir Ansari 2025-12-19 10:34:42 +01:00
parent 11c43c1206
commit 51d5abea81
2 changed files with 200 additions and 35 deletions

View file

@ -354,7 +354,7 @@ class ChronosModel(nn.Module):
class ChronosPipeline(BaseChronosPipeline):
"""
Forecasting pipeline for the Chronos model.
Pipeline for the Chronos model.
Parameters
----------

View file

@ -401,10 +401,41 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
class ChronosBoltPipeline(BaseChronosPipeline):
"""
Pipeline for the Chronos-Bolt model.
Parameters
----------
model
ChronosBoltModelForForecasting instance containing the pretrained model.
Attributes
----------
model
The underlying forecasting model
forecast_type
Set to ForecastType.QUANTILES indicating this pipeline produces quantiles
default_context_length
Default context length of 2048 time steps
See Also
--------
ChronosPipeline : Sample-based forecasting with tokenization
Chronos2Pipeline : Advanced forecasting with covariates support
"""
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: ChronosBoltModelForForecasting):
"""
Initialize the ChronosBoltPipeline with a pretrained model.
Parameters
----------
model
ChronosBoltModelForForecasting instance containing the pretrained
transformer model configured for quantile forecasting.
"""
super().__init__(inner_model=model) # type: ignore
self.model = model
@ -425,24 +456,40 @@ class ChronosBoltPipeline(BaseChronosPipeline):
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.
Extract encoder embeddings for the given time series.
This method processes the input time series through patching and instance
normalization, then extracts encoder embeddings that can be used for
downstream tasks like clustering, classification, or similarity search.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Input time series. Can be a 1D tensor (single series), a list
of 1D tensors (multiple series of varying lengths), or a 2D tensor
where the first dimension is batch size. For 2D tensors, use
left-padding with torch.nan to align series of different lengths.
Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
embeddings
Encoder embeddings with shape (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches created from the input
time series, and the extra 1 is for the [REG] token if used by the model.
Returned on CPU in the model's dtype.
loc_scale
Tuple of (location, scale) tensors used for instance normalization,
representing the mean and standard deviation of the original time series.
Both tensors have shape (batch_size,) and are returned on CPU.
Notes
-----
The embeddings are extracted after patching and instance normalization
but before the decoder. They capture the encoded representation of the
input time series in the model's latent space.
If the input context is longer than the model's context length, it will
be automatically truncated to the most recent time steps.
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
@ -467,31 +514,58 @@ class ChronosBoltPipeline(BaseChronosPipeline):
limit_prediction_length: bool = False,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Refer to the base method (``BaseChronosPipeline.predict``)
for details on shared parameters.
Additional parameters
---------------------
Generate quantile forecasts for the given time series.
This method directly predicts quantiles without generating sample trajectories.
For predictions longer than the model's built-in horizon, it uses an
autoregressive approach that expands the batch size by the number of quantiles
to generate more robust long-horizon forecasts.
Parameters
----------
inputs
Input time series. Can be a 1D tensor (single series), a list
of 1D tensors (multiple series of varying lengths), or a 2D tensor
where the first dimension is batch size. For 2D tensors, use
left-padding with torch.nan to align series of different lengths.
prediction_length
Number of time steps to forecast. If not provided, uses the model's
default prediction length from the configuration.
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. False by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
When True, raises an error if prediction_length exceeds the model's
built-in prediction length. When False (default), allows longer
predictions with a warning about potential quality degradation.
Returns
-------
torch.Tensor
Forecasts of shape (batch_size, num_quantiles, prediction_length)
where num_quantiles is the number of quantiles the model has been
trained to output. For official Chronos-Bolt models, the value of
num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
Quantile forecasts with shape (batch_size, num_quantiles, prediction_length),
where num_quantiles is the number of quantiles the model was trained on.
For official Chronos-Bolt models, num_quantiles is 9 for quantiles
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
Returned in fp32 on CPU.
Raises
------
ValueError
When limit_prediction_length is True and the prediction_length is
greater than model's training prediction_length.
If limit_prediction_length is True and prediction_length exceeds
the model's built-in prediction length.
Notes
-----
For predictions longer than the model's built-in horizon, the method uses
an autoregressive approach:
1. Generate initial quantiles for the first chunk
2. Expand context by num_quantiles (treating each quantile as a scenario)
3. Generate next chunk for each scenario
4. Compute empirical quantiles across all scenarios
5. Repeat until desired prediction_length is reached
This approach scales the batch size by num_quantiles for long horizons,
which may require more GPU memory but produces more robust predictions.
If the input context is longer than the model's context length, it will
be automatically truncated to the most recent time steps.
"""
context_tensor = self._prepare_and_validate_context(context=inputs)
@ -564,7 +638,57 @@ class ChronosBoltPipeline(BaseChronosPipeline):
**predict_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
Generate quantile and mean forecasts for given time series.
This method generates forecasts at the specified quantile levels. If the
requested quantiles match those the model was trained on, they are returned
directly. Otherwise, the method performs interpolation or extrapolation
to obtain the requested quantiles.
Parameters
----------
inputs
Input time series. Can be a 1D tensor (single series), a list
of 1D tensors (multiple series of varying lengths), or a 2D tensor
where the first dimension is batch size. For 2D tensors, use
left-padding with torch.nan to align series of different lengths.
prediction_length
Number of time steps to forecast. If not provided, uses the model's
default prediction length from the configuration.
quantile_levels
List of quantile levels to compute, each between 0 and 1.
Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
**predict_kwargs
Additional keyword arguments passed to the predict method, such as
limit_prediction_length.
Returns
-------
quantiles
Tensor of quantile forecasts with shape
(batch_size, prediction_length, num_quantiles).
Returned in fp32 on CPU.
mean
Tensor of mean forecasts with shape (batch_size, prediction_length).
This is actually the median (0.5 quantile) from the model's predictions.
Returned in fp32 on CPU.
Notes
-----
If the requested quantile_levels are a subset of the model's training
quantiles, they are extracted directly without interpolation.
If quantile_levels include values outside the range of training quantiles,
the method will extrapolate using the minimum/maximum training quantiles,
which may significantly affect prediction quality. A warning will be issued
in this case.
The interpolation/extrapolation assumes the model's training quantiles
formed an equidistant grid (e.g., 0.1, 0.2, ..., 0.9), which holds for
official Chronos-Bolt models but may not be true for custom models.
The mean returned is actually the median (0.5 quantile) from the model's
predictions, not a true mean.
"""
# shape (batch_size, prediction_length, len(training_quantile_levels))
predictions = (
@ -609,8 +733,49 @@ class ChronosBoltPipeline(BaseChronosPipeline):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load the model, either from a local path S3 prefix or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
Load a pretrained ChronosBoltPipeline from various sources.
This method loads a pretrained ChronosBoltPipeline model from a local path,
S3 bucket, or the HuggingFace Hub. It automatically instantiates the
appropriate model architecture based on the configuration.
Parameters
----------
pretrained_model_name_or_path
Path or identifier for the pretrained model. Can be:
- A local directory path containing model files
- An S3 URI (s3://bucket/prefix)
- A HuggingFace Hub model identifier (e.g., "amazon/chronos-bolt-small")
*args
Additional positional arguments passed to AutoConfig and the model constructor.
**kwargs
Additional keyword arguments passed to AutoConfig and the model constructor.
Common options include:
- torch_dtype: Data type for model weights ("auto", "float32", "bfloat16")
- device_map: Device placement strategy for model layers
- Other transformers AutoConfig and model arguments
Returns
-------
ChronosBoltPipeline
An instance of ChronosBoltPipeline with the loaded model.
Raises
------
AssertionError
If the configuration is not a valid Chronos config.
Notes
-----
For S3 URIs, the method delegates to BaseChronosPipeline.from_pretrained
which handles S3 download and caching.
The method automatically detects the model architecture from the configuration
and instantiates the appropriate class. If the architecture is not recognized,
it defaults to ChronosBoltModelForForecasting.
This method supports all arguments accepted by HuggingFace's AutoConfig
and model classes.
"""
if str(pretrained_model_name_or_path).startswith("s3://"):