mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Update long horizon heuristic for Chronos-Bolt (#366)
*Issue #, if available:*
*Description of changes:* In light of the planned AG dependency on this
package, this PR updates the long horizon heuristic of Chronos-Bolt to
what it was in AG (introduced in
https://github.com/autogluon/autogluon/pull/5177). Code has been
copy-pasted [from
AG](5fcebaa493/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py (L441)).
By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
parent
78bd1c90ca
commit
46b20c2d54
1 changed files with 38 additions and 24 deletions
|
|
@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
super().__init__(inner_model=model) # type: ignore
|
||||
self.model = model
|
||||
self.model_context_length: int = self.model.config.chronos_config["context_length"]
|
||||
self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
|
||||
|
||||
@property
|
||||
def quantiles(self) -> List[float]:
|
||||
|
|
@ -487,14 +489,12 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=inputs)
|
||||
|
||||
model_context_length: int = self.model.config.chronos_config["context_length"]
|
||||
model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
|
||||
if prediction_length is None:
|
||||
prediction_length = model_prediction_length
|
||||
prediction_length = self.model_prediction_length
|
||||
|
||||
if prediction_length > model_prediction_length:
|
||||
if prediction_length > self.model_prediction_length:
|
||||
msg = (
|
||||
f"We recommend keeping prediction length <= {model_prediction_length}. "
|
||||
f"We recommend keeping prediction length <= {self.model_prediction_length}. "
|
||||
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
||||
)
|
||||
if limit_prediction_length:
|
||||
|
|
@ -507,32 +507,46 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
|
||||
# We truncate the context here because otherwise batches with very long
|
||||
# context could take up large amounts of GPU memory unnecessarily.
|
||||
if context_tensor.shape[-1] > model_context_length:
|
||||
context_tensor = context_tensor[..., -model_context_length:]
|
||||
if context_tensor.shape[-1] > self.model_context_length:
|
||||
context_tensor = context_tensor[..., -self.model_context_length :]
|
||||
|
||||
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
||||
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
||||
# every 64 steps.
|
||||
context_tensor = context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
while remaining > 0:
|
||||
with torch.no_grad():
|
||||
prediction = self.model(
|
||||
context=context_tensor,
|
||||
).quantile_preds.to(context_tensor)
|
||||
context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32)
|
||||
# First block prediction
|
||||
with torch.no_grad():
|
||||
prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
||||
if remaining <= 0:
|
||||
break
|
||||
# NOTE: The following heuristic for better prediction intervals with long-horizon forecasts
|
||||
# uses all quantiles generated by the model for the first `model_prediction_length` steps,
|
||||
# concatenating each quantile with the context and generating the next `model_prediction_length` steps.
|
||||
# The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles`
|
||||
# by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles`
|
||||
# when the `prediction_length` is greater than `model_prediction_length`.
|
||||
|
||||
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
|
||||
central_prediction = prediction[:, central_idx]
|
||||
if remaining > 0:
|
||||
# Expand the context along quantile axis
|
||||
context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1)
|
||||
|
||||
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
||||
quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device)
|
||||
while remaining > 0:
|
||||
# Append the prediction to context
|
||||
context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :]
|
||||
(batch_size, n_quantiles, context_length) = context_tensor.shape
|
||||
|
||||
with torch.no_grad():
|
||||
# Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)
|
||||
prediction = self.model(
|
||||
context=context_tensor.reshape(batch_size * n_quantiles, context_length)
|
||||
).quantile_preds.to(context_tensor)
|
||||
# Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length)
|
||||
prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1)
|
||||
# Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length)
|
||||
prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
||||
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue