fix bolt buffer for transformers v5

This commit is contained in:
Kashif Rasul 2026-01-13 09:55:35 +01:00
parent f669e6c5ca
commit dc0c9a1b5b

View file

@ -250,6 +250,12 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
# Reinitialize quantiles buffer for transformers v5 meta device compatibility
if _TRANSFORMERS_V5:
quantiles = torch.tensor(
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
)
init.copy_(module.quantiles, quantiles)
elif isinstance(module, ResidualBlock):
init.normal_(
module.hidden_layer.weight,