fix buffer inits

This commit is contained in:
Kashif Rasul 2026-01-09 19:14:50 +01:00
parent 08eee1fc15
commit f29ff338e5
2 changed files with 53 additions and 17 deletions

View file

@ -15,27 +15,59 @@ from transformers.utils import ModelOutput
from .config import Chronos2CoreConfig
class RoPE(nn.Module):
class Chronos2RotaryEmbedding(nn.Module):
"""Applies rotary position embeddings (RoPE) to input tensors.
This implementation follows the transformers v5 pattern for RotaryEmbedding classes,
which enables automatic buffer reinitialization via the base `_init_weights` method.
Implementation adapted from:
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
"""
def __init__(self, dim: int, base: float = 10000):
inv_freq: torch.Tensor # type hint for register_buffer
def __init__(self, config: Chronos2CoreConfig, device=None):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.inv_freq: torch.Tensor # type hint for type checker
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
self.config = config
self.rope_type = "default"
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: Chronos2CoreConfig,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple[torch.Tensor, float]:
"""
Computes the inverse frequencies for RoPE embeddings.
Args:
config: The model configuration containing rope_theta and d_kv.
device: The device to use for initialization.
seq_len: Unused, kept for API compatibility with transformers.
Returns:
Tuple of (inv_freq tensor, attention_scaling factor).
"""
base = config.rope_theta
dim = config.d_kv
attention_factor = 1.0 # Unused in default RoPE
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
@ -44,8 +76,8 @@ class RoPE(nn.Module):
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@staticmethod
@ -78,8 +110,8 @@ class RoPE(nn.Module):
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (RoPE.rotate_half(q) * sin)
k_embed = (k * cos) + (RoPE.rotate_half(k) * sin)
q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin)
k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin)
return q_embed, k_embed
@ -164,7 +196,7 @@ class MHA(nn.Module):
self.use_rope = use_rope
if use_rope:
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
self.rope_embed = Chronos2RotaryEmbedding(config=config)
def _eager_attention(
self,
@ -277,7 +309,7 @@ class MHA(nn.Module):
value_states = shape(self.v(hidden_states))
if self.use_rope:
cos, sin = self.rope_embed(value_states, position_ids)
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if attn_implementation == "sdpa":
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)

View file

@ -297,8 +297,12 @@ class Chronos2Model(PreTrainedModel):
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, (Chronos2Model)):
elif isinstance(module, Chronos2Model):
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
elif isinstance(module, ResidualBlock):
init.normal_(
module.hidden_layer.weight,