This commit is contained in:
Kashif Rasul 2026-04-26 14:25:24 +02:00 committed by GitHub
commit fe9de3302a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 152 additions and 55 deletions

View file

@ -15,8 +15,8 @@ license = { file = "LICENSE" }
requires-python = ">=3.10"
dependencies = [
"torch>=2.2,<3",
"transformers>=4.41,<5",
"accelerate>=0.34,<2",
"transformers>=4.41",
"accelerate>=1.1.0",
"numpy>=1.21,<3",
"einops>=0.7.0,<1",
"scikit-learn>=1.6.0,<2",
@ -41,14 +41,14 @@ path = "src/chronos/__about__.py"
[project.optional-dependencies]
extras = [
"boto3>=1.10,<2",
"peft>=0.13.0,<0.18",
"peft>=0.18.1",
"fev>=0.6.1",
"pandas[pyarrow]>=2.0,<2.4",
]
test = [
"pytest~=8.0",
"boto3>=1.10,<2",
"peft>=0.13.0,<1",
"peft>=0.18.1",
"fev>=0.6.1",
"pandas[pyarrow]>=2.0,<2.4",
]

View file

@ -21,6 +21,7 @@ import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, get_worker_info
import transformers
from packaging import version
from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
@ -46,6 +47,7 @@ from gluonts.transform import (
from chronos import ChronosConfig, ChronosTokenizer
_TRANSFORMERS_V5 = version.parse(transformers.__version__) >= version.parse("5.0.0")
app = typer.Typer(pretty_exceptions_enable=False)
@ -661,7 +663,7 @@ def main(
per_device_train_batch_size=per_device_train_batch_size,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
**({"warmup_steps": round(warmup_ratio * max_steps)} if _TRANSFORMERS_V5 else {"warmup_ratio": warmup_ratio}),
optim=optim,
logging_strategy="steps",
logging_steps=log_steps,

View file

@ -362,9 +362,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
from transformers import AutoConfig
torch_dtype = kwargs.get("torch_dtype", "auto")
if torch_dtype != "auto" and isinstance(torch_dtype, str):
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
# Handle both torch_dtype (deprecated) and dtype arguments
dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto")
if dtype_value != "auto" and isinstance(dtype_value, str):
dtype_value = cls.dtypes[dtype_value]
kwargs["dtype"] = dtype_value
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")

View file

@ -10,42 +10,75 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils.generic import maybe_autocast
from transformers.utils import ModelOutput
from .config import Chronos2CoreConfig
class RoPE(nn.Module):
class Chronos2RotaryEmbedding(nn.Module):
"""Applies rotary position embeddings (RoPE) to input tensors.
This implementation follows the transformers v5 pattern for RotaryEmbedding classes,
which enables automatic buffer reinitialization via the base `_init_weights` method.
Implementation adapted from:
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
"""
def __init__(self, dim: int, base: float = 10000):
inv_freq: torch.Tensor # type hint for register_buffer
def __init__(self, config: Chronos2CoreConfig, device=None):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.inv_freq: torch.Tensor # type hint for type checker
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
self.config = config
self.rope_type = "default"
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: Chronos2CoreConfig,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple[torch.Tensor, float]:
"""
Computes the inverse frequencies for RoPE embeddings.
Args:
config: The model configuration containing rope_theta and d_kv.
device: The device to use for initialization.
seq_len: Unused, kept for API compatibility with transformers.
Returns:
Tuple of (inv_freq tensor, attention_scaling factor).
"""
base = config.rope_theta
dim = config.d_kv
attention_factor = 1.0 # Unused in default RoPE
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
with maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@staticmethod
@ -78,8 +111,8 @@ class RoPE(nn.Module):
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (RoPE.rotate_half(q) * sin)
k_embed = (k * cos) + (RoPE.rotate_half(k) * sin)
q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin)
k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin)
return q_embed, k_embed
@ -164,7 +197,7 @@ class MHA(nn.Module):
self.use_rope = use_rope
if use_rope:
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
self.rope_embed = Chronos2RotaryEmbedding(config=config)
def _eager_attention(
self,
@ -277,7 +310,7 @@ class MHA(nn.Module):
value_states = shape(self.v(hidden_states))
if self.use_rope:
cos, sin = self.rope_embed(value_states, position_ids)
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if attn_implementation == "sdpa":
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)

View file

@ -10,9 +10,19 @@ from typing import cast
import torch
import torch.nn as nn
from einops import rearrange, repeat
from packaging import version
from transformers import __version__ as transformers_version
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
if _TRANSFORMERS_V5:
from transformers import initialization as init
else:
from torch.nn import init
from chronos.chronos_bolt import InstanceNorm, Patch
from .config import Chronos2CoreConfig, Chronos2ForecastingConfig
@ -268,49 +278,58 @@ class Chronos2Model(PreTrainedModel):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, Chronos2LayerNorm):
module.weight.data.fill_(factor * 1.0)
init.constant_(module.weight, factor * 1.0)
elif isinstance(module, MLP):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.zeros_(module.wi.bias)
init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
init.zeros_(module.wo.bias)
elif isinstance(module, MHA):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
kv_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, (Chronos2Model)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, Chronos2Model):
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
if _TRANSFORMERS_V5:
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)
module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)
module.output_layer.weight.data.normal_(
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
init.normal_(
module.output_layer.weight,
mean=0.0,
std=factor * (module.output_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)
def _validate_input(
self,

View file

@ -22,6 +22,7 @@ from transformers.utils.peft_utils import find_adapter_config_file
import chronos.chronos2
from chronos.base import BaseChronosPipeline, ForecastType
from chronos.chronos2 import Chronos2Model
from chronos.chronos2.model import _TRANSFORMERS_V5
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
from chronos.utils import interpolate_quantiles, weighted_quantile
@ -270,7 +271,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
lr_scheduler_type="linear",
warmup_ratio=0.0,
**({"warmup_steps": 0} if _TRANSFORMERS_V5 else {"warmup_ratio": 0.0}),
optim="adamw_torch_fused",
logging_strategy="steps",
logging_steps=100,

View file

@ -13,7 +13,8 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig
from packaging import version
from transformers import AutoConfig, __version__ as transformers_version
from transformers.models.t5.modeling_t5 import (
ACT2FN,
T5Config,
@ -28,6 +29,29 @@ from .base import BaseChronosPipeline, ForecastType
logger = logging.getLogger(__file__)
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
# In transformers v5, use guarded init functions that check _is_hf_initialized
# to avoid re-initializing weights loaded from checkpoint
if _TRANSFORMERS_V5:
from transformers import initialization as init
else:
from torch.nn import init
def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack:
"""
Create a T5Stack with the given config and embed_tokens.
This helper function provides backward compatibility between transformers v4 and v5.
In v4, T5Stack.__init__ accepts (config, embed_tokens).
In v5, T5Stack.__init__ only accepts (config), and embed_tokens must be set separately.
"""
if _TRANSFORMERS_V5:
return T5Stack(config)
else:
return T5Stack(config, embed_tokens)
@dataclass
class ChronosBoltConfig:
@ -150,7 +174,15 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
r"output_patch_embedding\.",
]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
# In transformers v5, _tied_weights_keys changed from list to dict {target: source}
_tied_weights_keys = ( # type: ignore
{
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
if _TRANSFORMERS_V5
else ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
)
def __init__(self, config: T5Config):
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
@ -188,7 +220,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
self.encoder = T5Stack(encoder_config, self.shared)
self.encoder = _create_t5_stack(encoder_config, self.shared)
self._init_decoder(config)
@ -217,25 +249,33 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
# Reinitialize quantiles buffer for transformers v5 meta device compatibility
if _TRANSFORMERS_V5:
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)
module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)
module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
@ -359,7 +399,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.decoder = _create_t5_stack(decoder_config, self.shared)
def decode(
self,

Binary file not shown.