Remove Gemma-4 from FORCE_FLOAT32 (#4875)

Gemma-4 does not need FORCE_FLOAT32. Testing shows that both float16 and
bfloat16 work correctly without the forced float32 override:

- Inference: identical outputs for float16 and bfloat16 (greedy decoding)
- Training (100 steps, 4-bit LoRA, SFT on FineTome-100k):
  - float16 final loss: 3.048
  - bfloat16 final loss: 3.065
  - Losses converge to within 0.02 by step 60
  - Grad norms healthy and comparable for both dtypes

The FORCE_FLOAT32 path was actually causing training divergence. With
it enabled, the compiled float32 run diverged at step ~28 with grad norms
collapsing to near zero and loss plateauing at ~12.4. Without it, both
dtypes train normally.

This enables float16 on Tesla T4 and other GPUs without bfloat16 support.
This commit is contained in:
Daniel Han 2026-04-06 07:33:28 -07:00 committed by GitHub
parent ab65b47c73
commit 07b6fcc344
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -108,8 +108,6 @@ FORCE_FLOAT32 = [
"gemma3n",
"gpt_oss",
"qwen3_5", # Qwen3.5 GDN layers produce NaN grad norms in float16 training
"gemma4,", # Add comma bc gemma4 will match gemma4_text
"gemma4_text",
]
global DISABLE_COMPILE_MODEL_NAMES