mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
fix bolt buffer for transformers v5
This commit is contained in:
parent
f669e6c5ca
commit
dc0c9a1b5b
1 changed files with 6 additions and 0 deletions
|
|
@ -250,6 +250,12 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
factor = self.config.initializer_factor
|
||||
if isinstance(module, (self.__class__)):
|
||||
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
|
||||
# Reinitialize quantiles buffer for transformers v5 meta device compatibility
|
||||
if _TRANSFORMERS_V5:
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
init.normal_(
|
||||
module.hidden_layer.weight,
|
||||
|
|
|
|||
Loading…
Reference in a new issue