Chronos-2

This commit is contained in:
Abdul Fatir Ansari 2025-12-19 10:45:44 +01:00
parent 51d5abea81
commit 1fea716236

View file

@ -37,10 +37,26 @@ logger = logging.getLogger(__name__)
class Chronos2Pipeline(BaseChronosPipeline):
"""
Pipeline for the Chronos-2 model.
See Also
--------
ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
ChronosBoltPipeline: Quantile-based forecasting with patching
"""
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: Chronos2Model):
"""
Initialize the Chronos-2 pipeline with a pretrained model.
Parameters
----------
model
A pretrained Chronos2Model instance
"""
super().__init__(inner_model=model)
self.model = model
@ -55,13 +71,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
Parameters
----------
quantile_levels : torch.Tensor
quantile_levels
The quantile levels, must be strictly in (0, 1)
Returns
-------
torch.Tensor
The normalized probability mass per quantile
The normalized probability mass per quantile
"""
assert quantile_levels.ndim == 1
assert quantile_levels.min() > 0.0 and quantile_levels.max() < 1.0
@ -75,22 +90,57 @@ class Chronos2Pipeline(BaseChronosPipeline):
@property
def model_context_length(self) -> int:
"""
Maximum number of time steps the model can use as context.
Returns
-------
Maximum context length supported by the model
"""
return self.model.chronos_config.context_length
@property
def model_output_patch_size(self) -> int:
"""
Size of each output patch produced by the model.
Returns
-------
Output patch size
"""
return self.model.chronos_config.output_patch_size
@property
def model_prediction_length(self) -> int:
"""
Default prediction horizon for the model.
Returns
-------
Default prediction horizon (max_output_patches * output_patch_size)
"""
return self.model.chronos_config.max_output_patches * self.model.chronos_config.output_patch_size
@property
def quantiles(self) -> list[float]:
"""
Quantile levels the model was trained to predict.
Returns
-------
List of quantile levels
"""
return self.model.chronos_config.quantiles
@property
def max_output_patches(self) -> int:
"""
Maximum number of output patches the model can generate in a single forward pass.
Returns
-------
Maximum number of output patches
"""
return self.model.chronos_config.max_output_patches
def fit(
@ -764,9 +814,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
**predict_kwargs,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Generate quantile and mean forecasts for given time series.
Refer to ``Chronos2Pipeline.predict`` for shared parameters.
Additional parameters
Additional Parameters
---------------------
quantile_levels
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
@ -774,11 +826,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
Returns
-------
quantiles
A list of torch tensors containing quantile forecasts. Each element of the list has shape (n_variates, prediction_length, len(quantile_levels))
and the number of elements are equal to the number of target time series (univariate or multivariate) in the `inputs`.
A list of torch tensors containing quantile forecasts. Each element has shape (n_variates, prediction_length, len(quantile_levels))
and the number of elements equals the number of target time series (univariate or multivariate) in the inputs.
mean
A list of torch tensors containing containing mean (point) forecasts. Each element of the list has shape (n_variates, prediction_length)
and the number of elements are equal to the number of target time series (univariate or multivariate) in the `inputs`.
A list of torch tensors containing mean (point) forecasts. Each element has shape (n_variates, prediction_length)
and the number of elements equals the number of target time series (univariate or multivariate) in the inputs.
"""
training_quantile_levels = self.quantiles
@ -840,29 +892,31 @@ class Chronos2Pipeline(BaseChronosPipeline):
Future covariates data with an id column, a timestamp, and any number of covariate columns,
all of these columns will be treated as known future covariates
id_column
The name of the column which contains the unique time series identifiers, by default "item_id"
The name of the column which contains the unique time series identifiers
timestamp_column
The name of the column which contains timestamps, by default "timestamp"
All time series in the dataframe must have regular timestamps with the same frequency (no gaps)
The name of the column which contains timestamps. All time series in the dataframe must have
regular timestamps with the same frequency (no gaps)
target
The name of the column(s) which contain the target variables to be forecasted, by default "target"
The name of the column(s) which contain the target variables to be forecasted
prediction_length
Number of steps to predict for each time series
quantile_levels
Quantile levels to compute
batch_size
The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates,
which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch
will be lower than this value, by default 256
The batch size used for prediction. Note that the batch size here means the number of time series,
including target(s) and covariates, which are input into the model. If your data has multiple target
and/or covariates, the effective number of time series tasks in a batch will be lower than this value
context_length
The maximum context length used during for inference, by default set to the model's default context length
The maximum context length used during inference, by default set to the model's default context length
cross_learning
If True, cross-learning is enabled, i.e., all the tasks in `inputs` will be predicted jointly and the model will share information across all inputs, by default False
The following must be noted when using cross-learning:
If True, cross-learning is enabled, i.e., all the tasks in inputs will be predicted jointly and the
model will share information across all inputs. The following must be noted when using cross-learning:
- Cross-learning doesn't always improve forecast accuracy and must be tested for individual use cases.
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they
deviate from the maximum group size used during pretraining. For optimal results, consider using a
batch size around 100 (as used in the Chronos-2 technical report).
- Cross-learning is most helpful when individual time series have limited historical context, as the
model can leverage patterns from related series in the batch.
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
@ -871,12 +925,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
Returns
-------
The forecasts dataframe generated by the model with the following columns
- `id_column`: The time series ID
- `timestamp_column`: Future timestamps
The forecasts dataframe generated by the model with the following columns:
- id_column: The time series ID
- timestamp_column: Future timestamps
- "target_name": The name of the target column
- "predictions": The point predictions generated by the model
- One column for predictions at each quantile level in `quantile_levels`
- One column for predictions at each quantile level in quantile_levels
"""
try:
import pandas as pd
@ -1099,24 +1153,26 @@ class Chronos2Pipeline(BaseChronosPipeline):
----------
inputs
The time series to get embeddings for, can be one of:
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
will be shared among the different variates of each time series in the batch.
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
will be applied, if needed.
- A 3-dimensional torch.Tensor or np.ndarray of shape (batch, n_variates, history_length). When n_variates > 1,
information will be shared among the different variates of each time series in the batch.
- A list of torch.Tensor or np.ndarray where each element can either be 1-dimensional of shape (history_length,)
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements;
left-padding will be applied, if needed.
batch_size
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
The batch size used for generating embeddings. Note that the batch size here means the total number of time series
which are input into the model. If your data has multiple variates, the effective number of time series tasks in a
batch will be lower than this value
context_length
The maximum context length used during for inference, by default set to the model's default context length
The maximum context length used during inference, by default set to the model's default context length
Returns
-------
embeddings
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
A list of torch.Tensor where each element has shape (n_variates, num_patches + 2, d_model) and the number of
elements equals the number of target time series (univariate or multivariate) in the inputs. The extra +2 is due
to embeddings of the [REG] token and a masked output patch token.
loc_scale
a list of tuples with the mean and standard deviation of each time series.
A list of tuples with the mean and standard deviation of each time series.
"""
if context_length is None:
context_length = self.model_context_length
@ -1171,8 +1227,31 @@ class Chronos2Pipeline(BaseChronosPipeline):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
Load the model from a local path, S3 prefix, or HuggingFace Hub.
Supports loading base models and LoRA adapters. When loading a LoRA adapter,
it will be automatically merged with the base model.
Parameters
----------
pretrained_model_name_or_path
Path to the pretrained model. Can be:
- A local directory path
- An S3 URI (s3://...)
- A HuggingFace Hub model ID
*args
Additional positional arguments passed to AutoConfig and AutoModel
**kwargs
Additional keyword arguments passed to AutoConfig and AutoModel
Returns
-------
A Chronos2Pipeline instance with the loaded model
Notes
-----
Supports the same arguments as AutoConfig and AutoModel from transformers.
When loading LoRA adapters, the peft library must be installed.
"""
# Check if the model is on S3 and cache it locally first
@ -1209,6 +1288,15 @@ class Chronos2Pipeline(BaseChronosPipeline):
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
"""
Save the underlying model to a local directory or to HuggingFace Hub.
Save the underlying model to a local directory or HuggingFace Hub.
Parameters
----------
save_directory
Directory where the model will be saved
*args
Additional positional arguments passed to the model's save_pretrained method
**kwargs
Additional keyword arguments passed to the model's save_pretrained method
"""
self.model.save_pretrained(save_directory, *args, **kwargs)