use maybe_autocast

This commit is contained in:
Kashif Rasul 2026-04-14 11:46:12 +02:00
parent 9ee10fc9b6
commit 8f249a961e

View file

@ -10,6 +10,7 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils.generic import maybe_autocast
from transformers.utils import ModelOutput
from .config import Chronos2CoreConfig
@ -73,7 +74,7 @@ class Chronos2RotaryEmbedding(nn.Module):
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
with maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling