mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies (#4301)
* Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead of hidden states [B*T, hidden_dim] from their forward pass. When this happens, chunked_hidden_states_selective_log_softmax tries to compute logits @ lm_head.t() which fails with a shape mismatch. Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies: check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden states are returned, the existing path is taken. When logits are returned, scaling/softcapping/temperature are applied manually and chunked_selective_log_softmax is used instead. Also add chunked_selective_log_softmax to the import from unsloth_zoo. The text-only branch (pixel_values is None) is unchanged. Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant scaling in logits fallback path When COMPILE_DISABLE=1 and the model returns logits directly, scaling and softcapping are already applied by the model forward. Only temperature (a GRPO training parameter) needs to be applied. * Pass temperature to chunked_selective_log_softmax instead of manual cast Use the new temperature parameter in chunked_selective_log_softmax (added in companion zoo PR) to avoid casting the entire logits tensor to float32 before the function call. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
356538d760
commit
11449208f4
1 changed files with 38 additions and 12 deletions
|
|
@ -26,7 +26,11 @@ import torch
|
|||
import inspect
|
||||
import linecache
|
||||
from collections import defaultdict
|
||||
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
|
||||
from unsloth_zoo.rl_replacements import (
|
||||
RL_REPLACEMENTS,
|
||||
left_pack_padding,
|
||||
chunked_selective_log_softmax,
|
||||
)
|
||||
from unsloth_zoo.utils import Version
|
||||
from trl import __version__ as trl_version_raw
|
||||
from importlib.metadata import version as importlib_version
|
||||
|
|
@ -859,6 +863,18 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
|
|||
:, -(logits_to_keep + max_left_pad + 1) :, :
|
||||
]
|
||||
logits_chunk = logits_chunk[:, :-1, :]
|
||||
logprobs_chunk = (
|
||||
chunked_hidden_states_selective_log_softmax(
|
||||
logits_chunk,
|
||||
lm_head,
|
||||
completion_input_ids_chunk,
|
||||
chunks = input_ids_chunk.shape[0] * multiplier,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
logit_softcapping = logit_softcapping,
|
||||
temperature = temperature,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Essentially, for VLMs we do not go via the optimized path in models/,
|
||||
# so we don't encounter the Flash Attn left-padding issue.
|
||||
|
|
@ -876,17 +892,27 @@ def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
|
|||
completion_input_ids_chunk = input_ids_chunk[
|
||||
:, -logits_to_keep:
|
||||
]
|
||||
|
||||
logprobs_chunk = chunked_hidden_states_selective_log_softmax(
|
||||
logits_chunk,
|
||||
lm_head,
|
||||
completion_input_ids_chunk,
|
||||
chunks = input_ids_chunk.shape[0] * multiplier,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
logit_softcapping = logit_softcapping,
|
||||
temperature = temperature,
|
||||
)
|
||||
# Guard: check if model returned hidden states or logits
|
||||
if logits_chunk.shape[-1] == lm_head.shape[1]:
|
||||
logprobs_chunk = (
|
||||
chunked_hidden_states_selective_log_softmax(
|
||||
logits_chunk,
|
||||
lm_head,
|
||||
completion_input_ids_chunk,
|
||||
chunks = input_ids_chunk.shape[0] * multiplier,
|
||||
logit_scale_multiply = logit_scale_multiply,
|
||||
logit_scale_divide = logit_scale_divide,
|
||||
logit_softcapping = logit_softcapping,
|
||||
temperature = temperature,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Model returned logits directly - scaling/softcapping already applied by model forward
|
||||
logprobs_chunk = chunked_selective_log_softmax(
|
||||
logits_chunk,
|
||||
completion_input_ids_chunk,
|
||||
temperature,
|
||||
)
|
||||
# This is needed to avoid race conditions with GPT OSS offload_embbed=True
|
||||
# However, it seems that this line does not slow down or disrupt models.
|
||||
device_synchronize()
|
||||
|
|
|
|||
Loading…
Reference in a new issue