diff --git a/src/chronos/chronos2/layers.py b/src/chronos/chronos2/layers.py index 15b3e0a..a632034 100644 --- a/src/chronos/chronos2/layers.py +++ b/src/chronos/chronos2/layers.py @@ -10,6 +10,7 @@ from einops import rearrange from torch import nn from transformers.activations import ACT2FN from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils.generic import maybe_autocast from transformers.utils import ModelOutput from .config import Chronos2CoreConfig @@ -73,7 +74,7 @@ class Chronos2RotaryEmbedding(nn.Module): # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + with maybe_autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling