From 16a91d5b584f870ec76315d9c42823a7fa8f4ecd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 9 Jan 2026 19:19:41 +0100 Subject: [PATCH] skip copy_ for v4 --- src/chronos/chronos2/model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 4a9ff4c..c96d0bc 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -16,7 +16,9 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput -if version.parse(transformers_version) >= version.parse("5.0.0.dev0"): +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0") + +if _TRANSFORMERS_V5: from transformers import initialization as init else: from torch.nn import init @@ -299,10 +301,11 @@ class Chronos2Model(PreTrainedModel): init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) elif isinstance(module, Chronos2Model): init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) - quantiles = torch.tensor( - module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device - ) - init.copy_(module.quantiles, quantiles) + if _TRANSFORMERS_V5: + quantiles = torch.tensor( + module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device + ) + init.copy_(module.quantiles, quantiles) elif isinstance(module, ResidualBlock): init.normal_( module.hidden_layer.weight,