skip copy_ for v4

This commit is contained in:
Kashif Rasul 2026-01-09 19:19:41 +01:00
parent f29ff338e5
commit 16a91d5b58

View file

@ -16,7 +16,9 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
if version.parse(transformers_version) >= version.parse("5.0.0.dev0"):
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")
if _TRANSFORMERS_V5:
from transformers import initialization as init
else:
from torch.nn import init
@ -299,10 +301,11 @@ class Chronos2Model(PreTrainedModel):
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, Chronos2Model):
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
if _TRANSFORMERS_V5:
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
elif isinstance(module, ResidualBlock):
init.normal_(
module.hidden_layer.weight,