This commit is contained in:
Abdul Fatir 2026-05-14 14:25:32 +02:00 committed by GitHub
commit 9530db24ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 46 additions and 9 deletions

View file

@ -371,8 +371,10 @@ class Chronos2Model(PreTrainedModel):
)
def _prepare_patched_context(
self, context: torch.Tensor, context_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
self, context: torch.Tensor, context_mask: torch.Tensor | None = None, return_minmax: bool = False
) -> tuple[
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | None
]:
context_mask = (
context_mask.to(context.dtype)
if context_mask is not None
@ -386,6 +388,14 @@ class Chronos2Model(PreTrainedModel):
context_mask = context_mask[..., -self.chronos_config.context_length :]
# scaling
context_minmax = None
if return_minmax:
context_min = torch.amin(torch.nan_to_num(context, nan=float("inf")), dim=-1, keepdim=True)
context_min = torch.nan_to_num(context_min, posinf=0.0)
context_max = torch.amax(torch.nan_to_num(context, nan=float("-inf")), dim=-1, keepdim=True)
context_max = torch.nan_to_num(context_max, neginf=0.0)
context_minmax = context_min, context_max
context, loc_scale = self.instance_norm(context)
# scaling is done in 32-bit precision, then the context is moved to model's dtype
@ -420,7 +430,7 @@ class Chronos2Model(PreTrainedModel):
# concat time encoding, context and mask along the last (feature) dim
patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1)
return patched_context, attention_mask, loc_scale
return patched_context, attention_mask, loc_scale, context_minmax
def _prepare_patched_future(
self,
@ -558,6 +568,7 @@ class Chronos2Model(PreTrainedModel):
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
return_minmax: bool = False,
):
self._validate_input(
context=context,
@ -571,8 +582,8 @@ class Chronos2Model(PreTrainedModel):
)
batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
patched_context, attention_mask, loc_scale, context_minmax = self._prepare_patched_context(
context=context, context_mask=context_mask, return_minmax=return_minmax
)
num_context_patches = attention_mask.shape[-1]
@ -613,7 +624,7 @@ class Chronos2Model(PreTrainedModel):
group_ids=group_ids,
output_attentions=output_attentions,
)
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax
def forward(
self,
@ -626,6 +637,7 @@ class Chronos2Model(PreTrainedModel):
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
clip_factor: float | None = None,
) -> Chronos2Output:
"""Forward pass of the Chronos2 model.
@ -694,7 +706,7 @@ class Chronos2Model(PreTrainedModel):
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
"""
batch_size = context.shape[0]
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax = self.encode(
context=context,
context_mask=context_mask,
group_ids=group_ids,
@ -704,6 +716,7 @@ class Chronos2Model(PreTrainedModel):
future_target=future_target,
future_target_mask=future_target_mask,
output_attentions=output_attentions,
return_minmax=clip_factor is not None,
)
hidden_states: torch.Tensor = encoder_outputs[0]
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
@ -741,6 +754,13 @@ class Chronos2Model(PreTrainedModel):
h=num_output_patches * self.chronos_config.output_patch_size,
)
quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale)
if clip_factor is not None:
assert context_minmax is not None
clamp_min = context_minmax[0] - clip_factor * loc_scale[1]
clamp_max = context_minmax[1] + clip_factor * loc_scale[1]
quantile_preds = quantile_preds.clamp(min=clamp_min, max=clamp_max)
quantile_preds = rearrange(
quantile_preds,
"b (q h) -> b q h",

View file

@ -413,6 +413,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
unrolled_quantiles: torch.Tensor,
unrolled_sample_weights: torch.Tensor,
num_output_patches: int,
clip_factor: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Get unrolled_quantiles from prediction and append it to the expanded context
prediction_unrolled = interpolate_quantiles(
@ -439,6 +440,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
else None,
group_ids=rearrange(group_ids, "b n -> (b n)"),
num_output_patches=num_output_patches,
clip_factor=clip_factor,
)
# Reshape predictions from (batch * n_paths, n_quantiles, length) to (batch, n_paths * n_quantiles, length)
prediction = rearrange(prediction, "(b n) q h -> b (n q) h", n=n_paths)
@ -463,6 +465,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
context_length: int | None = None,
cross_learning: bool = False,
limit_prediction_length: bool = False,
clip_factor: float | None = None,
**kwargs,
) -> list[torch.Tensor]:
"""
@ -647,8 +650,14 @@ class Chronos2Pipeline(BaseChronosPipeline):
prediction_length=prediction_length,
max_output_patches=max_output_patches,
target_idx_ranges=batch_target_idx_ranges,
clip_factor=clip_factor,
)
all_predictions.extend(batch_prediction)
# Remove floating point noise around integers
for item in batch_prediction:
item = torch.where(torch.abs(item - item.round()) < 1e-5, item.round(), item)
all_predictions.append(item)
after_batch_callback()
return all_predictions
@ -662,6 +671,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
prediction_length: int,
max_output_patches: int,
target_idx_ranges: list[tuple[int, int]],
clip_factor: float | None = None,
) -> list[torch.Tensor]:
context = context.to(device=self.model.device, dtype=torch.float32)
group_ids = group_ids.to(device=self.model.device)
@ -682,6 +692,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
group_ids=group_ids,
future_covariates=future_covariates,
num_output_patches=get_num_output_patches(remaining),
clip_factor=clip_factor,
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
@ -707,6 +718,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
unrolled_quantiles=unrolled_quantiles_tensor,
unrolled_sample_weights=unrolled_sample_weights,
num_output_patches=get_num_output_patches(remaining),
clip_factor=clip_factor,
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
@ -723,6 +735,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
group_ids: torch.Tensor,
future_covariates: torch.Tensor | None,
num_output_patches: int,
clip_factor: float | None = None,
) -> torch.Tensor:
kwargs = {}
if future_covariates is not None:
@ -741,7 +754,11 @@ class Chronos2Pipeline(BaseChronosPipeline):
kwargs["future_covariates"] = future_covariates
with torch.no_grad():
prediction: torch.Tensor = self.model(
context=context, group_ids=group_ids, num_output_patches=num_output_patches, **kwargs
context=context,
group_ids=group_ids,
num_output_patches=num_output_patches,
clip_factor=clip_factor,
**kwargs,
).quantile_preds.to(context)
return prediction