mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
skip copy_ for v4
This commit is contained in:
parent
f29ff338e5
commit
16a91d5b58
1 changed files with 8 additions and 5 deletions
|
|
@ -16,7 +16,9 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
from transformers.utils import ModelOutput
|
||||
|
||||
|
||||
if version.parse(transformers_version) >= version.parse("5.0.0.dev0"):
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")
|
||||
|
||||
if _TRANSFORMERS_V5:
|
||||
from transformers import initialization as init
|
||||
else:
|
||||
from torch.nn import init
|
||||
|
|
@ -299,10 +301,11 @@ class Chronos2Model(PreTrainedModel):
|
|||
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
|
||||
elif isinstance(module, Chronos2Model):
|
||||
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
if _TRANSFORMERS_V5:
|
||||
quantiles = torch.tensor(
|
||||
module.chronos_config.quantiles, dtype=module.dtype, device=module.quantiles.device
|
||||
)
|
||||
init.copy_(module.quantiles, quantiles)
|
||||
elif isinstance(module, ResidualBlock):
|
||||
init.normal_(
|
||||
module.hidden_layer.weight,
|
||||
|
|
|
|||
Loading…
Reference in a new issue