Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies (#4301)

* Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies

VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead
of hidden states [B*T, hidden_dim] from their forward pass. When this
happens, chunked_hidden_states_selective_log_softmax tries to compute
logits @ lm_head.t() which fails with a shape mismatch.

Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies:
check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden
states are returned, the existing path is taken. When logits are returned,
scaling/softcapping/temperature are applied manually and
chunked_selective_log_softmax is used instead.

Also add chunked_selective_log_softmax to the import from unsloth_zoo.

The text-only branch (pixel_values is None) is unchanged.

Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove redundant scaling in logits fallback path

When COMPILE_DISABLE=1 and the model returns logits directly, scaling
and softcapping are already applied by the model forward. Only
temperature (a GRPO training parameter) needs to be applied.

* Pass temperature to chunked_selective_log_softmax instead of manual cast

Use the new temperature parameter in chunked_selective_log_softmax
(added in companion zoo PR) to avoid casting the entire logits tensor
to float32 before the function call.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-03-16 03:54:16 -07:00 committed by GitHub
parent 356538d760
commit 11449208f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,7 +26,11 @@ import torch
import inspect
import linecache
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
from unsloth_zoo.rl_replacements import (
RL_REPLACEMENTS,
left_pack_padding,
chunked_selective_log_softmax,
)
from unsloth_zoo.utils import Version
from trl import __version__ as trl_version_raw
from importlib.metadata import version as importlib_version
@ -859,6 +863,18 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
:, -(logits_to_keep + max_left_pad + 1) :, :
]
logits_chunk = logits_chunk[:, :-1, :]
logprobs_chunk = (
chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
)
else:
# Essentially, for VLMs we do not go via the optimized path in models/,
# so we don't encounter the Flash Attn left-padding issue.
@ -876,17 +892,27 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
completion_input_ids_chunk = input_ids_chunk[
:, -logits_to_keep:
]
logprobs_chunk = chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
# Guard: check if model returned hidden states or logits
if logits_chunk.shape[-1] == lm_head.shape[1]:
logprobs_chunk = (
chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
)
else:
# Model returned logits directly - scaling/softcapping already applied by model forward
logprobs_chunk = chunked_selective_log_softmax(
logits_chunk,
completion_input_ids_chunk,
temperature,
)
# This is needed to avoid race conditions with GPT OSS offload_embbed=True
# However, it seems that this line does not slow down or disrupt models.
device_synchronize()