diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2..2b7523e 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -371,8 +371,10 @@ class Chronos2Model(PreTrainedModel): ) def _prepare_patched_context( - self, context: torch.Tensor, context_mask: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + self, context: torch.Tensor, context_mask: torch.Tensor | None = None, return_minmax: bool = False + ) -> tuple[ + torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | None + ]: context_mask = ( context_mask.to(context.dtype) if context_mask is not None @@ -386,6 +388,14 @@ class Chronos2Model(PreTrainedModel): context_mask = context_mask[..., -self.chronos_config.context_length :] # scaling + context_minmax = None + if return_minmax: + context_min = torch.amin(torch.nan_to_num(context, nan=float("inf")), dim=-1, keepdim=True) + context_min = torch.nan_to_num(context_min, posinf=0.0) + context_max = torch.amax(torch.nan_to_num(context, nan=float("-inf")), dim=-1, keepdim=True) + context_max = torch.nan_to_num(context_max, neginf=0.0) + context_minmax = context_min, context_max + context, loc_scale = self.instance_norm(context) # scaling is done in 32-bit precision, then the context is moved to model's dtype @@ -420,7 +430,7 @@ class Chronos2Model(PreTrainedModel): # concat time encoding, context and mask along the last (feature) dim patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1) - return patched_context, attention_mask, loc_scale + return patched_context, attention_mask, loc_scale, context_minmax def _prepare_patched_future( self, @@ -558,6 +568,7 @@ class Chronos2Model(PreTrainedModel): future_target: torch.Tensor | None = None, future_target_mask: torch.Tensor | None = None, output_attentions: bool = False, + return_minmax: bool = False, ): self._validate_input( context=context, @@ -571,8 +582,8 @@ class Chronos2Model(PreTrainedModel): ) batch_size = context.shape[0] - patched_context, attention_mask, loc_scale = self._prepare_patched_context( - context=context, context_mask=context_mask + patched_context, attention_mask, loc_scale, context_minmax = self._prepare_patched_context( + context=context, context_mask=context_mask, return_minmax=return_minmax ) num_context_patches = attention_mask.shape[-1] @@ -613,7 +624,7 @@ class Chronos2Model(PreTrainedModel): group_ids=group_ids, output_attentions=output_attentions, ) - return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches + return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax def forward( self, @@ -626,6 +637,7 @@ class Chronos2Model(PreTrainedModel): future_target: torch.Tensor | None = None, future_target_mask: torch.Tensor | None = None, output_attentions: bool = False, + clip_factor: float | None = None, ) -> Chronos2Output: """Forward pass of the Chronos2 model. @@ -694,7 +706,7 @@ class Chronos2Model(PreTrainedModel): - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True """ batch_size = context.shape[0] - encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode( + encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax = self.encode( context=context, context_mask=context_mask, group_ids=group_ids, @@ -704,6 +716,7 @@ class Chronos2Model(PreTrainedModel): future_target=future_target, future_target_mask=future_target_mask, output_attentions=output_attentions, + return_minmax=clip_factor is not None, ) hidden_states: torch.Tensor = encoder_outputs[0] assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) @@ -741,6 +754,13 @@ class Chronos2Model(PreTrainedModel): h=num_output_patches * self.chronos_config.output_patch_size, ) quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale) + + if clip_factor is not None: + assert context_minmax is not None + clamp_min = context_minmax[0] - clip_factor * loc_scale[1] + clamp_max = context_minmax[1] + clip_factor * loc_scale[1] + quantile_preds = quantile_preds.clamp(min=clamp_min, max=clamp_max) + quantile_preds = rearrange( quantile_preds, "b (q h) -> b q h", diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 223689d..c036583 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -413,6 +413,7 @@ class Chronos2Pipeline(BaseChronosPipeline): unrolled_quantiles: torch.Tensor, unrolled_sample_weights: torch.Tensor, num_output_patches: int, + clip_factor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Get unrolled_quantiles from prediction and append it to the expanded context prediction_unrolled = interpolate_quantiles( @@ -439,6 +440,7 @@ class Chronos2Pipeline(BaseChronosPipeline): else None, group_ids=rearrange(group_ids, "b n -> (b n)"), num_output_patches=num_output_patches, + clip_factor=clip_factor, ) # Reshape predictions from (batch * n_paths, n_quantiles, length) to (batch, n_paths * n_quantiles, length) prediction = rearrange(prediction, "(b n) q h -> b (n q) h", n=n_paths) @@ -463,6 +465,7 @@ class Chronos2Pipeline(BaseChronosPipeline): context_length: int | None = None, cross_learning: bool = False, limit_prediction_length: bool = False, + clip_factor: float | None = None, **kwargs, ) -> list[torch.Tensor]: """ @@ -647,8 +650,14 @@ class Chronos2Pipeline(BaseChronosPipeline): prediction_length=prediction_length, max_output_patches=max_output_patches, target_idx_ranges=batch_target_idx_ranges, + clip_factor=clip_factor, ) - all_predictions.extend(batch_prediction) + + # Remove floating point noise around integers + for item in batch_prediction: + item = torch.where(torch.abs(item - item.round()) < 1e-5, item.round(), item) + all_predictions.append(item) + after_batch_callback() return all_predictions @@ -662,6 +671,7 @@ class Chronos2Pipeline(BaseChronosPipeline): prediction_length: int, max_output_patches: int, target_idx_ranges: list[tuple[int, int]], + clip_factor: float | None = None, ) -> list[torch.Tensor]: context = context.to(device=self.model.device, dtype=torch.float32) group_ids = group_ids.to(device=self.model.device) @@ -682,6 +692,7 @@ class Chronos2Pipeline(BaseChronosPipeline): group_ids=group_ids, future_covariates=future_covariates, num_output_patches=get_num_output_patches(remaining), + clip_factor=clip_factor, ) predictions.append(prediction) remaining -= prediction.shape[-1] @@ -707,6 +718,7 @@ class Chronos2Pipeline(BaseChronosPipeline): unrolled_quantiles=unrolled_quantiles_tensor, unrolled_sample_weights=unrolled_sample_weights, num_output_patches=get_num_output_patches(remaining), + clip_factor=clip_factor, ) predictions.append(prediction) remaining -= prediction.shape[-1] @@ -723,6 +735,7 @@ class Chronos2Pipeline(BaseChronosPipeline): group_ids: torch.Tensor, future_covariates: torch.Tensor | None, num_output_patches: int, + clip_factor: float | None = None, ) -> torch.Tensor: kwargs = {} if future_covariates is not None: @@ -741,7 +754,11 @@ class Chronos2Pipeline(BaseChronosPipeline): kwargs["future_covariates"] = future_covariates with torch.no_grad(): prediction: torch.Tensor = self.model( - context=context, group_ids=group_ids, num_output_patches=num_output_patches, **kwargs + context=context, + group_ids=group_ids, + num_output_patches=num_output_patches, + clip_factor=clip_factor, + **kwargs, ).quantile_preds.to(context) return prediction