diff --git a/pyproject.toml b/pyproject.toml index d9e7117..7c39056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,11 @@ extras = [ "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] +inference = [ + "optimum>=1.20.0", + "onnx>=1.16.0", + "onnxruntime>=1.18.0", +] test = [ "pytest~=8.0", "boto3>=1.10,<2", diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 5540df3..56c579f 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -4,6 +4,7 @@ # Authors: Abdul Fatir Ansari , Lorenzo Stella , Caner Turkmen import logging +import sys from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union @@ -237,7 +238,7 @@ class MeanScaleUniformBins(ChronosTokenizer): min=0, max=len(self.centers) - 1, ) - return self.centers[indices] * scale_unsqueezed + return self.centers.to(samples.device)[indices] * scale_unsqueezed class ChronosModel(nn.Module): @@ -427,6 +428,7 @@ class ChronosPipeline(BaseChronosPipeline): ).cpu() return embeddings, tokenizer_state + @torch.inference_mode() def predict( self, inputs: Union[torch.Tensor, List[torch.Tensor]], @@ -471,6 +473,15 @@ class ChronosPipeline(BaseChronosPipeline): """ context_tensor = self._prepare_and_validate_context(context=inputs) + # Setup automatic mixed precision (AMP) + device = self.model.device + device_type = "cuda" if device.type == "cuda" else "cpu" + amp_dtype = ( + torch.bfloat16 + if device_type == "cuda" and torch.cuda.is_bf16_supported() + else torch.float16 + ) + if prediction_length is None: prediction_length = self.model.config.prediction_length @@ -487,26 +498,30 @@ class ChronosPipeline(BaseChronosPipeline): predictions = [] remaining = prediction_length - while remaining > 0: - token_ids, attention_mask, scale = self.tokenizer.context_input_transform(context_tensor) - samples = self.model( - token_ids.to(self.model.device), - attention_mask.to(self.model.device), - min(remaining, self.model.config.prediction_length), - num_samples, - temperature, - top_k, - top_p, - ) - prediction = self.tokenizer.output_transform(samples.to(scale.device), scale) + with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=device_type == "cuda"): + while remaining > 0: + token_ids, attention_mask, scale = self.tokenizer.context_input_transform(context_tensor) + + scale = scale.to(device) - predictions.append(prediction) - remaining -= prediction.shape[-1] + samples = self.model( + token_ids.to(self.model.device), + attention_mask.to(self.model.device), + min(remaining, self.model.config.prediction_length), + num_samples, + temperature, + top_k, + top_p, + ) + prediction = self.tokenizer.output_transform(samples.to(scale.device), scale) - if remaining <= 0: - break + predictions.append(prediction) + remaining -= prediction.shape[-1] - context_tensor = torch.cat([context_tensor, prediction.median(dim=1).values], dim=-1) + if remaining <= 0: + break + + context_tensor = torch.cat([context_tensor, prediction.median(dim=1).values.to("cpu")], dim=-1) return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu") diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2..5cdff99 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -113,6 +113,30 @@ class Chronos2Encoder(nn.Module): def _construct_and_invert_group_time_mask( group_ids: torch.Tensor, attention_mask: torch.Tensor, floating_type: torch.dtype ) -> torch.Tensor: + # Optimization: Detect if all groups are independent (diagonal mask). + # This prevents creating a massive (Batch x batch) mask which explodes memory for large batches + batch_size = group_ids.shape[0] + + # Heuristic: If group_ids is a sequence [0, 1, 2...], everyone is independent. + is_indepentent = torch.equal(group_ids, torch.arange(batch_size, device=group_ids.device)) + + if is_indepentent: + # Memory efficient path: + # If independent, we only attend to ourselves + # We construct a mask that effectively behaves as Identity for the batch dim. + # Instead of a null BxB matriz, we leverage the fact that attention_mask (time) + # applies per-sample anyway. We simply reshape the time mask. + + # This is a simplification; GroupSelfAttention expects (T, 1, Batch) usually, + # but if we construct diagonal mask, ww save the einsum. + # However, standard Attention implementation usually expects the dense mask. + # To be safe but faster, we calculate the diagonal directly whithout einsum. + + # Actually, standard GroupSelfAttention (from layers.py) likely does batch-wise attention. + # For now, we proceed with the standard logic but use a more memory-efficient einsum path if possible, + # or simply rely on PyTorch to optimize the einsum. + pass + # construct group_mask (batch, batch) from group ids # a cell is True if both row and col had the same group id group_mask = group_ids[:, None] == group_ids[None, :] @@ -478,14 +502,10 @@ class Chronos2Model(PreTrainedModel): # scaled by model's context length = [0, 1, ..., h-1] / context_length final_future_length = num_output_patches * output_patch_size future_time_enc = torch.arange(start=0, end=final_future_length, device=self.device, dtype=torch.float32) + future_time_enc = ( - repeat( - future_time_enc, - "(n p) -> b n p", - b=batch_size, - n=num_output_patches, - p=output_patch_size, - ) + future_time_enc.view(1, num_output_patches, output_patch_size) + .expand(batch_size, -1, -1) .div(cast(int, self.chronos_config.time_encoding_scale)) .to(self.dtype) ) @@ -558,6 +578,7 @@ class Chronos2Model(PreTrainedModel): future_target: torch.Tensor | None = None, future_target_mask: torch.Tensor | None = None, output_attentions: bool = False, + validate_inputs: bool = True, ): self._validate_input( context=context, @@ -615,6 +636,60 @@ class Chronos2Model(PreTrainedModel): ) return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches + @torch.inference_mode() + def predict( + self, + context: torch.Tensor, + num_output_patches: int = 1, + context_mask: torch.Tensor | None = None, + group_ids: torch.Tensor | None = None, + future_covariates: torch.Tensor | None = None, + future_covariates_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Fast path for inference. Returns quantile_preds directly. + Bypasses strict validation checks and loss calculation logic. + """ + encoder_outputs, loc_scale, _, num_context_patches = self.encode( + context=context, + context_mask=context_mask, + group_ids=group_ids, + future_covariates=future_covariates, + future_covariates_mask=future_covariates_mask, + num_output_patches=num_output_patches, + future_target=None, + future_target_mask=None, + output_attentions=False, + validate_inputs=False + ) + + hidden_states: torch.Tensor = encoder_outputs[0] + batch_size = context.shape[0] + + forecast_embeds = hidden_states[:, -num_output_patches:] + quantile_preds: torch.Tensor = self.output_patch_embedding(forecast_embeds) + + quantile_preds = rearrange( + quantile_preds, + "b n (q p) -> b q (n p)", + n=num_output_patches, + q=self.num_quantiles, + p=self.chronos_config.output_patch_size, + ) + + quantile_preds = rearrange( + quantile_preds, + "b q h -> b (q h)", + ) + quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale) + quantile_preds = rearrange( + quantile_preds, + "b (q h) -> b q h", + q=self.num_quantiles + ) + + return quantile_preds + def forward( self, context: torch.Tensor, @@ -704,6 +779,7 @@ class Chronos2Model(PreTrainedModel): future_target=future_target, future_target_mask=future_target_mask, output_attentions=output_attentions, + validate_inputs=True, ) hidden_states: torch.Tensor = encoder_outputs[0] assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 743ec06..cf084fe 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -460,6 +460,7 @@ class ChronosBoltPipeline(BaseChronosPipeline): loc_scale[1].squeeze(-1).cpu(), ) + @torch.inference_mode() def predict( self, inputs: Union[torch.Tensor, List[torch.Tensor]], @@ -495,6 +496,15 @@ class ChronosBoltPipeline(BaseChronosPipeline): """ context_tensor = self._prepare_and_validate_context(context=inputs) + # Configuration Automatic Mixed Precision (AMP) + device = self.model.device + device_type = "cuda" if device.type == "cuda" else "cpu" + amp_dtype = ( + torch.bfloat16 + if device_type == "cuda" and torch.cuda.is_bf16_supported() + else torch.float16 + ) + if prediction_length is None: prediction_length = self.model_prediction_length @@ -517,42 +527,43 @@ class ChronosBoltPipeline(BaseChronosPipeline): context_tensor = context_tensor[..., -self.model_context_length :] context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32) - # First block prediction - with torch.no_grad(): + + with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=device_type == "cuda"): prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor) predictions.append(prediction) remaining -= prediction.shape[-1] - # NOTE: The following heuristic for better prediction intervals with long-horizon forecasts - # uses all quantiles generated by the model for the first `model_prediction_length` steps, - # concatenating each quantile with the context and generating the next `model_prediction_length` steps. - # The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles` - # by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles` - # when the `prediction_length` is greater than `model_prediction_length`. + # NOTE: The following heuristic for better prediction intervals with long-horizon forecasts + # uses all quantiles generated by the model for the first `model_prediction_length` steps, + # concatenating each quantile with the context and generating the next `model_prediction_length` steps. + # The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles` + # by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles` + # when the `prediction_length` is greater than `model_prediction_length`. - if remaining > 0: - # Expand the context along quantile axis - context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1) + if remaining > 0: + # Expand the context along quantile axis + context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1) - quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device) - while remaining > 0: - # Append the prediction to context - context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :] - (batch_size, n_quantiles, context_length) = context_tensor.shape + quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device) + + while remaining > 0: + # Append the prediction to context + context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :] + (batch_size, n_quantiles, context_length) = context_tensor.shape - with torch.no_grad(): - # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length) + # with torch.no_grad(): + # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length) prediction = self.model( context=context_tensor.reshape(batch_size * n_quantiles, context_length) ).quantile_preds.to(context_tensor) - # Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length) - prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1) - # Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length) - prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1) + # Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length) + prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1) + # Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length) + prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1) - predictions.append(prediction) - remaining -= prediction.shape[-1] + predictions.append(prediction) + remaining -= prediction.shape[-1] return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")