mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
fix buffer inits
This commit is contained in:
parent
08eee1fc15
commit
f29ff338e5
2 changed files with 53 additions and 17 deletions
|
|
@ -15,27 +15,59 @@ from transformers.utils import ModelOutput
|
|||
from .config import Chronos2CoreConfig
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
class Chronos2RotaryEmbedding(nn.Module):
|
||||
"""Applies rotary position embeddings (RoPE) to input tensors.
|
||||
|
||||
This implementation follows the transformers v5 pattern for RotaryEmbedding classes,
|
||||
which enables automatic buffer reinitialization via the base `_init_weights` method.
|
||||
|
||||
Implementation adapted from:
|
||||
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base: float = 10000):
|
||||
inv_freq: torch.Tensor # type hint for register_buffer
|
||||
|
||||
def __init__(self, config: Chronos2CoreConfig, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
self.inv_freq: torch.Tensor # type hint for type checker
|
||||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||
self.config = config
|
||||
self.rope_type = "default"
|
||||
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: Chronos2CoreConfig,
|
||||
device: torch.device | None = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""
|
||||
Computes the inverse frequencies for RoPE embeddings.
|
||||
|
||||
Args:
|
||||
config: The model configuration containing rope_theta and d_kv.
|
||||
device: The device to use for initialization.
|
||||
seq_len: Unused, kept for API compatibility with transformers.
|
||||
|
||||
Returns:
|
||||
Tuple of (inv_freq tensor, attention_scaling factor).
|
||||
"""
|
||||
base = config.rope_theta
|
||||
dim = config.d_kv
|
||||
|
||||
attention_factor = 1.0 # Unused in default RoPE
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
self.inv_freq.to(x.device)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
|
|
@ -44,8 +76,8 @@ class RoPE(nn.Module):
|
|||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -78,8 +110,8 @@ class RoPE(nn.Module):
|
|||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (RoPE.rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (RoPE.rotate_half(k) * sin)
|
||||
q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
|
@ -164,7 +196,7 @@ class MHA(nn.Module):
|
|||
|
||||
self.use_rope = use_rope
|
||||
if use_rope:
|
||||
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
|
||||
self.rope_embed = Chronos2RotaryEmbedding(config=config)
|
||||
|
||||
def _eager_attention(
|
||||
self,
|
||||
|
|
@ -277,7 +309,7 @@ class MHA(nn.Module):
|
|||
value_states = shape(self.v(hidden_states))
|
||||
if self.use_rope:
|
||||
cos, sin = self.rope_embed(value_states, position_ids)
|
||||
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if attn_implementation == "sdpa":
|
||||
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
|
||||
|
|
|
|||
|
|
@ -297,8 +297,12 @@ class Chronos2Model(PreTrainedModel):
|
|||
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
|
||||
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
|
||||
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
|
||||
elif isinstance(module, (Chronos2Model)):
|
||||
elif isinstance(module, Chronos2Model):
|
||||
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
init.normal_(
|
||||
module.hidden_layer.weight,
|
||||
|
|
|
|||
Loading…
Reference in a new issue