mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Backports for v2.0.1 and version bump (#369)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
7a8427d18e
commit
1a2498f238
3 changed files with 40 additions and 26 deletions
|
|
@ -15,7 +15,7 @@ license = { file = "LICENSE" }
|
|||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0,<3",
|
||||
"transformers>=4.49,<5",
|
||||
"transformers>=4.41,<5",
|
||||
"accelerate>=0.34,<2",
|
||||
"numpy>=1.21,<3",
|
||||
"einops>=0.7.0,<1",
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__version__ = "2.0.0"
|
||||
__version__ = "2.0.1"
|
||||
|
|
|
|||
|
|
@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
super().__init__(inner_model=model) # type: ignore
|
||||
self.model = model
|
||||
self.model_context_length: int = self.model.config.chronos_config["context_length"]
|
||||
self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
|
||||
|
||||
@property
|
||||
def quantiles(self) -> List[float]:
|
||||
|
|
@ -487,14 +489,12 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=inputs)
|
||||
|
||||
model_context_length: int = self.model.config.chronos_config["context_length"]
|
||||
model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
|
||||
if prediction_length is None:
|
||||
prediction_length = model_prediction_length
|
||||
prediction_length = self.model_prediction_length
|
||||
|
||||
if prediction_length > model_prediction_length:
|
||||
if prediction_length > self.model_prediction_length:
|
||||
msg = (
|
||||
f"We recommend keeping prediction length <= {model_prediction_length}. "
|
||||
f"We recommend keeping prediction length <= {self.model_prediction_length}. "
|
||||
"The quality of longer predictions may degrade since the model is not optimized for it. "
|
||||
)
|
||||
if limit_prediction_length:
|
||||
|
|
@ -507,32 +507,46 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
|
||||
# We truncate the context here because otherwise batches with very long
|
||||
# context could take up large amounts of GPU memory unnecessarily.
|
||||
if context_tensor.shape[-1] > model_context_length:
|
||||
context_tensor = context_tensor[..., -model_context_length:]
|
||||
if context_tensor.shape[-1] > self.model_context_length:
|
||||
context_tensor = context_tensor[..., -self.model_context_length :]
|
||||
|
||||
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
|
||||
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
|
||||
# every 64 steps.
|
||||
context_tensor = context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
while remaining > 0:
|
||||
with torch.no_grad():
|
||||
prediction = self.model(
|
||||
context=context_tensor,
|
||||
).quantile_preds.to(context_tensor)
|
||||
context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32)
|
||||
# First block prediction
|
||||
with torch.no_grad():
|
||||
prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
||||
if remaining <= 0:
|
||||
break
|
||||
# NOTE: The following heuristic for better prediction intervals with long-horizon forecasts
|
||||
# uses all quantiles generated by the model for the first `model_prediction_length` steps,
|
||||
# concatenating each quantile with the context and generating the next `model_prediction_length` steps.
|
||||
# The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles`
|
||||
# by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles`
|
||||
# when the `prediction_length` is greater than `model_prediction_length`.
|
||||
|
||||
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
|
||||
central_prediction = prediction[:, central_idx]
|
||||
if remaining > 0:
|
||||
# Expand the context along quantile axis
|
||||
context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1)
|
||||
|
||||
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
|
||||
quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device)
|
||||
while remaining > 0:
|
||||
# Append the prediction to context
|
||||
context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :]
|
||||
(batch_size, n_quantiles, context_length) = context_tensor.shape
|
||||
|
||||
with torch.no_grad():
|
||||
# Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)
|
||||
prediction = self.model(
|
||||
context=context_tensor.reshape(batch_size * n_quantiles, context_length)
|
||||
).quantile_preds.to(context_tensor)
|
||||
# Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length)
|
||||
prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1)
|
||||
# Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length)
|
||||
prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1)
|
||||
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
||||
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue