diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 2a92f66..b3c08cf 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -354,7 +354,7 @@ class ChronosModel(nn.Module): class ChronosPipeline(BaseChronosPipeline): """ - Forecasting pipeline for the Chronos model. + Pipeline for the Chronos model. Parameters ---------- diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 743ec06..68fcaa8 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -401,10 +401,41 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): class ChronosBoltPipeline(BaseChronosPipeline): + """ + Pipeline for the Chronos-Bolt model. + + Parameters + ---------- + model + ChronosBoltModelForForecasting instance containing the pretrained model. + + Attributes + ---------- + model + The underlying forecasting model + forecast_type + Set to ForecastType.QUANTILES indicating this pipeline produces quantiles + default_context_length + Default context length of 2048 time steps + + See Also + -------- + ChronosPipeline : Sample-based forecasting with tokenization + Chronos2Pipeline : Advanced forecasting with covariates support + """ forecast_type: ForecastType = ForecastType.QUANTILES default_context_length: int = 2048 def __init__(self, model: ChronosBoltModelForForecasting): + """ + Initialize the ChronosBoltPipeline with a pretrained model. + + Parameters + ---------- + model + ChronosBoltModelForForecasting instance containing the pretrained + transformer model configured for quantile forecasting. + """ super().__init__(inner_model=model) # type: ignore self.model = model @@ -425,24 +456,40 @@ class ChronosBoltPipeline(BaseChronosPipeline): self, context: Union[torch.Tensor, List[torch.Tensor]] ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - Get encoder embeddings for the given time series. - + Extract encoder embeddings for the given time series. + + This method processes the input time series through patching and instance + normalization, then extracts encoder embeddings that can be used for + downstream tasks like clustering, classification, or similarity search. + Parameters ---------- context - Input series. This is either a 1D tensor, or a list - of 1D tensors, or a 2D tensor whose first dimension - is batch. In the latter case, use left-padding with - ``torch.nan`` to align series of different lengths. - + Input time series. Can be a 1D tensor (single series), a list + of 1D tensors (multiple series of varying lengths), or a 2D tensor + where the first dimension is batch size. For 2D tensors, use + left-padding with torch.nan to align series of different lengths. + Returns ------- - embeddings, loc_scale - A tuple of two items: the encoder embeddings and the loc_scale, - i.e., the mean and std of the original time series. - The encoder embeddings are shaped (batch_size, num_patches + 1, d_model), - where num_patches is the number of patches in the time series - and the extra 1 is for the [REG] token (if used by the model). + embeddings + Encoder embeddings with shape (batch_size, num_patches + 1, d_model), + where num_patches is the number of patches created from the input + time series, and the extra 1 is for the [REG] token if used by the model. + Returned on CPU in the model's dtype. + loc_scale + Tuple of (location, scale) tensors used for instance normalization, + representing the mean and standard deviation of the original time series. + Both tensors have shape (batch_size,) and are returned on CPU. + + Notes + ----- + The embeddings are extracted after patching and instance normalization + but before the decoder. They capture the encoded representation of the + input time series in the model's latent space. + + If the input context is longer than the model's context length, it will + be automatically truncated to the most recent time steps. """ context_tensor = self._prepare_and_validate_context(context=context) model_context_length = self.model.config.chronos_config["context_length"] @@ -467,31 +514,58 @@ class ChronosBoltPipeline(BaseChronosPipeline): limit_prediction_length: bool = False, ) -> torch.Tensor: """ - Get forecasts for the given time series. - - Refer to the base method (``BaseChronosPipeline.predict``) - for details on shared parameters. - Additional parameters - --------------------- + Generate quantile forecasts for the given time series. + + This method directly predicts quantiles without generating sample trajectories. + For predictions longer than the model's built-in horizon, it uses an + autoregressive approach that expands the batch size by the number of quantiles + to generate more robust long-horizon forecasts. + + Parameters + ---------- + inputs + Input time series. Can be a 1D tensor (single series), a list + of 1D tensors (multiple series of varying lengths), or a 2D tensor + where the first dimension is batch size. For 2D tensors, use + left-padding with torch.nan to align series of different lengths. + prediction_length + Number of time steps to forecast. If not provided, uses the model's + default prediction length from the configuration. limit_prediction_length - Force prediction length smaller or equal than the - built-in prediction length from the model. False by - default. When true, fail loudly if longer predictions - are requested, otherwise longer predictions are allowed. - + When True, raises an error if prediction_length exceeds the model's + built-in prediction length. When False (default), allows longer + predictions with a warning about potential quality degradation. + Returns ------- torch.Tensor - Forecasts of shape (batch_size, num_quantiles, prediction_length) - where num_quantiles is the number of quantiles the model has been - trained to output. For official Chronos-Bolt models, the value of - num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles. - + Quantile forecasts with shape (batch_size, num_quantiles, prediction_length), + where num_quantiles is the number of quantiles the model was trained on. + For official Chronos-Bolt models, num_quantiles is 9 for quantiles + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]. + Returned in fp32 on CPU. + Raises ------ ValueError - When limit_prediction_length is True and the prediction_length is - greater than model's training prediction_length. + If limit_prediction_length is True and prediction_length exceeds + the model's built-in prediction length. + + Notes + ----- + For predictions longer than the model's built-in horizon, the method uses + an autoregressive approach: + 1. Generate initial quantiles for the first chunk + 2. Expand context by num_quantiles (treating each quantile as a scenario) + 3. Generate next chunk for each scenario + 4. Compute empirical quantiles across all scenarios + 5. Repeat until desired prediction_length is reached + + This approach scales the batch size by num_quantiles for long horizons, + which may require more GPU memory but produces more robust predictions. + + If the input context is longer than the model's context length, it will + be automatically truncated to the most recent time steps. """ context_tensor = self._prepare_and_validate_context(context=inputs) @@ -564,7 +638,57 @@ class ChronosBoltPipeline(BaseChronosPipeline): **predict_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Refer to the base method (``BaseChronosPipeline.predict_quantiles``). + Generate quantile and mean forecasts for given time series. + + This method generates forecasts at the specified quantile levels. If the + requested quantiles match those the model was trained on, they are returned + directly. Otherwise, the method performs interpolation or extrapolation + to obtain the requested quantiles. + + Parameters + ---------- + inputs + Input time series. Can be a 1D tensor (single series), a list + of 1D tensors (multiple series of varying lengths), or a 2D tensor + where the first dimension is batch size. For 2D tensors, use + left-padding with torch.nan to align series of different lengths. + prediction_length + Number of time steps to forecast. If not provided, uses the model's + default prediction length from the configuration. + quantile_levels + List of quantile levels to compute, each between 0 and 1. + Default is [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]. + **predict_kwargs + Additional keyword arguments passed to the predict method, such as + limit_prediction_length. + + Returns + ------- + quantiles + Tensor of quantile forecasts with shape + (batch_size, prediction_length, num_quantiles). + Returned in fp32 on CPU. + mean + Tensor of mean forecasts with shape (batch_size, prediction_length). + This is actually the median (0.5 quantile) from the model's predictions. + Returned in fp32 on CPU. + + Notes + ----- + If the requested quantile_levels are a subset of the model's training + quantiles, they are extracted directly without interpolation. + + If quantile_levels include values outside the range of training quantiles, + the method will extrapolate using the minimum/maximum training quantiles, + which may significantly affect prediction quality. A warning will be issued + in this case. + + The interpolation/extrapolation assumes the model's training quantiles + formed an equidistant grid (e.g., 0.1, 0.2, ..., 0.9), which holds for + official Chronos-Bolt models but may not be true for custom models. + + The mean returned is actually the median (0.5 quantile) from the model's + predictions, not a true mean. """ # shape (batch_size, prediction_length, len(training_quantile_levels)) predictions = ( @@ -609,8 +733,49 @@ class ChronosBoltPipeline(BaseChronosPipeline): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ - Load the model, either from a local path S3 prefix or from the HuggingFace Hub. - Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. + Load a pretrained ChronosBoltPipeline from various sources. + + This method loads a pretrained ChronosBoltPipeline model from a local path, + S3 bucket, or the HuggingFace Hub. It automatically instantiates the + appropriate model architecture based on the configuration. + + Parameters + ---------- + pretrained_model_name_or_path + Path or identifier for the pretrained model. Can be: + - A local directory path containing model files + - An S3 URI (s3://bucket/prefix) + - A HuggingFace Hub model identifier (e.g., "amazon/chronos-bolt-small") + *args + Additional positional arguments passed to AutoConfig and the model constructor. + **kwargs + Additional keyword arguments passed to AutoConfig and the model constructor. + Common options include: + - torch_dtype: Data type for model weights ("auto", "float32", "bfloat16") + - device_map: Device placement strategy for model layers + - Other transformers AutoConfig and model arguments + + Returns + ------- + ChronosBoltPipeline + An instance of ChronosBoltPipeline with the loaded model. + + Raises + ------ + AssertionError + If the configuration is not a valid Chronos config. + + Notes + ----- + For S3 URIs, the method delegates to BaseChronosPipeline.from_pretrained + which handles S3 download and caching. + + The method automatically detects the model architecture from the configuration + and instantiates the appropriate class. If the architecture is not recognized, + it defaults to ChronosBoltModelForForecasting. + + This method supports all arguments accepted by HuggingFace's AutoConfig + and model classes. """ if str(pretrained_model_name_or_path).startswith("s3://"):