diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index b00e8a8..15b3e0a 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -15,27 +15,59 @@ from transformers.utils import ModelOutput from .config import Chronos2CoreConfig -class RoPE(nn.Module): +class Chronos2RotaryEmbedding(nn.Module): """Applies rotary position embeddings (RoPE) to input tensors. + This implementation follows the transformers v5 pattern for RotaryEmbedding classes, + which enables automatic buffer reinitialization via the base `_init_weights` method. + Implementation adapted from: - https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95 + https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py """ - def __init__(self, dim: int, base: float = 10000): + inv_freq: torch.Tensor # type hint for register_buffer + + def __init__(self, config: Chronos2CoreConfig, device=None): super().__init__() - self.dim = dim - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.inv_freq: torch.Tensor # type hint for type checker - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + self.config = config + self.rope_type = "default" + inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Chronos2CoreConfig, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + """ + Computes the inverse frequencies for RoPE embeddings. + + Args: + config: The model configuration containing rope_theta and d_kv. + device: The device to use for initialization. + seq_len: Unused, kept for API compatibility with transformers. + + Returns: + Tuple of (inv_freq tensor, attention_scaling factor). + """ + base = config.rope_theta + dim = config.d_kv + + attention_factor = 1.0 # Unused in default RoPE + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 @@ -44,8 +76,8 @@ class RoPE(nn.Module): with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @staticmethod @@ -78,8 +110,8 @@ class RoPE(nn.Module): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (RoPE.rotate_half(q) * sin) - k_embed = (k * cos) + (RoPE.rotate_half(k) * sin) + q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin) + k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin) return q_embed, k_embed @@ -164,7 +196,7 @@ class MHA(nn.Module): self.use_rope = use_rope if use_rope: - self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) + self.rope_embed = Chronos2RotaryEmbedding(config=config) def _eager_attention( self, @@ -277,7 +309,7 @@ class MHA(nn.Module): value_states = shape(self.v(hidden_states)) if self.use_rope: cos, sin = self.rope_embed(value_states, position_ids) - query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin) if attn_implementation == "sdpa": attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index cee0f3c..4a9ff4c 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -297,8 +297,12 @@ class Chronos2Model(PreTrainedModel): init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) - elif isinstance(module, (Chronos2Model)): + elif isinstance(module, Chronos2Model): init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) + quantiles = torch.tensor( + module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device + ) + init.copy_(module.quantiles, quantiles) elif isinstance(module, ResidualBlock): init.normal_( module.hidden_layer.weight,