Update llama.py

This commit is contained in:
Daniel Han 2024-07-08 10:44:19 -07:00
parent 2eb950872a
commit 316aaefdf2

View file

@ -897,9 +897,15 @@ def CausalLM_fast_forward(fast_forward_inference):
logit_softcapping = logit_softcapping,
)
elif logit_softcapping != 0:
logits *= (1.0 / logit_softcapping)
logits = torch.tanh(logits, out = logits if not logits.requires_grad else None)
logits *= logit_softcapping
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
pass
if not return_dict: