mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
use maybe_autocast
This commit is contained in:
parent
9ee10fc9b6
commit
8f249a961e
1 changed files with 2 additions and 1 deletions
|
|
@ -10,6 +10,7 @@ from einops import rearrange
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.utils.generic import maybe_autocast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from .config import Chronos2CoreConfig
|
||||
|
|
@ -73,7 +74,7 @@ class Chronos2RotaryEmbedding(nn.Module):
|
|||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
with maybe_autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
|
|
|
|||
Loading…
Reference in a new issue