mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Chronos-2
This commit is contained in:
parent
51d5abea81
commit
1fea716236
1 changed files with 127 additions and 39 deletions
|
|
@ -37,10 +37,26 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Chronos2Pipeline(BaseChronosPipeline):
|
||||
"""
|
||||
Pipeline for the Chronos-2 model.
|
||||
|
||||
See Also
|
||||
--------
|
||||
ChronosPipeline: Sample-based forecasting with scaling and quantization based tokenization
|
||||
ChronosBoltPipeline: Quantile-based forecasting with patching
|
||||
"""
|
||||
forecast_type: ForecastType = ForecastType.QUANTILES
|
||||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: Chronos2Model):
|
||||
"""
|
||||
Initialize the Chronos-2 pipeline with a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model
|
||||
A pretrained Chronos2Model instance
|
||||
"""
|
||||
super().__init__(inner_model=model)
|
||||
self.model = model
|
||||
|
||||
|
|
@ -55,13 +71,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
quantile_levels : torch.Tensor
|
||||
quantile_levels
|
||||
The quantile levels, must be strictly in (0, 1)
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The normalized probability mass per quantile
|
||||
The normalized probability mass per quantile
|
||||
"""
|
||||
assert quantile_levels.ndim == 1
|
||||
assert quantile_levels.min() > 0.0 and quantile_levels.max() < 1.0
|
||||
|
|
@ -75,22 +90,57 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
@property
|
||||
def model_context_length(self) -> int:
|
||||
"""
|
||||
Maximum number of time steps the model can use as context.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Maximum context length supported by the model
|
||||
"""
|
||||
return self.model.chronos_config.context_length
|
||||
|
||||
@property
|
||||
def model_output_patch_size(self) -> int:
|
||||
"""
|
||||
Size of each output patch produced by the model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Output patch size
|
||||
"""
|
||||
return self.model.chronos_config.output_patch_size
|
||||
|
||||
@property
|
||||
def model_prediction_length(self) -> int:
|
||||
"""
|
||||
Default prediction horizon for the model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Default prediction horizon (max_output_patches * output_patch_size)
|
||||
"""
|
||||
return self.model.chronos_config.max_output_patches * self.model.chronos_config.output_patch_size
|
||||
|
||||
@property
|
||||
def quantiles(self) -> list[float]:
|
||||
"""
|
||||
Quantile levels the model was trained to predict.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List of quantile levels
|
||||
"""
|
||||
return self.model.chronos_config.quantiles
|
||||
|
||||
@property
|
||||
def max_output_patches(self) -> int:
|
||||
"""
|
||||
Maximum number of output patches the model can generate in a single forward pass.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Maximum number of output patches
|
||||
"""
|
||||
return self.model.chronos_config.max_output_patches
|
||||
|
||||
def fit(
|
||||
|
|
@ -764,9 +814,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
**predict_kwargs,
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Generate quantile and mean forecasts for given time series.
|
||||
|
||||
Refer to ``Chronos2Pipeline.predict`` for shared parameters.
|
||||
|
||||
Additional parameters
|
||||
Additional Parameters
|
||||
---------------------
|
||||
quantile_levels
|
||||
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
|
||||
|
|
@ -774,11 +826,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
Returns
|
||||
-------
|
||||
quantiles
|
||||
A list of torch tensors containing quantile forecasts. Each element of the list has shape (n_variates, prediction_length, len(quantile_levels))
|
||||
and the number of elements are equal to the number of target time series (univariate or multivariate) in the `inputs`.
|
||||
A list of torch tensors containing quantile forecasts. Each element has shape (n_variates, prediction_length, len(quantile_levels))
|
||||
and the number of elements equals the number of target time series (univariate or multivariate) in the inputs.
|
||||
mean
|
||||
A list of torch tensors containing containing mean (point) forecasts. Each element of the list has shape (n_variates, prediction_length)
|
||||
and the number of elements are equal to the number of target time series (univariate or multivariate) in the `inputs`.
|
||||
A list of torch tensors containing mean (point) forecasts. Each element has shape (n_variates, prediction_length)
|
||||
and the number of elements equals the number of target time series (univariate or multivariate) in the inputs.
|
||||
"""
|
||||
training_quantile_levels = self.quantiles
|
||||
|
||||
|
|
@ -840,29 +892,31 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
Future covariates data with an id column, a timestamp, and any number of covariate columns,
|
||||
all of these columns will be treated as known future covariates
|
||||
id_column
|
||||
The name of the column which contains the unique time series identifiers, by default "item_id"
|
||||
The name of the column which contains the unique time series identifiers
|
||||
timestamp_column
|
||||
The name of the column which contains timestamps, by default "timestamp"
|
||||
All time series in the dataframe must have regular timestamps with the same frequency (no gaps)
|
||||
The name of the column which contains timestamps. All time series in the dataframe must have
|
||||
regular timestamps with the same frequency (no gaps)
|
||||
target
|
||||
The name of the column(s) which contain the target variables to be forecasted, by default "target"
|
||||
The name of the column(s) which contain the target variables to be forecasted
|
||||
prediction_length
|
||||
Number of steps to predict for each time series
|
||||
quantile_levels
|
||||
Quantile levels to compute
|
||||
batch_size
|
||||
The batch size used for prediction. Note that the batch size here means the number of time series, including target(s) and covariates,
|
||||
which are input into the model. If your data has multiple target and/or covariates, the effective number of time series tasks in a batch
|
||||
will be lower than this value, by default 256
|
||||
The batch size used for prediction. Note that the batch size here means the number of time series,
|
||||
including target(s) and covariates, which are input into the model. If your data has multiple target
|
||||
and/or covariates, the effective number of time series tasks in a batch will be lower than this value
|
||||
context_length
|
||||
The maximum context length used during for inference, by default set to the model's default context length
|
||||
The maximum context length used during inference, by default set to the model's default context length
|
||||
cross_learning
|
||||
If True, cross-learning is enabled, i.e., all the tasks in `inputs` will be predicted jointly and the model will share information across all inputs, by default False
|
||||
The following must be noted when using cross-learning:
|
||||
If True, cross-learning is enabled, i.e., all the tasks in inputs will be predicted jointly and the
|
||||
model will share information across all inputs. The following must be noted when using cross-learning:
|
||||
- Cross-learning doesn't always improve forecast accuracy and must be tested for individual use cases.
|
||||
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they deviate from the maximum group size used during pretraining.
|
||||
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
|
||||
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
|
||||
- Results become dependent on batch size. Very large batch sizes may not provide benefits as they
|
||||
deviate from the maximum group size used during pretraining. For optimal results, consider using a
|
||||
batch size around 100 (as used in the Chronos-2 technical report).
|
||||
- Cross-learning is most helpful when individual time series have limited historical context, as the
|
||||
model can leverage patterns from related series in the batch.
|
||||
validate_inputs
|
||||
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
|
||||
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
|
||||
|
|
@ -871,12 +925,12 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
Returns
|
||||
-------
|
||||
The forecasts dataframe generated by the model with the following columns
|
||||
- `id_column`: The time series ID
|
||||
- `timestamp_column`: Future timestamps
|
||||
The forecasts dataframe generated by the model with the following columns:
|
||||
- id_column: The time series ID
|
||||
- timestamp_column: Future timestamps
|
||||
- "target_name": The name of the target column
|
||||
- "predictions": The point predictions generated by the model
|
||||
- One column for predictions at each quantile level in `quantile_levels`
|
||||
- One column for predictions at each quantile level in quantile_levels
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
|
@ -1099,24 +1153,26 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
----------
|
||||
inputs
|
||||
The time series to get embeddings for, can be one of:
|
||||
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
|
||||
will be shared among the different variates of each time series in the batch.
|
||||
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
|
||||
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
|
||||
will be applied, if needed.
|
||||
- A 3-dimensional torch.Tensor or np.ndarray of shape (batch, n_variates, history_length). When n_variates > 1,
|
||||
information will be shared among the different variates of each time series in the batch.
|
||||
- A list of torch.Tensor or np.ndarray where each element can either be 1-dimensional of shape (history_length,)
|
||||
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements;
|
||||
left-padding will be applied, if needed.
|
||||
batch_size
|
||||
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
|
||||
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
|
||||
The batch size used for generating embeddings. Note that the batch size here means the total number of time series
|
||||
which are input into the model. If your data has multiple variates, the effective number of time series tasks in a
|
||||
batch will be lower than this value
|
||||
context_length
|
||||
The maximum context length used during for inference, by default set to the model's default context length
|
||||
The maximum context length used during inference, by default set to the model's default context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
embeddings
|
||||
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
|
||||
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
|
||||
A list of torch.Tensor where each element has shape (n_variates, num_patches + 2, d_model) and the number of
|
||||
elements equals the number of target time series (univariate or multivariate) in the inputs. The extra +2 is due
|
||||
to embeddings of the [REG] token and a masked output patch token.
|
||||
loc_scale
|
||||
a list of tuples with the mean and standard deviation of each time series.
|
||||
A list of tuples with the mean and standard deviation of each time series.
|
||||
"""
|
||||
if context_length is None:
|
||||
context_length = self.model_context_length
|
||||
|
|
@ -1171,8 +1227,31 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
"""
|
||||
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
|
||||
Load the model from a local path, S3 prefix, or HuggingFace Hub.
|
||||
|
||||
Supports loading base models and LoRA adapters. When loading a LoRA adapter,
|
||||
it will be automatically merged with the base model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pretrained_model_name_or_path
|
||||
Path to the pretrained model. Can be:
|
||||
- A local directory path
|
||||
- An S3 URI (s3://...)
|
||||
- A HuggingFace Hub model ID
|
||||
*args
|
||||
Additional positional arguments passed to AutoConfig and AutoModel
|
||||
**kwargs
|
||||
Additional keyword arguments passed to AutoConfig and AutoModel
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Chronos2Pipeline instance with the loaded model
|
||||
|
||||
Notes
|
||||
-----
|
||||
Supports the same arguments as AutoConfig and AutoModel from transformers.
|
||||
When loading LoRA adapters, the peft library must be installed.
|
||||
"""
|
||||
|
||||
# Check if the model is on S3 and cache it locally first
|
||||
|
|
@ -1209,6 +1288,15 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
|
||||
"""
|
||||
Save the underlying model to a local directory or to HuggingFace Hub.
|
||||
Save the underlying model to a local directory or HuggingFace Hub.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_directory
|
||||
Directory where the model will be saved
|
||||
*args
|
||||
Additional positional arguments passed to the model's save_pretrained method
|
||||
**kwargs
|
||||
Additional keyword arguments passed to the model's save_pretrained method
|
||||
"""
|
||||
self.model.save_pretrained(save_directory, *args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue