diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 84e9405..5b68029 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -250,6 +250,12 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): factor = self.config.initializer_factor if isinstance(module, (self.__class__)): init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) + # Reinitialize quantiles buffer for transformers v5 meta device compatibility + if _TRANSFORMERS_V5: + quantiles = torch.tensor( + module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device + ) + init.copy_(module.quantiles, quantiles) elif isinstance(module, ResidualBlock): init.normal_( module.hidden_layer.weight,