mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Update gemma.py
This commit is contained in:
parent
440c29273f
commit
cb193f7e0a
1 changed files with 2 additions and 1 deletions
|
|
@ -357,7 +357,8 @@ class FastGemmaModel(FastLlamaModel):
|
|||
if isinstance(module, GemmaRMSNorm):
|
||||
# Must be in float32
|
||||
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
|
||||
module = module.to(torch.float32)
|
||||
# module = module.to(torch.float32)
|
||||
# Don't convert to float32 since error analysis shows it makes it worse!!
|
||||
module.weight += 1.0 # return output * (1 + self.weight)
|
||||
if not hasattr(module, "variance_epsilon"):
|
||||
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
|
||||
|
|
|
|||
Loading…
Reference in a new issue