mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Merge 4b5e273ced into 32111085d8
This commit is contained in:
commit
9530db24ee
2 changed files with 46 additions and 9 deletions
|
|
@ -371,8 +371,10 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
|
||||
def _prepare_patched_context(
|
||||
self, context: torch.Tensor, context_mask: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
self, context: torch.Tensor, context_mask: torch.Tensor | None = None, return_minmax: bool = False
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | None
|
||||
]:
|
||||
context_mask = (
|
||||
context_mask.to(context.dtype)
|
||||
if context_mask is not None
|
||||
|
|
@ -386,6 +388,14 @@ class Chronos2Model(PreTrainedModel):
|
|||
context_mask = context_mask[..., -self.chronos_config.context_length :]
|
||||
|
||||
# scaling
|
||||
context_minmax = None
|
||||
if return_minmax:
|
||||
context_min = torch.amin(torch.nan_to_num(context, nan=float("inf")), dim=-1, keepdim=True)
|
||||
context_min = torch.nan_to_num(context_min, posinf=0.0)
|
||||
context_max = torch.amax(torch.nan_to_num(context, nan=float("-inf")), dim=-1, keepdim=True)
|
||||
context_max = torch.nan_to_num(context_max, neginf=0.0)
|
||||
context_minmax = context_min, context_max
|
||||
|
||||
context, loc_scale = self.instance_norm(context)
|
||||
|
||||
# scaling is done in 32-bit precision, then the context is moved to model's dtype
|
||||
|
|
@ -420,7 +430,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
# concat time encoding, context and mask along the last (feature) dim
|
||||
patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1)
|
||||
|
||||
return patched_context, attention_mask, loc_scale
|
||||
return patched_context, attention_mask, loc_scale, context_minmax
|
||||
|
||||
def _prepare_patched_future(
|
||||
self,
|
||||
|
|
@ -558,6 +568,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
future_target: torch.Tensor | None = None,
|
||||
future_target_mask: torch.Tensor | None = None,
|
||||
output_attentions: bool = False,
|
||||
return_minmax: bool = False,
|
||||
):
|
||||
self._validate_input(
|
||||
context=context,
|
||||
|
|
@ -571,8 +582,8 @@ class Chronos2Model(PreTrainedModel):
|
|||
)
|
||||
|
||||
batch_size = context.shape[0]
|
||||
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
|
||||
context=context, context_mask=context_mask
|
||||
patched_context, attention_mask, loc_scale, context_minmax = self._prepare_patched_context(
|
||||
context=context, context_mask=context_mask, return_minmax=return_minmax
|
||||
)
|
||||
num_context_patches = attention_mask.shape[-1]
|
||||
|
||||
|
|
@ -613,7 +624,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
group_ids=group_ids,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
|
||||
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -626,6 +637,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
future_target: torch.Tensor | None = None,
|
||||
future_target_mask: torch.Tensor | None = None,
|
||||
output_attentions: bool = False,
|
||||
clip_factor: float | None = None,
|
||||
) -> Chronos2Output:
|
||||
"""Forward pass of the Chronos2 model.
|
||||
|
||||
|
|
@ -694,7 +706,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
|
||||
"""
|
||||
batch_size = context.shape[0]
|
||||
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
|
||||
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax = self.encode(
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
group_ids=group_ids,
|
||||
|
|
@ -704,6 +716,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
future_target=future_target,
|
||||
future_target_mask=future_target_mask,
|
||||
output_attentions=output_attentions,
|
||||
return_minmax=clip_factor is not None,
|
||||
)
|
||||
hidden_states: torch.Tensor = encoder_outputs[0]
|
||||
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
|
||||
|
|
@ -741,6 +754,13 @@ class Chronos2Model(PreTrainedModel):
|
|||
h=num_output_patches * self.chronos_config.output_patch_size,
|
||||
)
|
||||
quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale)
|
||||
|
||||
if clip_factor is not None:
|
||||
assert context_minmax is not None
|
||||
clamp_min = context_minmax[0] - clip_factor * loc_scale[1]
|
||||
clamp_max = context_minmax[1] + clip_factor * loc_scale[1]
|
||||
quantile_preds = quantile_preds.clamp(min=clamp_min, max=clamp_max)
|
||||
|
||||
quantile_preds = rearrange(
|
||||
quantile_preds,
|
||||
"b (q h) -> b q h",
|
||||
|
|
|
|||
|
|
@ -413,6 +413,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
unrolled_quantiles: torch.Tensor,
|
||||
unrolled_sample_weights: torch.Tensor,
|
||||
num_output_patches: int,
|
||||
clip_factor: float | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Get unrolled_quantiles from prediction and append it to the expanded context
|
||||
prediction_unrolled = interpolate_quantiles(
|
||||
|
|
@ -439,6 +440,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
else None,
|
||||
group_ids=rearrange(group_ids, "b n -> (b n)"),
|
||||
num_output_patches=num_output_patches,
|
||||
clip_factor=clip_factor,
|
||||
)
|
||||
# Reshape predictions from (batch * n_paths, n_quantiles, length) to (batch, n_paths * n_quantiles, length)
|
||||
prediction = rearrange(prediction, "(b n) q h -> b (n q) h", n=n_paths)
|
||||
|
|
@ -463,6 +465,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
context_length: int | None = None,
|
||||
cross_learning: bool = False,
|
||||
limit_prediction_length: bool = False,
|
||||
clip_factor: float | None = None,
|
||||
**kwargs,
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
|
|
@ -647,8 +650,14 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
prediction_length=prediction_length,
|
||||
max_output_patches=max_output_patches,
|
||||
target_idx_ranges=batch_target_idx_ranges,
|
||||
clip_factor=clip_factor,
|
||||
)
|
||||
all_predictions.extend(batch_prediction)
|
||||
|
||||
# Remove floating point noise around integers
|
||||
for item in batch_prediction:
|
||||
item = torch.where(torch.abs(item - item.round()) < 1e-5, item.round(), item)
|
||||
all_predictions.append(item)
|
||||
|
||||
after_batch_callback()
|
||||
|
||||
return all_predictions
|
||||
|
|
@ -662,6 +671,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
prediction_length: int,
|
||||
max_output_patches: int,
|
||||
target_idx_ranges: list[tuple[int, int]],
|
||||
clip_factor: float | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
context = context.to(device=self.model.device, dtype=torch.float32)
|
||||
group_ids = group_ids.to(device=self.model.device)
|
||||
|
|
@ -682,6 +692,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
group_ids=group_ids,
|
||||
future_covariates=future_covariates,
|
||||
num_output_patches=get_num_output_patches(remaining),
|
||||
clip_factor=clip_factor,
|
||||
)
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
|
@ -707,6 +718,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
unrolled_quantiles=unrolled_quantiles_tensor,
|
||||
unrolled_sample_weights=unrolled_sample_weights,
|
||||
num_output_patches=get_num_output_patches(remaining),
|
||||
clip_factor=clip_factor,
|
||||
)
|
||||
predictions.append(prediction)
|
||||
remaining -= prediction.shape[-1]
|
||||
|
|
@ -723,6 +735,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
group_ids: torch.Tensor,
|
||||
future_covariates: torch.Tensor | None,
|
||||
num_output_patches: int,
|
||||
clip_factor: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if future_covariates is not None:
|
||||
|
|
@ -741,7 +754,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
kwargs["future_covariates"] = future_covariates
|
||||
with torch.no_grad():
|
||||
prediction: torch.Tensor = self.model(
|
||||
context=context, group_ids=group_ids, num_output_patches=num_output_patches, **kwargs
|
||||
context=context,
|
||||
group_ids=group_ids,
|
||||
num_output_patches=num_output_patches,
|
||||
clip_factor=clip_factor,
|
||||
**kwargs,
|
||||
).quantile_preds.to(context)
|
||||
|
||||
return prediction
|
||||
|
|
|
|||
Loading…
Reference in a new issue