Apply use_reentrant removal to all TRL trainer configs, not just GRPO

The existing fix that removes use_reentrant=False from
gradient_checkpointing_kwargs was gated behind RLConfig_name ==
"GRPOConfig", so only GRPOConfig was protected. SFTConfig, DPOConfig,
KTOConfig, CPOConfig, ORPOConfig etc. were all still affected.

Remove the GRPOConfig guard so the fix applies to all compiled trainer
configs when TRL >= 0.27.0.

This is defense-in-depth alongside the unsloth_zoo fix that forces
use_reentrant=True in unsloth_checkpoint() itself.
This commit is contained in:
Daniel Han 2026-03-16 10:17:15 +00:00 committed by Daniel Han
parent ec9a0906eb
commit 356538d760

View file

@ -1232,7 +1232,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Unsloth gradient checkpointing requires use_reentrant=True, so we remove
# the setting after super().__init__() when it gets auto-applied.
RLConfig_post = ""
if trl_version >= Version("0.27.0") and RLConfig_name == "GRPOConfig":
if trl_version >= Version("0.27.0"):
RLConfig_post = (
" # Unsloth: Remove use_reentrant=False forced by TRL 0.27.0+\n"
" if getattr(self, 'gradient_checkpointing_kwargs', None) is not None:\n"