mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Chronos-Bolt
This commit is contained in:
parent
11c43c1206
commit
51d5abea81
2 changed files with 200 additions and 35 deletions
|
|
@ -354,7 +354,7 @@ class ChronosModel(nn.Module):
|
|||
|
||||
class ChronosPipeline(BaseChronosPipeline):
|
||||
"""
|
||||
Forecasting pipeline for the Chronos model.
|
||||
Pipeline for the Chronos model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
|||
|
|
@ -401,10 +401,41 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
|
||||
|
||||
class ChronosBoltPipeline(BaseChronosPipeline):
|
||||
"""
|
||||
Pipeline for the Chronos-Bolt model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model
|
||||
ChronosBoltModelForForecasting instance containing the pretrained model.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model
|
||||
The underlying forecasting model
|
||||
forecast_type
|
||||
Set to ForecastType.QUANTILES indicating this pipeline produces quantiles
|
||||
default_context_length
|
||||
Default context length of 2048 time steps
|
||||
|
||||
See Also
|
||||
--------
|
||||
ChronosPipeline : Sample-based forecasting with tokenization
|
||||
Chronos2Pipeline : Advanced forecasting with covariates support
|
||||
"""
|
||||
forecast_type: ForecastType = ForecastType.QUANTILES
|
||||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
"""
|
||||
Initialize the ChronosBoltPipeline with a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model
|
||||
ChronosBoltModelForForecasting instance containing the pretrained
|
||||
transformer model configured for quantile forecasting.
|
||||
"""
|
||||
super().__init__(inner_model=model) # type: ignore
|
||||
self.model = model
|
||||
|
||||
|
|
@ -425,24 +456,40 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Get encoder embeddings for the given time series.
|
||||
|
||||
Extract encoder embeddings for the given time series.
|
||||
|
||||
This method processes the input time series through patching and instance
|
||||
normalization, then extracts encoder embeddings that can be used for
|
||||
downstream tasks like clustering, classification, or similarity search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
|
||||
Input time series. Can be a 1D tensor (single series), a list
|
||||
of 1D tensors (multiple series of varying lengths), or a 2D tensor
|
||||
where the first dimension is batch size. For 2D tensors, use
|
||||
left-padding with torch.nan to align series of different lengths.
|
||||
|
||||
Returns
|
||||
-------
|
||||
embeddings, loc_scale
|
||||
A tuple of two items: the encoder embeddings and the loc_scale,
|
||||
i.e., the mean and std of the original time series.
|
||||
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
|
||||
where num_patches is the number of patches in the time series
|
||||
and the extra 1 is for the [REG] token (if used by the model).
|
||||
embeddings
|
||||
Encoder embeddings with shape (batch_size, num_patches + 1, d_model),
|
||||
where num_patches is the number of patches created from the input
|
||||
time series, and the extra 1 is for the [REG] token if used by the model.
|
||||
Returned on CPU in the model's dtype.
|
||||
loc_scale
|
||||
Tuple of (location, scale) tensors used for instance normalization,
|
||||
representing the mean and standard deviation of the original time series.
|
||||
Both tensors have shape (batch_size,) and are returned on CPU.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The embeddings are extracted after patching and instance normalization
|
||||
but before the decoder. They capture the encoded representation of the
|
||||
input time series in the model's latent space.
|
||||
|
||||
If the input context is longer than the model's context length, it will
|
||||
be automatically truncated to the most recent time steps.
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
model_context_length = self.model.config.chronos_config["context_length"]
|
||||
|
|
@ -467,31 +514,58 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
limit_prediction_length: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get forecasts for the given time series.
|
||||
|
||||
Refer to the base method (``BaseChronosPipeline.predict``)
|
||||
for details on shared parameters.
|
||||
Additional parameters
|
||||
---------------------
|
||||
Generate quantile forecasts for the given time series.
|
||||
|
||||
This method directly predicts quantiles without generating sample trajectories.
|
||||
For predictions longer than the model's built-in horizon, it uses an
|
||||
autoregressive approach that expands the batch size by the number of quantiles
|
||||
to generate more robust long-horizon forecasts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs
|
||||
Input time series. Can be a 1D tensor (single series), a list
|
||||
of 1D tensors (multiple series of varying lengths), or a 2D tensor
|
||||
where the first dimension is batch size. For 2D tensors, use
|
||||
left-padding with torch.nan to align series of different lengths.
|
||||
prediction_length
|
||||
Number of time steps to forecast. If not provided, uses the model's
|
||||
default prediction length from the configuration.
|
||||
limit_prediction_length
|
||||
Force prediction length smaller or equal than the
|
||||
built-in prediction length from the model. False by
|
||||
default. When true, fail loudly if longer predictions
|
||||
are requested, otherwise longer predictions are allowed.
|
||||
|
||||
When True, raises an error if prediction_length exceeds the model's
|
||||
built-in prediction length. When False (default), allows longer
|
||||
predictions with a warning about potential quality degradation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Forecasts of shape (batch_size, num_quantiles, prediction_length)
|
||||
where num_quantiles is the number of quantiles the model has been
|
||||
trained to output. For official Chronos-Bolt models, the value of
|
||||
num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
|
||||
|
||||
Quantile forecasts with shape (batch_size, num_quantiles, prediction_length),
|
||||
where num_quantiles is the number of quantiles the model was trained on.
|
||||
For official Chronos-Bolt models, num_quantiles is 9 for quantiles
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
|
||||
Returned in fp32 on CPU.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
When limit_prediction_length is True and the prediction_length is
|
||||
greater than model's training prediction_length.
|
||||
If limit_prediction_length is True and prediction_length exceeds
|
||||
the model's built-in prediction length.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For predictions longer than the model's built-in horizon, the method uses
|
||||
an autoregressive approach:
|
||||
1. Generate initial quantiles for the first chunk
|
||||
2. Expand context by num_quantiles (treating each quantile as a scenario)
|
||||
3. Generate next chunk for each scenario
|
||||
4. Compute empirical quantiles across all scenarios
|
||||
5. Repeat until desired prediction_length is reached
|
||||
|
||||
This approach scales the batch size by num_quantiles for long horizons,
|
||||
which may require more GPU memory but produces more robust predictions.
|
||||
|
||||
If the input context is longer than the model's context length, it will
|
||||
be automatically truncated to the most recent time steps.
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=inputs)
|
||||
|
||||
|
|
@ -564,7 +638,57 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
**predict_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
|
||||
Generate quantile and mean forecasts for given time series.
|
||||
|
||||
This method generates forecasts at the specified quantile levels. If the
|
||||
requested quantiles match those the model was trained on, they are returned
|
||||
directly. Otherwise, the method performs interpolation or extrapolation
|
||||
to obtain the requested quantiles.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs
|
||||
Input time series. Can be a 1D tensor (single series), a list
|
||||
of 1D tensors (multiple series of varying lengths), or a 2D tensor
|
||||
where the first dimension is batch size. For 2D tensors, use
|
||||
left-padding with torch.nan to align series of different lengths.
|
||||
prediction_length
|
||||
Number of time steps to forecast. If not provided, uses the model's
|
||||
default prediction length from the configuration.
|
||||
quantile_levels
|
||||
List of quantile levels to compute, each between 0 and 1.
|
||||
Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9].
|
||||
**predict_kwargs
|
||||
Additional keyword arguments passed to the predict method, such as
|
||||
limit_prediction_length.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quantiles
|
||||
Tensor of quantile forecasts with shape
|
||||
(batch_size, prediction_length, num_quantiles).
|
||||
Returned in fp32 on CPU.
|
||||
mean
|
||||
Tensor of mean forecasts with shape (batch_size, prediction_length).
|
||||
This is actually the median (0.5 quantile) from the model's predictions.
|
||||
Returned in fp32 on CPU.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the requested quantile_levels are a subset of the model's training
|
||||
quantiles, they are extracted directly without interpolation.
|
||||
|
||||
If quantile_levels include values outside the range of training quantiles,
|
||||
the method will extrapolate using the minimum/maximum training quantiles,
|
||||
which may significantly affect prediction quality. A warning will be issued
|
||||
in this case.
|
||||
|
||||
The interpolation/extrapolation assumes the model's training quantiles
|
||||
formed an equidistant grid (e.g., 0.1, 0.2, ..., 0.9), which holds for
|
||||
official Chronos-Bolt models but may not be true for custom models.
|
||||
|
||||
The mean returned is actually the median (0.5 quantile) from the model's
|
||||
predictions, not a true mean.
|
||||
"""
|
||||
# shape (batch_size, prediction_length, len(training_quantile_levels))
|
||||
predictions = (
|
||||
|
|
@ -609,8 +733,49 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
"""
|
||||
Load the model, either from a local path S3 prefix or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
|
||||
Load a pretrained ChronosBoltPipeline from various sources.
|
||||
|
||||
This method loads a pretrained ChronosBoltPipeline model from a local path,
|
||||
S3 bucket, or the HuggingFace Hub. It automatically instantiates the
|
||||
appropriate model architecture based on the configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pretrained_model_name_or_path
|
||||
Path or identifier for the pretrained model. Can be:
|
||||
- A local directory path containing model files
|
||||
- An S3 URI (s3://bucket/prefix)
|
||||
- A HuggingFace Hub model identifier (e.g., "amazon/chronos-bolt-small")
|
||||
*args
|
||||
Additional positional arguments passed to AutoConfig and the model constructor.
|
||||
**kwargs
|
||||
Additional keyword arguments passed to AutoConfig and the model constructor.
|
||||
Common options include:
|
||||
- torch_dtype: Data type for model weights ("auto", "float32", "bfloat16")
|
||||
- device_map: Device placement strategy for model layers
|
||||
- Other transformers AutoConfig and model arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
ChronosBoltPipeline
|
||||
An instance of ChronosBoltPipeline with the loaded model.
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If the configuration is not a valid Chronos config.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For S3 URIs, the method delegates to BaseChronosPipeline.from_pretrained
|
||||
which handles S3 download and caching.
|
||||
|
||||
The method automatically detects the model architecture from the configuration
|
||||
and instantiates the appropriate class. If the architecture is not recognized,
|
||||
it defaults to ChronosBoltModelForForecasting.
|
||||
|
||||
This method supports all arguments accepted by HuggingFace's AutoConfig
|
||||
and model classes.
|
||||
"""
|
||||
|
||||
if str(pretrained_model_name_or_path).startswith("s3://"):
|
||||
|
|
|
|||
Loading…
Reference in a new issue