Update gemma.py

This commit is contained in:
Daniel Han-Chen 2024-03-04 16:14:50 +11:00
parent 440c29273f
commit cb193f7e0a

View file

@ -357,7 +357,8 @@ class FastGemmaModel(FastLlamaModel):
if isinstance(module, GemmaRMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
module = module.to(torch.float32)
# module = module.to(torch.float32)
# Don't convert to float32 since error analysis shows it makes it worse!!
module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon