mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Merge 0db7dc1bdb into 32111085d8
This commit is contained in:
commit
fe9de3302a
8 changed files with 152 additions and 55 deletions
|
|
@ -15,8 +15,8 @@ license = { file = "LICENSE" }
|
|||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.2,<3",
|
||||
"transformers>=4.41,<5",
|
||||
"accelerate>=0.34,<2",
|
||||
"transformers>=4.41",
|
||||
"accelerate>=1.1.0",
|
||||
"numpy>=1.21,<3",
|
||||
"einops>=0.7.0,<1",
|
||||
"scikit-learn>=1.6.0,<2",
|
||||
|
|
@ -41,14 +41,14 @@ path = "src/chronos/__about__.py"
|
|||
[project.optional-dependencies]
|
||||
extras = [
|
||||
"boto3>=1.10,<2",
|
||||
"peft>=0.13.0,<0.18",
|
||||
"peft>=0.18.1",
|
||||
"fev>=0.6.1",
|
||||
"pandas[pyarrow]>=2.0,<2.4",
|
||||
]
|
||||
test = [
|
||||
"pytest~=8.0",
|
||||
"boto3>=1.10,<2",
|
||||
"peft>=0.13.0,<1",
|
||||
"peft>=0.18.1",
|
||||
"fev>=0.6.1",
|
||||
"pandas[pyarrow]>=2.0,<2.4",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, get_worker_info
|
||||
import transformers
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForCausalLM,
|
||||
|
|
@ -46,6 +47,7 @@ from gluonts.transform import (
|
|||
|
||||
from chronos import ChronosConfig, ChronosTokenizer
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers.__version__) >= version.parse("5.0.0")
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
|
||||
|
|
@ -661,7 +663,7 @@ def main(
|
|||
per_device_train_batch_size=per_device_train_batch_size,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
warmup_ratio=warmup_ratio,
|
||||
**({"warmup_steps": round(warmup_ratio * max_steps)} if _TRANSFORMERS_V5 else {"warmup_ratio": warmup_ratio}),
|
||||
optim=optim,
|
||||
logging_strategy="steps",
|
||||
logging_steps=log_steps,
|
||||
|
|
|
|||
|
|
@ -362,9 +362,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
|
||||
from transformers import AutoConfig
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", "auto")
|
||||
if torch_dtype != "auto" and isinstance(torch_dtype, str):
|
||||
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
|
||||
# Handle both torch_dtype (deprecated) and dtype arguments
|
||||
dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto")
|
||||
if dtype_value != "auto" and isinstance(dtype_value, str):
|
||||
dtype_value = cls.dtypes[dtype_value]
|
||||
kwargs["dtype"] = dtype_value
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")
|
||||
|
|
|
|||
|
|
@ -10,42 +10,75 @@ from einops import rearrange
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.utils.generic import maybe_autocast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from .config import Chronos2CoreConfig
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
class Chronos2RotaryEmbedding(nn.Module):
|
||||
"""Applies rotary position embeddings (RoPE) to input tensors.
|
||||
|
||||
This implementation follows the transformers v5 pattern for RotaryEmbedding classes,
|
||||
which enables automatic buffer reinitialization via the base `_init_weights` method.
|
||||
|
||||
Implementation adapted from:
|
||||
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L95
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base: float = 10000):
|
||||
inv_freq: torch.Tensor # type hint for register_buffer
|
||||
|
||||
def __init__(self, config: Chronos2CoreConfig, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
self.inv_freq: torch.Tensor # type hint for type checker
|
||||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||
self.config = config
|
||||
self.rope_type = "default"
|
||||
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: Chronos2CoreConfig,
|
||||
device: torch.device | None = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""
|
||||
Computes the inverse frequencies for RoPE embeddings.
|
||||
|
||||
Args:
|
||||
config: The model configuration containing rope_theta and d_kv.
|
||||
device: The device to use for initialization.
|
||||
seq_len: Unused, kept for API compatibility with transformers.
|
||||
|
||||
Returns:
|
||||
Tuple of (inv_freq tensor, attention_scaling factor).
|
||||
"""
|
||||
base = config.rope_theta
|
||||
dim = config.d_kv
|
||||
|
||||
attention_factor = 1.0 # Unused in default RoPE
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
self.inv_freq.to(x.device)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
with maybe_autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -78,8 +111,8 @@ class RoPE(nn.Module):
|
|||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (RoPE.rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (RoPE.rotate_half(k) * sin)
|
||||
q_embed = (q * cos) + (Chronos2RotaryEmbedding.rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (Chronos2RotaryEmbedding.rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
|
@ -164,7 +197,7 @@ class MHA(nn.Module):
|
|||
|
||||
self.use_rope = use_rope
|
||||
if use_rope:
|
||||
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
|
||||
self.rope_embed = Chronos2RotaryEmbedding(config=config)
|
||||
|
||||
def _eager_attention(
|
||||
self,
|
||||
|
|
@ -277,7 +310,7 @@ class MHA(nn.Module):
|
|||
value_states = shape(self.v(hidden_states))
|
||||
if self.use_rope:
|
||||
cos, sin = self.rope_embed(value_states, position_ids)
|
||||
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
query_states, key_states = Chronos2RotaryEmbedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if attn_implementation == "sdpa":
|
||||
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,19 @@ from typing import cast
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
|
||||
|
||||
if _TRANSFORMERS_V5:
|
||||
from transformers import initialization as init
|
||||
else:
|
||||
from torch.nn import init
|
||||
|
||||
from chronos.chronos_bolt import InstanceNorm, Patch
|
||||
|
||||
from .config import Chronos2CoreConfig, Chronos2ForecastingConfig
|
||||
|
|
@ -268,49 +278,58 @@ class Chronos2Model(PreTrainedModel):
|
|||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, Chronos2LayerNorm):
|
||||
module.weight.data.fill_(factor * 1.0)
|
||||
init.constant_(module.weight, factor * 1.0)
|
||||
elif isinstance(module, MLP):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
||||
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi, "bias") and module.wi.bias is not None:
|
||||
module.wi.bias.data.zero_()
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
init.zeros_(module.wi.bias)
|
||||
init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
init.zeros_(module.wo.bias)
|
||||
elif isinstance(module, MHA):
|
||||
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
||||
d_model = self.config.d_model
|
||||
kv_proj_dim = self.config.d_kv
|
||||
n_heads = self.config.num_heads
|
||||
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
|
||||
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
|
||||
elif isinstance(module, (Chronos2Model)):
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
|
||||
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
|
||||
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
|
||||
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
|
||||
elif isinstance(module, Chronos2Model):
|
||||
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
|
||||
if _TRANSFORMERS_V5:
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
module.hidden_layer.weight.data.normal_(
|
||||
init.normal_(
|
||||
module.hidden_layer.weight,
|
||||
mean=0.0,
|
||||
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
|
||||
)
|
||||
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
|
||||
module.hidden_layer.bias.data.zero_()
|
||||
init.zeros_(module.hidden_layer.bias)
|
||||
|
||||
module.residual_layer.weight.data.normal_(
|
||||
init.normal_(
|
||||
module.residual_layer.weight,
|
||||
mean=0.0,
|
||||
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
|
||||
)
|
||||
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
|
||||
module.residual_layer.bias.data.zero_()
|
||||
init.zeros_(module.residual_layer.bias)
|
||||
|
||||
module.output_layer.weight.data.normal_(
|
||||
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
|
||||
init.normal_(
|
||||
module.output_layer.weight,
|
||||
mean=0.0,
|
||||
std=factor * (module.output_layer.weight.size(-1) ** -0.5),
|
||||
)
|
||||
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
|
||||
module.output_layer.bias.data.zero_()
|
||||
init.zeros_(module.output_layer.bias)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from transformers.utils.peft_utils import find_adapter_config_file
|
|||
import chronos.chronos2
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.chronos2 import Chronos2Model
|
||||
from chronos.chronos2.model import _TRANSFORMERS_V5
|
||||
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray
|
||||
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
|
||||
from chronos.utils import interpolate_quantiles, weighted_quantile
|
||||
|
|
@ -270,7 +271,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
per_device_eval_batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type="linear",
|
||||
warmup_ratio=0.0,
|
||||
**({"warmup_steps": 0} if _TRANSFORMERS_V5 else {"warmup_ratio": 0.0}),
|
||||
optim="adamw_torch_fused",
|
||||
logging_strategy="steps",
|
||||
logging_steps=100,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
from packaging import version
|
||||
from transformers import AutoConfig, __version__ as transformers_version
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
ACT2FN,
|
||||
T5Config,
|
||||
|
|
@ -28,6 +29,29 @@ from .base import BaseChronosPipeline, ForecastType
|
|||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
|
||||
|
||||
# In transformers v5, use guarded init functions that check _is_hf_initialized
|
||||
# to avoid re-initializing weights loaded from checkpoint
|
||||
if _TRANSFORMERS_V5:
|
||||
from transformers import initialization as init
|
||||
else:
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack:
|
||||
"""
|
||||
Create a T5Stack with the given config and embed_tokens.
|
||||
|
||||
This helper function provides backward compatibility between transformers v4 and v5.
|
||||
In v4, T5Stack.__init__ accepts (config, embed_tokens).
|
||||
In v5, T5Stack.__init__ only accepts (config), and embed_tokens must be set separately.
|
||||
"""
|
||||
if _TRANSFORMERS_V5:
|
||||
return T5Stack(config)
|
||||
else:
|
||||
return T5Stack(config, embed_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosBoltConfig:
|
||||
|
|
@ -150,7 +174,15 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
r"output_patch_embedding\.",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
|
||||
# In transformers v5, _tied_weights_keys changed from list to dict {target: source}
|
||||
_tied_weights_keys = ( # type: ignore
|
||||
{
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
if _TRANSFORMERS_V5
|
||||
else ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
)
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
||||
|
|
@ -188,7 +220,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
self.encoder = _create_t5_stack(encoder_config, self.shared)
|
||||
|
||||
self._init_decoder(config)
|
||||
|
||||
|
|
@ -217,25 +249,33 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, (self.__class__)):
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
|
||||
# Reinitialize quantiles buffer for transformers v5 meta device compatibility
|
||||
if _TRANSFORMERS_V5:
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
module.hidden_layer.weight.data.normal_(
|
||||
init.normal_(
|
||||
module.hidden_layer.weight,
|
||||
mean=0.0,
|
||||
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
||||
)
|
||||
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
|
||||
module.hidden_layer.bias.data.zero_()
|
||||
init.zeros_(module.hidden_layer.bias)
|
||||
|
||||
module.residual_layer.weight.data.normal_(
|
||||
init.normal_(
|
||||
module.residual_layer.weight,
|
||||
mean=0.0,
|
||||
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
|
||||
)
|
||||
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
|
||||
module.residual_layer.bias.data.zero_()
|
||||
init.zeros_(module.residual_layer.bias)
|
||||
|
||||
module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
|
||||
module.output_layer.bias.data.zero_()
|
||||
init.zeros_(module.output_layer.bias)
|
||||
|
||||
def encode(
|
||||
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
|
|
@ -359,7 +399,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = T5Stack(decoder_config, self.shared)
|
||||
self.decoder = _create_t5_stack(decoder_config, self.shared)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
|
|
|
|||
BIN
test/dummy-chronos-model/model.safetensors
Normal file
BIN
test/dummy-chronos-model/model.safetensors
Normal file
Binary file not shown.
Loading…
Reference in a new issue