diff --git a/pyproject.toml b/pyproject.toml index b42c37f..ae6cad6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ "torch>=2.0,<3", - "transformers>=4.49,<5", + "transformers>=4.41,<5", "accelerate>=0.34,<2", "numpy>=1.21,<3", "einops>=0.7.0,<1", diff --git a/src/chronos/__about__.py b/src/chronos/__about__.py index 8c0d5d5..159d48b 100644 --- a/src/chronos/__about__.py +++ b/src/chronos/__about__.py @@ -1 +1 @@ -__version__ = "2.0.0" +__version__ = "2.0.1" diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 1b60a8f..f0b910d 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline): def __init__(self, model: ChronosBoltModelForForecasting): super().__init__(inner_model=model) # type: ignore self.model = model + self.model_context_length: int = self.model.config.chronos_config["context_length"] + self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"] @property def quantiles(self) -> List[float]: @@ -487,14 +489,12 @@ class ChronosBoltPipeline(BaseChronosPipeline): """ context_tensor = self._prepare_and_validate_context(context=inputs) - model_context_length: int = self.model.config.chronos_config["context_length"] - model_prediction_length: int = self.model.config.chronos_config["prediction_length"] if prediction_length is None: - prediction_length = model_prediction_length + prediction_length = self.model_prediction_length - if prediction_length > model_prediction_length: + if prediction_length > self.model_prediction_length: msg = ( - f"We recommend keeping prediction length <= {model_prediction_length}. " + f"We recommend keeping prediction length <= {self.model_prediction_length}. " "The quality of longer predictions may degrade since the model is not optimized for it. " ) if limit_prediction_length: @@ -507,32 +507,46 @@ class ChronosBoltPipeline(BaseChronosPipeline): # We truncate the context here because otherwise batches with very long # context could take up large amounts of GPU memory unnecessarily. - if context_tensor.shape[-1] > model_context_length: - context_tensor = context_tensor[..., -model_context_length:] + if context_tensor.shape[-1] > self.model_context_length: + context_tensor = context_tensor[..., -self.model_context_length :] - # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast - # horizon that the model was trained with (i.e., 64). This results in variance collapsing - # every 64 steps. - context_tensor = context_tensor.to( - device=self.model.device, - dtype=torch.float32, - ) - while remaining > 0: - with torch.no_grad(): - prediction = self.model( - context=context_tensor, - ).quantile_preds.to(context_tensor) + context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32) + # First block prediction + with torch.no_grad(): + prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor) predictions.append(prediction) remaining -= prediction.shape[-1] - if remaining <= 0: - break + # NOTE: The following heuristic for better prediction intervals with long-horizon forecasts + # uses all quantiles generated by the model for the first `model_prediction_length` steps, + # concatenating each quantile with the context and generating the next `model_prediction_length` steps. + # The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles` + # by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles` + # when the `prediction_length` is greater than `model_prediction_length`. - central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin() - central_prediction = prediction[:, central_idx] + if remaining > 0: + # Expand the context along quantile axis + context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1) - context_tensor = torch.cat([context_tensor, central_prediction], dim=-1) + quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device) + while remaining > 0: + # Append the prediction to context + context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :] + (batch_size, n_quantiles, context_length) = context_tensor.shape + + with torch.no_grad(): + # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length) + prediction = self.model( + context=context_tensor.reshape(batch_size * n_quantiles, context_length) + ).quantile_preds.to(context_tensor) + # Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length) + prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1) + # Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length) + prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1) + + predictions.append(prediction) + remaining -= prediction.shape[-1] return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")