diff --git a/pyproject.toml b/pyproject.toml index d9e7117..dd4e4d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ "torch>=2.2,<3", - "transformers>=4.41,<5", - "accelerate>=0.34,<2", + "transformers>=4.41", + "accelerate>=1.1.0", "numpy>=1.21,<3", "einops>=0.7.0,<1", "scikit-learn>=1.6.0,<2", @@ -41,14 +41,14 @@ path = "src/chronos/__about__.py" [project.optional-dependencies] extras = [ "boto3>=1.10,<2", - "peft>=0.13.0,<0.18", + "peft>=0.18.1", "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] test = [ "pytest~=8.0", "boto3>=1.10,<2", - "peft>=0.13.0,<1", + "peft>=0.18.1", "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] diff --git a/scripts/training/train.py b/scripts/training/train.py index 09d5d8e..c586460 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist from torch.utils.data import IterableDataset, get_worker_info import transformers +from packaging import version from transformers import ( AutoModelForSeq2SeqLM, AutoModelForCausalLM, @@ -46,6 +47,7 @@ from gluonts.transform import ( from chronos import ChronosConfig, ChronosTokenizer +_TRANSFORMERS_V5 = version.parse(transformers.__version__) >= version.parse("5.0.0") app = typer.Typer(pretty_exceptions_enable=False) @@ -661,7 +663,7 @@ def main( per_device_train_batch_size=per_device_train_batch_size, learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type, - warmup_ratio=warmup_ratio, + **({"warmup_steps": round(warmup_ratio * max_steps)} if _TRANSFORMERS_V5 else {"warmup_ratio": warmup_ratio}), optim=optim, logging_strategy="steps", logging_steps=log_steps, diff --git a/src/chronos/base.py b/src/chronos/base.py index 807f91d..c42e87a 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -362,9 +362,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): from transformers import AutoConfig - torch_dtype = kwargs.get("torch_dtype", "auto") - if torch_dtype != "auto" and isinstance(torch_dtype, str): - kwargs["torch_dtype"] = cls.dtypes[torch_dtype] + # Handle both torch_dtype (deprecated) and dtype arguments + dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto") + if dtype_value != "auto" and isinstance(dtype_value, str): + dtype_value = cls.dtypes[dtype_value] + kwargs["dtype"] = dtype_value config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config") diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index b00e8a8..a632034 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -10,42 +10,75 @@ from einops import rearrange from torch import nn from transformers.activations import ACT2FN from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils.generic import maybe_autocast from transformers.utils import ModelOutput from .config import Chronos2CoreConfig -class RoPE(nn.Module): +class Chronos2RotaryEmbedding(nn.Module): """Applies rotary position embeddings (RoPE) to input tensors. + This implementation follows the transformers v5 pattern for RotaryEmbedding classes, + which enables automatic buffer reinitialization via the base `_init_weights` method. + Implementation adapted from: - https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95 + https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py """ - def __init__(self, dim: int, base: float = 10000): + inv_freq: torch.Tensor # type hint for register_buffer + + def __init__(self, config: Chronos2CoreConfig, device=None): super().__init__() - self.dim = dim - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.inv_freq: torch.Tensor # type hint for type checker - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + self.config = config + self.rope_type = "default" + inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Chronos2CoreConfig, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + """ + Computes the inverse frequencies for RoPE embeddings. + + Args: + config: The model configuration containing rope_theta and d_kv. + device: The device to use for initialization. + seq_len: Unused, kept for API compatibility with transformers. + + Returns: + Tuple of (inv_freq tensor, attention_scaling factor). + """ + base = config.rope_theta + dim = config.d_kv + + attention_factor = 1.0 # Unused in default RoPE + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @staticmethod @@ -78,8 +111,8 @@ class RoPE(nn.Module): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (RoPE.rotate_half(q) * sin) - k_embed = (k * cos) + (RoPE.rotate_half(k) * sin) + q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin) + k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin) return q_embed, k_embed @@ -164,7 +197,7 @@ class MHA(nn.Module): self.use_rope = use_rope if use_rope: - self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) + self.rope_embed = Chronos2RotaryEmbedding(config=config) def _eager_attention( self, @@ -277,7 +310,7 @@ class MHA(nn.Module): value_states = shape(self.v(hidden_states)) if self.use_rope: cos, sin = self.rope_embed(value_states, position_ids) - query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin) if attn_implementation == "sdpa": attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2..99eefb0 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -10,9 +10,19 @@ from typing import cast import torch import torch.nn as nn from einops import rearrange, repeat +from packaging import version +from transformers import __version__ as transformers_version from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput + +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0") + +if _TRANSFORMERS_V5: + from transformers import initialization as init +else: + from torch.nn import init + from chronos.chronos_bolt import InstanceNorm, Patch from .config import Chronos2CoreConfig, Chronos2ForecastingConfig @@ -268,49 +278,58 @@ class Chronos2Model(PreTrainedModel): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, Chronos2LayerNorm): - module.weight.data.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, MLP): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, MHA): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model kv_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) - elif isinstance(module, (Chronos2Model)): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) + elif isinstance(module, Chronos2Model): + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) + if _TRANSFORMERS_V5: + quantiles = torch.tensor( + module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device + ) + init.copy_(module.quantiles, quantiles) elif isinstance(module, ResidualBlock): - module.hidden_layer.weight.data.normal_( + init.normal_( + module.hidden_layer.weight, mean=0.0, std=factor * (module.hidden_layer.weight.size(-1) ** -0.5), ) if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: - module.hidden_layer.bias.data.zero_() + init.zeros_(module.hidden_layer.bias) - module.residual_layer.weight.data.normal_( + init.normal_( + module.residual_layer.weight, mean=0.0, std=factor * (module.residual_layer.weight.size(-1) ** -0.5), ) if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: - module.residual_layer.bias.data.zero_() + init.zeros_(module.residual_layer.bias) - module.output_layer.weight.data.normal_( - mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5) + init.normal_( + module.output_layer.weight, + mean=0.0, + std=factor * (module.output_layer.weight.size(-1) ** -0.5), ) if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: - module.output_layer.bias.data.zero_() + init.zeros_(module.output_layer.bias) def _validate_input( self, diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 223689d..2cefa5f 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -22,6 +22,7 @@ from transformers.utils.peft_utils import find_adapter_config_file import chronos.chronos2 from chronos.base import BaseChronosPipeline, ForecastType from chronos.chronos2 import Chronos2Model +from chronos.chronos2.model import _TRANSFORMERS_V5 from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray from chronos.df_utils import convert_df_input_to_list_of_dicts_input from chronos.utils import interpolate_quantiles, weighted_quantile @@ -270,7 +271,7 @@ class Chronos2Pipeline(BaseChronosPipeline): per_device_eval_batch_size=batch_size, learning_rate=learning_rate, lr_scheduler_type="linear", - warmup_ratio=0.0, + **({"warmup_steps": 0} if _TRANSFORMERS_V5 else {"warmup_ratio": 0.0}), optim="adamw_torch_fused", logging_strategy="steps", logging_steps=100, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 743ec06..aa9c1b7 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -13,7 +13,8 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from transformers import AutoConfig +from packaging import version +from transformers import AutoConfig, __version__ as transformers_version from transformers.models.t5.modeling_t5 import ( ACT2FN, T5Config, @@ -28,6 +29,29 @@ from .base import BaseChronosPipeline, ForecastType logger = logging.getLogger(__file__) +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0") + +# In transformers v5, use guarded init functions that check _is_hf_initialized +# to avoid re-initializing weights loaded from checkpoint +if _TRANSFORMERS_V5: + from transformers import initialization as init +else: + from torch.nn import init + + +def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack: + """ + Create a T5Stack with the given config and embed_tokens. + + This helper function provides backward compatibility between transformers v4 and v5. + In v4, T5Stack.__init__ accepts (config, embed_tokens). + In v5, T5Stack.__init__ only accepts (config), and embed_tokens must be set separately. + """ + if _TRANSFORMERS_V5: + return T5Stack(config) + else: + return T5Stack(config, embed_tokens) + @dataclass class ChronosBoltConfig: @@ -150,7 +174,15 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): r"output_patch_embedding\.", ] _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore + # In transformers v5, _tied_weights_keys changed from list to dict {target: source} + _tied_weights_keys = ( # type: ignore + { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } + if _TRANSFORMERS_V5 + else ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + ) def __init__(self, config: T5Config): assert hasattr(config, "chronos_config"), "Not a Chronos config file" @@ -188,7 +220,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = _create_t5_stack(encoder_config, self.shared) self._init_decoder(config) @@ -217,25 +249,33 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, (self.__class__)): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) + # Reinitialize quantiles buffer for transformers v5 meta device compatibility + if _TRANSFORMERS_V5: + quantiles = torch.tensor( + module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device + ) + init.copy_(module.quantiles, quantiles) elif isinstance(module, ResidualBlock): - module.hidden_layer.weight.data.normal_( + init.normal_( + module.hidden_layer.weight, mean=0.0, std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), ) if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: - module.hidden_layer.bias.data.zero_() + init.zeros_(module.hidden_layer.bias) - module.residual_layer.weight.data.normal_( + init.normal_( + module.residual_layer.weight, mean=0.0, std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), ) if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: - module.residual_layer.bias.data.zero_() + init.zeros_(module.residual_layer.bias) - module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: - module.output_layer.bias.data.zero_() + init.zeros_(module.output_layer.bias) def encode( self, context: torch.Tensor, mask: Optional[torch.Tensor] = None @@ -359,7 +399,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) + self.decoder = _create_t5_stack(decoder_config, self.shared) def decode( self, diff --git a/test/dummy-chronos-model/model.safetensors b/test/dummy-chronos-model/model.safetensors new file mode 100644 index 0000000..40bef76 Binary files /dev/null and b/test/dummy-chronos-model/model.safetensors differ