mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Update llama.py
This commit is contained in:
parent
2eb950872a
commit
316aaefdf2
1 changed files with 9 additions and 3 deletions
|
|
@ -897,9 +897,15 @@ def CausalLM_fast_forward(fast_forward_inference):
|
|||
logit_softcapping = logit_softcapping,
|
||||
)
|
||||
elif logit_softcapping != 0:
|
||||
logits *= (1.0 / logit_softcapping)
|
||||
logits = torch.tanh(logits, out = logits if not logits.requires_grad else None)
|
||||
logits *= logit_softcapping
|
||||
if logits.requires_grad:
|
||||
logits = (1.0 / logit_softcapping) * logits
|
||||
logits = torch.tanh(logits)
|
||||
logits = logit_softcapping * logits
|
||||
else:
|
||||
logits *= (1.0 / logit_softcapping)
|
||||
torch.tanh(logits, out = logits)
|
||||
logits *= logit_softcapping
|
||||
pass
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
|||
Loading…
Reference in a new issue