Backports for v2.0.1 and version bump (#369)

*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-11-06 13:18:26 +01:00 committed by GitHub
parent 7a8427d18e
commit 1a2498f238
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 40 additions and 26 deletions

View file

@ -15,7 +15,7 @@ license = { file = "LICENSE" }
requires-python = ">=3.10"
dependencies = [
"torch>=2.0,<3",
"transformers>=4.49,<5",
"transformers>=4.41,<5",
"accelerate>=0.34,<2",
"numpy>=1.21,<3",
"einops>=0.7.0,<1",

View file

@ -1 +1 @@
__version__ = "2.0.0"
__version__ = "2.0.1"

View file

@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline):
def __init__(self, model: ChronosBoltModelForForecasting):
super().__init__(inner_model=model) # type: ignore
self.model = model
self.model_context_length: int = self.model.config.chronos_config["context_length"]
self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
@property
def quantiles(self) -> List[float]:
@ -487,14 +489,12 @@ class ChronosBoltPipeline(BaseChronosPipeline):
"""
context_tensor = self._prepare_and_validate_context(context=inputs)
model_context_length: int = self.model.config.chronos_config["context_length"]
model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
if prediction_length is None:
prediction_length = model_prediction_length
prediction_length = self.model_prediction_length
if prediction_length > model_prediction_length:
if prediction_length > self.model_prediction_length:
msg = (
f"We recommend keeping prediction length <= {model_prediction_length}. "
f"We recommend keeping prediction length <= {self.model_prediction_length}. "
"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
@ -507,32 +507,46 @@ class ChronosBoltPipeline(BaseChronosPipeline):
# We truncate the context here because otherwise batches with very long
# context could take up large amounts of GPU memory unnecessarily.
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
if context_tensor.shape[-1] > self.model_context_length:
context_tensor = context_tensor[..., -self.model_context_length :]
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
# every 64 steps.
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
while remaining > 0:
with torch.no_grad():
prediction = self.model(
context=context_tensor,
).quantile_preds.to(context_tensor)
context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32)
# First block prediction
with torch.no_grad():
prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
# NOTE: The following heuristic for better prediction intervals with long-horizon forecasts
# uses all quantiles generated by the model for the first `model_prediction_length` steps,
# concatenating each quantile with the context and generating the next `model_prediction_length` steps.
# The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles`
# by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles`
# when the `prediction_length` is greater than `model_prediction_length`.
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
central_prediction = prediction[:, central_idx]
if remaining > 0:
# Expand the context along quantile axis
context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1)
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device)
while remaining > 0:
# Append the prediction to context
context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :]
(batch_size, n_quantiles, context_length) = context_tensor.shape
with torch.no_grad():
# Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)
prediction = self.model(
context=context_tensor.reshape(batch_size * n_quantiles, context_length)
).quantile_preds.to(context_tensor)
# Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length)
prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1)
# Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length)
prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1)
predictions.append(prediction)
remaining -= prediction.shape[-1]
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")