From 9844366902e5c6e04a7dcfa15dd07105bb3ceefb Mon Sep 17 00:00:00 2001 From: xymli Date: Mon, 22 Dec 2025 21:05:50 +0800 Subject: [PATCH] Align code formatting Signed-off-by: xymli --- src/chronos/chronos2/layers.py | 150 ++++-------------- src/chronos/chronos2/model.py | 277 +++++++-------------------------- 2 files changed, 90 insertions(+), 337 deletions(-) diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 4e58ecf..ca3132b 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -28,35 +28,22 @@ class RoPE(nn.Module): self.dim = dim self.base = base - inv_freq = 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) - ) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.inv_freq: torch.Tensor # type hint for type checker self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - def forward( - self, x: torch.Tensor, position_ids: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -71,11 +58,7 @@ class RoPE(nn.Module): @staticmethod def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - unsqueeze_dim: int = 1, + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 ) -> tuple[torch.Tensor, torch.Tensor]: """Applies Rotary Position Embedding to the query and key tensors. @@ -147,9 +130,7 @@ class FeedForward(nn.Module): assert not config.is_gated_act, "gated activations are unsupported" self.mlp: nn.Module = MLP(config) - self.layer_norm = Chronos2LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -206,14 +187,10 @@ class MHA(nn.Module): attn_weights: [batch, n_heads, q_len, kv_len] """ # Compute attention weights (no scaling - this is the original Chronos-2 implementation) - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # "bnqd,bnkd->bnqk" + scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" scores += mask attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) return attn_output, attn_weights @@ -271,9 +248,7 @@ class MHA(nn.Module): - attn_weights : Attention weights if output_attentions=True """ if self.use_rope: - assert ( - position_ids is not None - ), "position_ids must be provided when self.use_rope=True" + assert position_ids is not None, "position_ids must be provided when self.use_rope=True" # Force eager attention if output_attentions is True (only eager returns weights) attn_implementation = self.config._attn_implementation @@ -284,23 +259,11 @@ class MHA(nn.Module): def shape(states: torch.Tensor) -> torch.Tensor: """(batch, seq_len, inner_dim) -> (batch, n_heads, seq_len, kv_proj_dim)""" - return rearrange( - states, - "b s (h d) -> b h s d", - h=self.n_heads, - s=seq_length, - d=self.kv_proj_dim, - ) + return rearrange(states, "b s (h d) -> b h s d", h=self.n_heads, s=seq_length, d=self.kv_proj_dim) def unshape(states: torch.Tensor) -> torch.Tensor: """(batch, n_heads, seq_len, kv_proj_dim) -> (batch, seq_len, inner_dim)""" - return rearrange( - states, - "b h s d -> b s (h d)", - h=self.n_heads, - s=seq_length, - d=self.kv_proj_dim, - ) + return rearrange(states, "b h s d -> b s (h d)", h=self.n_heads, s=seq_length, d=self.kv_proj_dim) # Construct query states query_states = shape(self.q(hidden_states)) @@ -315,36 +278,25 @@ class MHA(nn.Module): value_states = shape(self.v(hidden_states)) if self.use_rope: cos, sin = self.rope_embed(value_states, position_ids) - query_states, key_states = RoPE.apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) if attn_implementation == "sdpa": - attn_output, attn_weights = self._sdpa_attention( - query_states, key_states, value_states, mask - ) + attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) else: # eager - attn_output, attn_weights = self._eager_attention( - query_states, key_states, value_states, mask - ) + attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) # Project attention output attn_output = unshape(attn_output) attn_output = self.o(attn_output) - return AttentionOutput( - hidden_states=attn_output, - attn_weights=attn_weights if output_attentions else None, - ) + return AttentionOutput(hidden_states=attn_output, attn_weights=attn_weights if output_attentions else None) class TimeSelfAttention(nn.Module): def __init__(self, config: Chronos2CoreConfig): super().__init__() self.self_attention = MHA(config, use_rope=True) - self.layer_norm = Chronos2LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -356,25 +308,18 @@ class TimeSelfAttention(nn.Module): ) -> AttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) attention_output: AttentionOutput = self.self_attention( - normed_hidden_states, - position_ids=position_ids, - mask=attention_mask, - output_attentions=output_attentions, + normed_hidden_states, position_ids=position_ids, mask=attention_mask, output_attentions=output_attentions ) hidden_states = hidden_states + self.dropout(attention_output[0]) - return AttentionOutput( - hidden_states=hidden_states, attn_weights=attention_output.attn_weights - ) + return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) class TimeCrossAttention(nn.Module): def __init__(self, config: Chronos2CoreConfig): super().__init__() self.cross_attention = MHA(config, use_rope=False) - self.layer_norm = Chronos2LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( @@ -393,9 +338,7 @@ class TimeCrossAttention(nn.Module): ) hidden_states = hidden_states + self.dropout(attention_output[0]) - return AttentionOutput( - hidden_states=hidden_states, attn_weights=attention_output.attn_weights - ) + return AttentionOutput(hidden_states=hidden_states, attn_weights=attention_output.attn_weights) class GroupSelfAttention(nn.Module): @@ -405,35 +348,20 @@ class GroupSelfAttention(nn.Module): super().__init__() # we don't use RoPE here because there's no natural ordering along the batch axis self.self_attention = MHA(config, use_rope=False) - self.layer_norm = Chronos2LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | tuple, - output_attentions: bool = False, + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | tuple, output_attentions: bool = False ) -> AttentionOutput: # flip time and batch axes because attention operates along dim=-2 hidden_states = rearrange(hidden_states, "batch time d -> time batch d") normed_hidden_states = self.layer_norm(hidden_states) if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": - ( - flast_group_time_mask, - group_start_index, - counts, - max_group_len, - group_num, - ) = attention_mask + flast_group_time_mask, group_start_index, counts, max_group_len, group_num = attention_mask fast_normed_hidden_states = torch.zeros( - ( - group_num * normed_hidden_states.size(0), - max_group_len, - normed_hidden_states.size(2), - ), + (group_num * normed_hidden_states.size(0), max_group_len, normed_hidden_states.size(2)), dtype=normed_hidden_states.dtype, device=normed_hidden_states.device, ) @@ -441,17 +369,11 @@ class GroupSelfAttention(nn.Module): start_index = group_start_index[i].item() group_len = counts[i].item() end_index = start_index + group_len - fast_normed_hidden_states[ - i - * normed_hidden_states.size(0) : (i + 1) - * normed_hidden_states.size(0), - :group_len, - :, - ] = normed_hidden_states[:, start_index:end_index, :] + fast_normed_hidden_states[i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, :] = normed_hidden_states[ + :, start_index:end_index, : + ] fast_attention_output: AttentionOutput = self.self_attention( - fast_normed_hidden_states, - mask=flast_group_time_mask, - output_attentions=output_attentions, + fast_normed_hidden_states, mask=flast_group_time_mask, output_attentions=output_attentions ) attn_weights = fast_attention_output.attn_weights attention_output = torch.empty_like(normed_hidden_states) @@ -459,20 +381,12 @@ class GroupSelfAttention(nn.Module): start_index = group_start_index[i].item() group_len = counts[i].item() end_index = start_index + group_len - attention_output[:, start_index:end_index, :] = fast_attention_output[ - 0 - ][ - i - * normed_hidden_states.size(0) : (i + 1) - * normed_hidden_states.size(0), - :group_len, - :, + attention_output[:, start_index:end_index, :] = fast_attention_output[0][ + i * normed_hidden_states.size(0) : (i + 1) * normed_hidden_states.size(0), :group_len, : ] else: attention_output: AttentionOutput = self.self_attention( - normed_hidden_states, - mask=attention_mask, - output_attentions=output_attentions, + normed_hidden_states, mask=attention_mask, output_attentions=output_attentions ) attn_weights = attention_output.attn_weights attention_output = attention_output[0] diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 451db9e..411fb54 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -66,9 +66,7 @@ class Chronos2EncoderBlock(nn.Module): # apply group attention group_self_attn_outputs: AttentionOutput = self.layer[1]( - hidden_states, - attention_mask=group_time_mask, - output_attentions=output_attentions, + hidden_states, attention_mask=group_time_mask, output_attentions=output_attentions ) hidden_states = group_self_attn_outputs[0] @@ -94,21 +92,15 @@ class Chronos2Encoder(nn.Module): super().__init__() assert not config.is_decoder - self.block = nn.ModuleList( - [Chronos2EncoderBlock(config) for i in range(config.num_layers)] - ) - self.final_layer_norm = Chronos2LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.block = nn.ModuleList([Chronos2EncoderBlock(config) for i in range(config.num_layers)]) + self.final_layer_norm = Chronos2LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @staticmethod def _expand_and_invert_time_attention_mask( attention_mask: torch.Tensor, floating_type: torch.dtype ) -> torch.Tensor: - assert ( - attention_mask.ndim == 2 - ), "attention_mask must have shape (batch, seq_len)" + assert attention_mask.ndim == 2, "attention_mask must have shape (batch, seq_len)" # Add new dims for attention heads and q_len attention_mask = attention_mask[:, None, None, :] @@ -120,9 +112,7 @@ class Chronos2Encoder(nn.Module): @staticmethod def _construct_and_invert_group_time_mask( - group_ids: torch.Tensor, - attention_mask: torch.Tensor, - floating_type: torch.dtype, + group_ids: torch.Tensor, attention_mask: torch.Tensor, floating_type: torch.dtype ) -> torch.Tensor: # construct group_mask (batch, batch) from group ids # a cell is True if both row and col had the same group id @@ -146,9 +136,7 @@ class Chronos2Encoder(nn.Module): max_group_len = counts.max().item() group_num = unique_groups.size(0) fast_group_time_mask = torch.full( - (group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), - torch.finfo(floating_type).min, - device=group_time_mask.device, + (group_num * group_time_mask.size(0), 1, max_group_len, max_group_len), torch.finfo(floating_type).min, device=group_time_mask.device ) group_start_index = counts.cumsum(dim=0) - counts for i in range(group_num): @@ -156,18 +144,9 @@ class Chronos2Encoder(nn.Module): start_index = group_start_index[i].item() end_index = start_index + group_len fast_group_time_mask[ - i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), - :, - :group_len, - :group_len, - ] = group_time_mask[:, :, start_index:end_index, start_index:end_index] - return ( - fast_group_time_mask, - group_start_index, - counts, - max_group_len, - group_num, - ) + i * group_time_mask.size(0) : (i + 1) * group_time_mask.size(0), :, :group_len, :group_len + ] = group_time_mask[:, :, start_index : end_index, start_index : end_index] + return fast_group_time_mask, group_start_index, counts, max_group_len, group_num return group_time_mask @@ -183,45 +162,22 @@ class Chronos2Encoder(nn.Module): batch_size, seq_length = inputs_embeds.size()[:-1] if position_ids is None: - position_ids = torch.arange( - 0, seq_length, dtype=torch.long, device=inputs_embeds.device - ).unsqueeze(0) + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=inputs_embeds.device).unsqueeze(0) if attention_mask is None: - attention_mask = torch.ones( - batch_size, - seq_length, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - ) + attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=inputs_embeds.dtype) # make the time attention mask broadcastable to attention scores (batch, n_heads, q_len, kv_len) and invert - extended_attention_mask = self._expand_and_invert_time_attention_mask( - attention_mask, inputs_embeds.dtype - ) + extended_attention_mask = self._expand_and_invert_time_attention_mask(attention_mask, inputs_embeds.dtype) # construct group time mask if os.environ.get("CHRONOS2_USE_FAST_GROUP_ATTENTION", "0") == "1": - ( - flast_group_time_mask, - group_start_index, - counts, - max_group_len, - group_num, - ) = self._construct_and_invert_group_time_mask( + flast_group_time_mask, group_start_index, counts, max_group_len, group_num = self._construct_and_invert_group_time_mask( group_ids, attention_mask, inputs_embeds.dtype ) - group_time_mask = ( - flast_group_time_mask, - group_start_index, - counts, - max_group_len, - group_num, - ) + group_time_mask = (flast_group_time_mask, group_start_index, counts, max_group_len, group_num) else: - group_time_mask = self._construct_and_invert_group_time_mask( - group_ids, attention_mask, inputs_embeds.dtype - ) + group_time_mask = self._construct_and_invert_group_time_mask(group_ids, attention_mask, inputs_embeds.dtype) all_time_self_attentions: tuple[torch.Tensor, ...] = () all_group_self_attentions: tuple[torch.Tensor, ...] = () @@ -243,14 +199,8 @@ class Chronos2Encoder(nn.Module): assert layer_outputs.time_self_attn_weights is not None assert layer_outputs.group_self_attn_weights is not None - all_time_self_attentions = ( - *all_time_self_attentions, - layer_outputs.time_self_attn_weights, - ) - all_group_self_attentions = ( - *all_group_self_attentions, - layer_outputs.group_self_attn_weights, - ) + all_time_self_attentions = (*all_time_self_attentions, layer_outputs.time_self_attn_weights) + all_group_self_attentions = (*all_group_self_attentions, layer_outputs.group_self_attn_weights) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -288,10 +238,7 @@ class Chronos2Model(PreTrainedModel): ) self.chronos_config = Chronos2ForecastingConfig(**config.chronos_config) - assert ( - self.chronos_config.input_patch_size - == self.chronos_config.output_patch_size - ), ( + assert self.chronos_config.input_patch_size == self.chronos_config.output_patch_size, ( "input_patch_size and output_patch_size sizes must be equal, " f"but found {self.chronos_config.input_patch_size} and {self.chronos_config.output_patch_size}" ) @@ -315,8 +262,7 @@ class Chronos2Model(PreTrainedModel): # patching layer self.patch = Patch( - patch_size=self.chronos_config.input_patch_size, - patch_stride=self.chronos_config.input_patch_stride, + patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride ) # instance normalization, also referred to as "scaling" in Chronos and GluonTS @@ -352,14 +298,10 @@ class Chronos2Model(PreTrainedModel): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_model) ** -0.5) - ) + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: module.wi.bias.data.zero_() - module.wo.weight.data.normal_( - mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) - ) + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() elif isinstance(module, MHA): @@ -368,14 +310,10 @@ class Chronos2Model(PreTrainedModel): d_model = self.config.d_model kv_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_( - mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5) - ) + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_( - mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5) - ) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) elif isinstance(module, (Chronos2Model)): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, ResidualBlock): @@ -383,29 +321,20 @@ class Chronos2Model(PreTrainedModel): mean=0.0, std=factor * (module.hidden_layer.weight.size(-1) ** -0.5), ) - if ( - hasattr(module.hidden_layer, "bias") - and module.hidden_layer.bias is not None - ): + if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: module.hidden_layer.bias.data.zero_() module.residual_layer.weight.data.normal_( mean=0.0, std=factor * (module.residual_layer.weight.size(-1) ** -0.5), ) - if ( - hasattr(module.residual_layer, "bias") - and module.residual_layer.bias is not None - ): + if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: module.residual_layer.bias.data.zero_() module.output_layer.weight.data.normal_( mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5) ) - if ( - hasattr(module.output_layer, "bias") - and module.output_layer.bias is not None - ): + if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: module.output_layer.bias.data.zero_() def _validate_input( @@ -421,18 +350,11 @@ class Chronos2Model(PreTrainedModel): ): output_patch_size = self.chronos_config.output_patch_size if context.ndim != 2: - raise ValueError( - f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}" - ) + raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}") if context_mask is not None and context_mask.shape != context.shape: - raise ValueError( - f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}" - ) + raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}") if future_covariates is not None: - if ( - future_covariates.shape[0] != context.shape[0] - or future_covariates.ndim != 2 - ): + if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2: raise ValueError( f"future_covariates must have shape (batch_size={context.shape[0]}, future_length), found: {tuple(future_covariates.shape)}" ) @@ -441,27 +363,20 @@ class Chronos2Model(PreTrainedModel): f"{num_output_patches=} must be large enough to accommodate the length of future_covariates, " f"found: {future_covariates.shape[-1]} > {num_output_patches} * {output_patch_size}" ) - if ( - future_target is not None - and future_target.shape != future_covariates.shape - ): + if future_target is not None and future_target.shape != future_covariates.shape: raise ValueError( f"future_target must have the same shape as future_covariates, found: {tuple(future_target.shape)} and {tuple(future_covariates.shape)}" ) if future_covariates_mask is not None: if future_covariates is None: - raise ValueError( - "future_covariates must be provided if future_covariates_mask is provided" - ) + raise ValueError("future_covariates must be provided if future_covariates_mask is provided") if future_covariates_mask.shape != future_covariates.shape: raise ValueError( f"future_covariates_mask must have the same shape as future_covariates, " f"found: {tuple(future_covariates_mask.shape)} and {tuple(future_covariates.shape)}" ) if group_ids is not None and group_ids.shape != (context.shape[0],): - raise ValueError( - f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}" - ) + raise ValueError(f"group_ids must have shape (batch_size,), found: {tuple(group_ids.shape)}") if future_target is not None: if future_target.shape[0] != context.shape[0] or future_target.ndim != 2: raise ValueError( @@ -474,9 +389,7 @@ class Chronos2Model(PreTrainedModel): ) if future_target_mask is not None: if future_target is None: - raise ValueError( - "future_target must be provided if future_target_mask is provided" - ) + raise ValueError("future_target must be provided if future_target_mask is provided") if future_target_mask.shape != future_target.shape: raise ValueError( f"future_target_mask must have the same shape as future_target, found: {tuple(future_target_mask.shape)} and {tuple(future_target.shape)}" @@ -515,12 +428,8 @@ class Chronos2Model(PreTrainedModel): # context time encoding: every observation is assigned a sequential time index, # scaled by model's context length = [-C, -(C-1), ..., -1] / context_length - final_context_length = ( - num_context_patches * self.chronos_config.input_patch_size - ) - context_time_enc = torch.arange( - start=-final_context_length, end=0, device=self.device, dtype=torch.float32 - ) + final_context_length = num_context_patches * self.chronos_config.input_patch_size + context_time_enc = torch.arange(start=-final_context_length, end=0, device=self.device, dtype=torch.float32) context_time_enc = ( repeat( context_time_enc, @@ -534,9 +443,7 @@ class Chronos2Model(PreTrainedModel): ) # concat time encoding, context and mask along the last (feature) dim - patched_context = torch.cat( - [context_time_enc, patched_context, patched_mask], dim=-1 - ) + patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1) return patched_context, attention_mask, loc_scale @@ -555,15 +462,9 @@ class Chronos2Model(PreTrainedModel): future_covariates = future_covariates.to(self.dtype) if future_covariates_mask is None: - future_covariates_mask = ( - torch.isnan(future_covariates) - .logical_not() - .to(future_covariates.dtype) - ) + future_covariates_mask = torch.isnan(future_covariates).logical_not().to(future_covariates.dtype) - future_covariates = torch.where( - future_covariates_mask > 0.0, future_covariates, 0.0 - ) + future_covariates = torch.where(future_covariates_mask > 0.0, future_covariates, 0.0) if torch.isnan(future_covariates).any(): raise ValueError( @@ -575,58 +476,33 @@ class Chronos2Model(PreTrainedModel): if num_output_patches * output_patch_size > future_covariates.shape[-1]: padding_shape = ( *future_covariates.shape[:-1], - num_output_patches * output_patch_size - - future_covariates.shape[-1], + num_output_patches * output_patch_size - future_covariates.shape[-1], ) future_covariates = torch.cat( - [ - future_covariates, - torch.zeros(padding_shape).to(future_covariates), - ], - dim=-1, + [future_covariates, torch.zeros(padding_shape).to(future_covariates)], dim=-1 ) future_covariates_mask = torch.cat( - [ - future_covariates_mask, - torch.zeros(padding_shape).to(future_covariates_mask), - ], - dim=-1, + [future_covariates_mask, torch.zeros(padding_shape).to(future_covariates_mask)], dim=-1 ) patched_future_covariates = rearrange( - future_covariates, - "b (n p) -> b n p", - n=num_output_patches, - p=output_patch_size, + future_covariates, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size ) patched_future_covariates_mask = rearrange( - future_covariates_mask, - "b (n p) -> b n p", - n=num_output_patches, - p=output_patch_size, + future_covariates_mask, "b (n p) -> b n p", n=num_output_patches, p=output_patch_size ) else: patched_future_covariates = torch.zeros( - batch_size, - num_output_patches, - output_patch_size, - device=self.device, - dtype=self.dtype, + batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype ) patched_future_covariates_mask = torch.zeros( - batch_size, - num_output_patches, - output_patch_size, - device=self.device, - dtype=self.dtype, + batch_size, num_output_patches, output_patch_size, device=self.device, dtype=self.dtype ) # future time encoding: every future timestep is assigned a sequential time index, # scaled by model's context length = [0, 1, ..., h-1] / context_length final_future_length = num_output_patches * output_patch_size - future_time_enc = torch.arange( - start=0, end=final_future_length, device=self.device, dtype=torch.float32 - ) + future_time_enc = torch.arange(start=0, end=final_future_length, device=self.device, dtype=torch.float32) future_time_enc = ( repeat( future_time_enc, @@ -640,12 +516,7 @@ class Chronos2Model(PreTrainedModel): ) patched_future = torch.cat( - [ - future_time_enc, - patched_future_covariates, - patched_future_covariates_mask, - ], - dim=-1, + [future_time_enc, patched_future_covariates, patched_future_covariates_mask], dim=-1 ) return patched_future, patched_future_covariates_mask @@ -661,10 +532,7 @@ class Chronos2Model(PreTrainedModel): ) -> torch.Tensor: batch_size = future_target.shape[0] output_patch_size = self.chronos_config.output_patch_size - assert ( - quantile_preds.shape[0] == batch_size - and quantile_preds.shape[-1] >= future_target.shape[-1] - ) + assert quantile_preds.shape[0] == batch_size and quantile_preds.shape[-1] >= future_target.shape[-1] # normalize target and mask future_target, _ = self.instance_norm(future_target, loc_scale) @@ -679,22 +547,15 @@ class Chronos2Model(PreTrainedModel): # pad target and target_mask if they are shorter than model's prediction if quantile_preds.shape[-1] > future_target.shape[-1]: - padding_shape = ( - *future_target.shape[:-1], - quantile_preds.shape[-1] - future_target.shape[-1], - ) - future_target = torch.cat( - [future_target, torch.zeros(padding_shape).to(future_target)], dim=-1 - ) + padding_shape = (*future_target.shape[:-1], quantile_preds.shape[-1] - future_target.shape[-1]) + future_target = torch.cat([future_target, torch.zeros(padding_shape).to(future_target)], dim=-1) future_target_mask = torch.cat( - [future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], - dim=-1, + [future_target_mask, torch.zeros(padding_shape).to(future_target_mask)], dim=-1 ) quantiles = rearrange(self.quantiles, "num_quantiles -> 1 num_quantiles 1") quantile_loss = 2 * torch.abs( - (future_target - quantile_preds) - * ((future_target <= quantile_preds).float() - quantiles) + (future_target - quantile_preds) * ((future_target <= quantile_preds).float() - quantiles) ) inv_future_covariate_mask = 1 - rearrange( patched_future_covariates_mask, @@ -744,17 +605,11 @@ class Chronos2Model(PreTrainedModel): input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) # append [REG] special token embedding, if needed if self.chronos_config.use_reg_token: - reg_input_ids = torch.full( - (batch_size, 1), self.config.reg_token_id, device=input_embeds.device - ) + reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device) reg_embeds = self.shared(reg_input_ids) input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2) attention_mask = torch.cat( - [ - attention_mask.to(self.dtype), - torch.ones_like(reg_input_ids).to(self.dtype), - ], - dim=-1, + [attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1 ) patched_future, patched_future_covariates_mask = self._prepare_patched_future( @@ -764,9 +619,7 @@ class Chronos2Model(PreTrainedModel): num_output_patches=num_output_patches, batch_size=batch_size, ) - future_attention_mask = torch.ones( - batch_size, num_output_patches, dtype=self.dtype, device=self.device - ) + future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device) # get future embeddings of shape (batch, num_output_patches, d_model) future_embeds: torch.Tensor = self.input_patch_embedding(patched_future) @@ -785,12 +638,7 @@ class Chronos2Model(PreTrainedModel): group_ids=group_ids, output_attentions=output_attentions, ) - return ( - encoder_outputs, - loc_scale, - patched_future_covariates_mask, - num_context_patches, - ) + return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches def forward( self, @@ -871,12 +719,7 @@ class Chronos2Model(PreTrainedModel): - enc_group_self_attn_weights: Group self attention weights, if output_attentions=True """ batch_size = context.shape[0] - ( - encoder_outputs, - loc_scale, - patched_future_covariates_mask, - num_context_patches, - ) = self.encode( + encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode( context=context, context_mask=context_mask, group_ids=group_ids, @@ -888,11 +731,7 @@ class Chronos2Model(PreTrainedModel): output_attentions=output_attentions, ) hidden_states: torch.Tensor = encoder_outputs[0] - assert hidden_states.shape == ( - batch_size, - num_context_patches + 1 + num_output_patches, - self.model_dim, - ) + assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) # slice the last num_output_patches hidden states to be input into the output_patch_embedding forecast_embeds = hidden_states[:, -num_output_patches:]