From 49edd9b876f5dc772884e351961560197ae6eb5b Mon Sep 17 00:00:00 2001 From: xymli Date: Mon, 22 Dec 2025 21:03:47 +0800 Subject: [PATCH] Speed up group attention by remove useless padding. Signed-off-by: xymli --- src/chronos/chronos2/layers.py | 174 ++++++++++++++++---- src/chronos/chronos2/model.py | 292 +++++++++++++++++++++++++++------ 2 files changed, 385 insertions(+), 81 deletions(-) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index b00e8a8..4e58ecf 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -5,6 +5,7 @@ from dataclasses import dataclass +import os import torch from einops import rearrange from torch import nn @@ -27,22 +28,35 @@ class RoPE(nn.Module): self.dim = dim self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) + ) self.inv_freq: torch.Tensor # type hint for type checker self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -57,7 +71,11 @@ class RoPE(nn.Module): @staticmethod def apply_rotary_pos_emb( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: """Applies Rotary Position Embedding to the query and key tensors. @@ -129,7 +147,9 @@ class FeedForward(nn.Module): assert not config.is_gated_act, "gated activations are unsupported" self.mlp: nn.Module = MLP(config) - self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -186,10 +206,14 @@ class MHA(nn.Module): attn_weights: [batch, n_heads, q_len, kv_len] """ # Compute attention weights (no scaling - this is the original Chronos-2 implementation) - scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # "bnqd,bnkd->bnqk" scores += mask attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) attn_output = torch.matmul(attn_weights, value_states) return attn_output, attn_weights @@ -247,7 +271,9 @@ class MHA(nn.Module): - attn_weights : Attention weights if output_attentions=True """ if self.use_rope: - assert position_ids is not None, "position_ids must be provided when self.use_rope=True" + assert ( + position_ids is not None + ), "position_ids must be provided when self.use_rope=True" # Force eager attention if output_attentions is True (only eager returns weights) attn_implementation = self.config._attn_implementation @@ -258,11 +284,23 @@ class MHA(nn.Module): def shape(states: torch.Tensor) -> torch.Tensor: """(batch, seq_len, inner_dim) -> (batch, n_heads, seq_len, kv_proj_dim)""" - return rearrange(states, "b s (h d) -> b h s d", h=self.n_heads, s=seq_length, d=self.kv_proj_dim) + return rearrange( + states, + "b s (h d) -> b h s d", + h=self.n_heads, + s=seq_length, + d=self.kv_proj_dim, + ) def unshape(states: torch.Tensor) -> torch.Tensor: """(batch, n_heads, seq_len, kv_proj_dim) -> (batch, seq_len, inner_dim)""" - return rearrange(states, "b h s d -> b s (h d)", h=self.n_heads, s=seq_length, d=self.kv_proj_dim) + return rearrange( + states, + "b h s d -> b s (h d)", + h=self.n_heads, + s=seq_length, + d=self.kv_proj_dim, + ) # Construct query states query_states = shape(self.q(hidden_states)) @@ -277,25 +315,36 @@ class MHA(nn.Module): value_states = shape(self.v(hidden_states)) if self.use_rope: cos, sin = self.rope_embed(value_states, position_ids) - query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = RoPE.apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) if attn_implementation == "sdpa": - attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) + attn_output, attn_weights = self._sdpa_attention( + query_states, key_states, value_states, mask + ) else: # eager - attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) + attn_output, attn_weights = self._eager_attention( + query_states, key_states, value_states, mask + ) # Project attention output attn_output = unshape(attn_output) attn_output = self.o(attn_output) - return AttentionOutput(hidden_states=attn_output, attn_weights=attn_weights if output_attentions else None) + return AttentionOutput( + hidden_states=attn_output, + attn_weights=attn_weights if output_attentions else None, + ) class TimeSelfAttention(nn.Module): def __init__(self, config: Chronos2CoreConfig): super().__init__() self.self_attention = MHA(config, use_rope=True) - self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -307,18 +356,25 @@ class TimeSelfAttention(nn.Module): ) -> AttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) attention_output: AttentionOutput = self.self_attention( - normed_hidden_states, position_ids=position_ids, mask=attention_mask, output_attentions=output_attentions + normed_hidden_states, + position_ids=position_ids, + mask=attention_mask, + output_attentions=output_attentions, ) hidden_states = hidden_states + self.dropout(attention_output[0]) - return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) + return AttentionOutput( + hidden_states=hidden_states, attn_weights=attention_output.attn_weights + ) class TimeCrossAttention(nn.Module): def __init__(self, config: Chronos2CoreConfig): super().__init__() self.cross_attention = MHA(config, use_rope=False) - self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -337,7 +393,9 @@ class TimeCrossAttention(nn.Module): ) hidden_states = hidden_states + self.dropout(attention_output[0]) - return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) + return AttentionOutput( + hidden_states=hidden_states, attn_weights=attention_output.attn_weights + ) class GroupSelfAttention(nn.Module): @@ -347,23 +405,83 @@ class GroupSelfAttention(nn.Module): super().__init__() # we don't use RoPE here because there's no natural ordering along the batch axis self.self_attention = MHA(config, use_rope=False) - self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | tuple, + output_attentions: bool = False, ) -> AttentionOutput: # flip time and batch axes because attention operates along dim=-2 hidden_states = rearrange(hidden_states, "batch time d -> time batch d") normed_hidden_states = self.layer_norm(hidden_states) - attention_output: AttentionOutput = self.self_attention( - normed_hidden_states, mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) + + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + ( + flast_group_time_mask, + group_start_index, + counts, + max_group_len, + group_num, + ) = attention_mask + fast_normed_hidden_states = torch.zeros( + ( + group_num * normed_hidden_states.size(0), + max_group_len, + normed_hidden_states.size(2), + ), + dtype=normed_hidden_states.dtype, + device=normed_hidden_states.device, + ) + for i in range(group_num): + start_index = group_start_index[i].item() + group_len = counts[i].item() + end_index = start_index + group_len + fast_normed_hidden_states[ + i + * normed_hidden_states.size(0) : (i + 1) + * normed_hidden_states.size(0), + :group_len, + :, + ] = normed_hidden_states[:, start_index:end_index, :] + fast_attention_output: AttentionOutput = self.self_attention( + fast_normed_hidden_states, + mask=flast_group_time_mask, + output_attentions=output_attentions, + ) + attn_weights = fast_attention_output.attn_weights + attention_output = torch.empty_like(normed_hidden_states) + for i in range(group_num): + start_index = group_start_index[i].item() + group_len = counts[i].item() + end_index = start_index + group_len + attention_output[:, start_index:end_index, :] = fast_attention_output[ + 0 + ][ + i + * normed_hidden_states.size(0) : (i + 1) + * normed_hidden_states.size(0), + :group_len, + :, + ] + else: + attention_output: AttentionOutput = self.self_attention( + normed_hidden_states, + mask=attention_mask, + output_attentions=output_attentions, + ) + attn_weights = attention_output.attn_weights + attention_output = attention_output[0] + + hidden_states = hidden_states + self.dropout(attention_output) # flip time and batch axes back to their original position hidden_states = rearrange(hidden_states, "time batch d -> batch time d") - return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) + return AttentionOutput(hidden_states=hidden_states, attn_weights=attn_weights) class ResidualBlock(nn.Module): diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2..451db9e 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -4,6 +4,7 @@ # Authors: Abdul Fatir Ansari import copy +import os from dataclasses import dataclass from typing import cast @@ -51,7 +52,7 @@ class Chronos2EncoderBlock(nn.Module): *, position_ids: torch.Tensor, attention_mask: torch.Tensor, - group_time_mask: torch.Tensor, + group_time_mask: torch.Tensor | tuple, output_attentions: bool = False, ) -> Chronos2EncoderBlockOutput: # apply time attention @@ -65,7 +66,9 @@ class Chronos2EncoderBlock(nn.Module): # apply group attention group_self_attn_outputs: AttentionOutput = self.layer[1]( - hidden_states, attention_mask=group_time_mask, output_attentions=output_attentions + hidden_states, + attention_mask=group_time_mask, + output_attentions=output_attentions, ) hidden_states = group_self_attn_outputs[0] @@ -91,15 +94,21 @@ class Chronos2Encoder(nn.Module): super().__init__() assert not config.is_decoder - self.block = nn.ModuleList([Chronos2EncoderBlock(config) for i in range(config.num_layers)]) - self.final_layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.block = nn.ModuleList( + [Chronos2EncoderBlock(config) for i in range(config.num_layers)] + ) + self.final_layer_norm = Chronos2LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) self.dropout = nn.Dropout(config.dropout_rate) @staticmethod def _expand_and_invert_time_attention_mask( attention_mask: torch.Tensor, floating_type: torch.dtype ) -> torch.Tensor: - assert attention_mask.ndim == 2, "attention_mask must have shape (batch, seq_len)" + assert ( + attention_mask.ndim == 2 + ), "attention_mask must have shape (batch, seq_len)" # Add new dims for attention heads and q_len attention_mask = attention_mask[:, None, None, :] @@ -111,7 +120,9 @@ class Chronos2Encoder(nn.Module): @staticmethod def _construct_and_invert_group_time_mask( - group_ids: torch.Tensor, attention_mask: torch.Tensor, floating_type: torch.dtype + group_ids: torch.Tensor, + attention_mask: torch.Tensor, + floating_type: torch.dtype, ) -> torch.Tensor: # construct group_mask (batch, batch) from group ids # a cell is True if both row and col had the same group id @@ -129,6 +140,35 @@ class Chronos2Encoder(nn.Module): group_time_mask = rearrange(group_time_mask, "q b t -> t 1 q b") group_time_mask = (1.0 - group_time_mask) * torch.finfo(floating_type).min + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + # pad group_time_mask to max group size for less useless computation in self-attention + unique_groups, counts = torch.unique(group_ids, return_counts=True) + max_group_len = counts.max().item() + group_num = unique_groups.size(0) + fast_group_time_mask = torch.full( + (group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), + torch.finfo(floating_type).min, + device=group_time_mask.device, + ) + group_start_index = counts.cumsum(dim=0) - counts + for i in range(group_num): + group_len = counts[i].item() + start_index = group_start_index[i].item() + end_index = start_index + group_len + fast_group_time_mask[ + i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), + :, + :group_len, + :group_len, + ] = group_time_mask[:, :, start_index:end_index, start_index:end_index] + return ( + fast_group_time_mask, + group_start_index, + counts, + max_group_len, + group_num, + ) + return group_time_mask def forward( @@ -143,16 +183,45 @@ class Chronos2Encoder(nn.Module): batch_size, seq_length = inputs_embeds.size()[:-1] if position_ids is None: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=inputs_embeds.device).unsqueeze(0) + position_ids = torch.arange( + 0, seq_length, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0) if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + attention_mask = torch.ones( + batch_size, + seq_length, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) # make the time attention mask broadcastable to attention scores (batch, n_heads, q_len, kv_len) and invert - extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype) + extended_attention_mask = self._expand_and_invert_time_attention_mask( + attention_mask, inputs_embeds.dtype + ) # construct group time mask - group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype) + if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": + ( + flast_group_time_mask, + group_start_index, + counts, + max_group_len, + group_num, + ) = self._construct_and_invert_group_time_mask( + group_ids, attention_mask, inputs_embeds.dtype + ) + group_time_mask = ( + flast_group_time_mask, + group_start_index, + counts, + max_group_len, + group_num, + ) + else: + group_time_mask = self._construct_and_invert_group_time_mask( + group_ids, attention_mask, inputs_embeds.dtype + ) all_time_self_attentions: tuple[torch.Tensor, ...] = () all_group_self_attentions: tuple[torch.Tensor, ...] = () @@ -174,8 +243,14 @@ class Chronos2Encoder(nn.Module): assert layer_outputs.time_self_attn_weights is not None assert layer_outputs.group_self_attn_weights is not None - all_time_self_attentions = (*all_time_self_attentions, layer_outputs.time_self_attn_weights) - all_group_self_attentions = (*all_group_self_attentions, layer_outputs.group_self_attn_weights) + all_time_self_attentions = ( + *all_time_self_attentions, + layer_outputs.time_self_attn_weights, + ) + all_group_self_attentions = ( + *all_group_self_attentions, + layer_outputs.group_self_attn_weights, + ) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -213,7 +288,10 @@ class Chronos2Model(PreTrainedModel): ) self.chronos_config = Chronos2ForecastingConfig(**config.chronos_config) - assert self.chronos_config.input_patch_size == self.chronos_config.output_patch_size, ( + assert ( + self.chronos_config.input_patch_size + == self.chronos_config.output_patch_size + ), ( "input_patch_size and output_patch_size sizes must be equal, " f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}" ) @@ -237,7 +315,8 @@ class Chronos2Model(PreTrainedModel): # patching layer self.patch = Patch( - patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride + patch_size=self.chronos_config.input_patch_size, + patch_stride=self.chronos_config.input_patch_stride, ) # instance normalization, also referred to as "scaling" in Chronos and GluonTS @@ -273,10 +352,14 @@ class Chronos2Model(PreTrainedModel): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, MHA): @@ -285,10 +368,14 @@ class Chronos2Model(PreTrainedModel): d_model = self.config.d_model kv_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5) + ) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5) + ) elif isinstance(module, (Chronos2Model)): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, ResidualBlock): @@ -296,20 +383,29 @@ class Chronos2Model(PreTrainedModel): mean=0.0, std=factor * (module.hidden_layer.weight.size(-1) ** -0.5), ) - if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: + if ( + hasattr(module.hidden_layer, "bias") + and module.hidden_layer.bias is not None + ): module.hidden_layer.bias.data.zero_() module.residual_layer.weight.data.normal_( mean=0.0, std=factor * (module.residual_layer.weight.size(-1) ** -0.5), ) - if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: + if ( + hasattr(module.residual_layer, "bias") + and module.residual_layer.bias is not None + ): module.residual_layer.bias.data.zero_() module.output_layer.weight.data.normal_( mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5) ) - if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: + if ( + hasattr(module.output_layer, "bias") + and module.output_layer.bias is not None + ): module.output_layer.bias.data.zero_() def _validate_input( @@ -325,11 +421,18 @@ class Chronos2Model(PreTrainedModel): ): output_patch_size = self.chronos_config.output_patch_size if context.ndim != 2: - raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}") + raise ValueError( + f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}" + ) if context_mask is not None and context_mask.shape != context.shape: - raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}") + raise ValueError( + f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}" + ) if future_covariates is not None: - if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2: + if ( + future_covariates.shape[0] != context.shape[0] + or future_covariates.ndim != 2 + ): raise ValueError( f"future_covariates must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_covariates.shape)}" ) @@ -338,20 +441,27 @@ class Chronos2Model(PreTrainedModel): f"{num_output_patches=} must be large enough to accommodate the length of future_covariates, " f"found: {future_covariates.shape[-1]} > {num_output_patches} * {output_patch_size}" ) - if future_target is not None and future_target.shape != future_covariates.shape: + if ( + future_target is not None + and future_target.shape != future_covariates.shape + ): raise ValueError( f"future_target must have the same shape as future_covariates, found: {tuple(future_target.shape)} and {tuple(future_covariates.shape)}" ) if future_covariates_mask is not None: if future_covariates is None: - raise ValueError("future_covariates must be provided if future_covariates_mask is provided") + raise ValueError( + "future_covariates must be provided if future_covariates_mask is provided" + ) if future_covariates_mask.shape != future_covariates.shape: raise ValueError( f"future_covariates_mask must have the same shape as future_covariates, " f"found: {tuple(future_covariates_mask.shape)} and {tuple(future_covariates.shape)}" ) if group_ids is not None and group_ids.shape != (context.shape[0],): - raise ValueError(f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}") + raise ValueError( + f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}" + ) if future_target is not None: if future_target.shape[0] != context.shape[0] or future_target.ndim != 2: raise ValueError( @@ -364,7 +474,9 @@ class Chronos2Model(PreTrainedModel): ) if future_target_mask is not None: if future_target is None: - raise ValueError("future_target must be provided if future_target_mask is provided") + raise ValueError( + "future_target must be provided if future_target_mask is provided" + ) if future_target_mask.shape != future_target.shape: raise ValueError( f"future_target_mask must have the same shape as future_target, found: {tuple(future_target_mask.shape)} and {tuple(future_target.shape)}" @@ -403,8 +515,12 @@ class Chronos2Model(PreTrainedModel): # context time encoding: every observation is assigned a sequential time index, # scaled by model's context length = [-C, -(C-1), ..., -1] / context_length - final_context_length = num_context_patches * self.chronos_config.input_patch_size - context_time_enc = torch.arange(start=-final_context_length, end=0, device=self.device, dtype=torch.float32) + final_context_length = ( + num_context_patches * self.chronos_config.input_patch_size + ) + context_time_enc = torch.arange( + start=-final_context_length, end=0, device=self.device, dtype=torch.float32 + ) context_time_enc = ( repeat( context_time_enc, @@ -418,7 +534,9 @@ class Chronos2Model(PreTrainedModel): ) # concat time encoding, context and mask along the last (feature) dim - patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1) + patched_context = torch.cat( + [context_time_enc, patched_context, patched_mask], dim=-1 + ) return patched_context, attention_mask, loc_scale @@ -437,9 +555,15 @@ class Chronos2Model(PreTrainedModel): future_covariates = future_covariates.to(self.dtype) if future_covariates_mask is None: - future_covariates_mask = torch.isnan(future_covariates).logical_not().to(future_covariates.dtype) + future_covariates_mask = ( + torch.isnan(future_covariates) + .logical_not() + .to(future_covariates.dtype) + ) - future_covariates = torch.where(future_covariates_mask > 0.0, future_covariates, 0.0) + future_covariates = torch.where( + future_covariates_mask > 0.0, future_covariates, 0.0 + ) if torch.isnan(future_covariates).any(): raise ValueError( @@ -451,33 +575,58 @@ class Chronos2Model(PreTrainedModel): if num_output_patches * output_patch_size > future_covariates.shape[-1]: padding_shape = ( *future_covariates.shape[:-1], - num_output_patches * output_patch_size - future_covariates.shape[-1], + num_output_patches * output_patch_size + - future_covariates.shape[-1], ) future_covariates = torch.cat( - [future_covariates, torch.zeros(padding_shape).to(future_covariates)], dim=-1 + [ + future_covariates, + torch.zeros(padding_shape).to(future_covariates), + ], + dim=-1, ) future_covariates_mask = torch.cat( - [future_covariates_mask, torch.zeros(padding_shape).to(future_covariates_mask)], dim=-1 + [ + future_covariates_mask, + torch.zeros(padding_shape).to(future_covariates_mask), + ], + dim=-1, ) patched_future_covariates = rearrange( - future_covariates, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size + future_covariates, + "b (n p) -> b n p", + n=num_output_patches, + p=output_patch_size, ) patched_future_covariates_mask = rearrange( - future_covariates_mask, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size + future_covariates_mask, + "b (n p) -> b n p", + n=num_output_patches, + p=output_patch_size, ) else: patched_future_covariates = torch.zeros( - batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype + batch_size, + num_output_patches, + output_patch_size, + device=self.device, + dtype=self.dtype, ) patched_future_covariates_mask = torch.zeros( - batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype + batch_size, + num_output_patches, + output_patch_size, + device=self.device, + dtype=self.dtype, ) # future time encoding: every future timestep is assigned a sequential time index, # scaled by model's context length = [0, 1, ..., h-1] / context_length final_future_length = num_output_patches * output_patch_size - future_time_enc = torch.arange(start=0, end=final_future_length, device=self.device, dtype=torch.float32) + future_time_enc = torch.arange( + start=0, end=final_future_length, device=self.device, dtype=torch.float32 + ) future_time_enc = ( repeat( future_time_enc, @@ -491,7 +640,12 @@ class Chronos2Model(PreTrainedModel): ) patched_future = torch.cat( - [future_time_enc, patched_future_covariates, patched_future_covariates_mask], dim=-1 + [ + future_time_enc, + patched_future_covariates, + patched_future_covariates_mask, + ], + dim=-1, ) return patched_future, patched_future_covariates_mask @@ -507,7 +661,10 @@ class Chronos2Model(PreTrainedModel): ) -> torch.Tensor: batch_size = future_target.shape[0] output_patch_size = self.chronos_config.output_patch_size - assert quantile_preds.shape[0] == batch_size and quantile_preds.shape[-1] >= future_target.shape[-1] + assert ( + quantile_preds.shape[0] == batch_size + and quantile_preds.shape[-1] >= future_target.shape[-1] + ) # normalize target and mask future_target, _ = self.instance_norm(future_target, loc_scale) @@ -522,15 +679,22 @@ class Chronos2Model(PreTrainedModel): # pad target and target_mask if they are shorter than model's prediction if quantile_preds.shape[-1] > future_target.shape[-1]: - padding_shape = (*future_target.shape[:-1], quantile_preds.shape[-1] - future_target.shape[-1]) - future_target = torch.cat([future_target, torch.zeros(padding_shape).to(future_target)], dim=-1) + padding_shape = ( + *future_target.shape[:-1], + quantile_preds.shape[-1] - future_target.shape[-1], + ) + future_target = torch.cat( + [future_target, torch.zeros(padding_shape).to(future_target)], dim=-1 + ) future_target_mask = torch.cat( - [future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], dim=-1 + [future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], + dim=-1, ) quantiles = rearrange(self.quantiles, "num_quantiles -> 1 num_quantiles 1") quantile_loss = 2 * torch.abs( - (future_target - quantile_preds) * ((future_target <= quantile_preds).float() - quantiles) + (future_target - quantile_preds) + * ((future_target <= quantile_preds).float() - quantiles) ) inv_future_covariate_mask = 1 - rearrange( patched_future_covariates_mask, @@ -580,11 +744,17 @@ class Chronos2Model(PreTrainedModel): input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) # append [REG] special token embedding, if needed if self.chronos_config.use_reg_token: - reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device) + reg_input_ids = torch.full( + (batch_size, 1), self.config.reg_token_id, device=input_embeds.device + ) reg_embeds = self.shared(reg_input_ids) input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) attention_mask = torch.cat( - [attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1 + [ + attention_mask.to(self.dtype), + torch.ones_like(reg_input_ids).to(self.dtype), + ], + dim=-1, ) patched_future, patched_future_covariates_mask = self._prepare_patched_future( @@ -594,7 +764,9 @@ class Chronos2Model(PreTrainedModel): num_output_patches=num_output_patches, batch_size=batch_size, ) - future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device) + future_attention_mask = torch.ones( + batch_size, num_output_patches, dtype=self.dtype, device=self.device + ) # get future embeddings of shape (batch, num_output_patches, d_model) future_embeds: torch.Tensor = self.input_patch_embedding(patched_future) @@ -613,7 +785,12 @@ class Chronos2Model(PreTrainedModel): group_ids=group_ids, output_attentions=output_attentions, ) - return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches + return ( + encoder_outputs, + loc_scale, + patched_future_covariates_mask, + num_context_patches, + ) def forward( self, @@ -694,7 +871,12 @@ class Chronos2Model(PreTrainedModel): - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True """ batch_size = context.shape[0] - encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode( + ( + encoder_outputs, + loc_scale, + patched_future_covariates_mask, + num_context_patches, + ) = self.encode( context=context, context_mask=context_mask, group_ids=group_ids, @@ -706,7 +888,11 @@ class Chronos2Model(PreTrainedModel): output_attentions=output_attentions, ) hidden_states: torch.Tensor = encoder_outputs[0] - assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) + assert hidden_states.shape == ( + batch_size, + num_context_patches + 1 + num_output_patches, + self.model_dim, + ) # slice the last num_output_patches hidden states to be input into the output_patch_embedding forecast_embeds = hidden_states[:, -num_output_patches:]