mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
qwen3.6 patches for multi-turn chat (#5083)
* qwen3.6 patches for multi-turn chat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f1041f885f
commit
013c99e51b
2 changed files with 1791 additions and 0 deletions
1653
unsloth/models/patches/mlx_vlm_qwen3_5/generate.py
Normal file
1653
unsloth/models/patches/mlx_vlm_qwen3_5/generate.py
Normal file
File diff suppressed because it is too large
Load diff
138
unsloth/models/patches/mlx_vlm_qwen3_5/qwen3_5.py
Normal file
138
unsloth/models/patches/mlx_vlm_qwen3_5/qwen3_5.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..base import InputEmbeddingsFeatures
|
||||
from ..qwen3_vl import Model as Qwen3VLModel
|
||||
from ..qwen3_vl import processing_qwen3_vl # noqa: F401
|
||||
from ..qwen3_vl.qwen3_vl import masked_scatter
|
||||
from .config import ModelConfig
|
||||
from .language import LanguageModel
|
||||
from .vision import VisionModel
|
||||
|
||||
|
||||
class Model(Qwen3VLModel):
|
||||
def __init__(self, config: ModelConfig):
|
||||
# only initialize nn.Module, skip the initialization of vision_tower and language_model in the parent class
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.language_model = LanguageModel(config.text_config, config)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
):
|
||||
image_grid_thw = kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = kwargs.get("video_grid_thw", None)
|
||||
mask = kwargs.get("mask", None)
|
||||
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
||||
|
||||
if pixel_values is None:
|
||||
return InputEmbeddingsFeatures(
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
)
|
||||
|
||||
dtype = self.vision_tower.patch_embed.proj.weight.dtype
|
||||
pixel_values = pixel_values.astype(dtype)
|
||||
|
||||
# Get the input embeddings from the language model
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
cached = kwargs.get("cached_image_features", None)
|
||||
if cached is not None:
|
||||
hidden_states = cached
|
||||
else:
|
||||
# Get the ouptut hidden states from the vision model
|
||||
hidden_states, _ = self.vision_tower(pixel_values, grid_thw)
|
||||
|
||||
# Insert special image tokens in the input_ids
|
||||
inputs_embeds, _ = self.merge_input_ids_with_image_features(
|
||||
hidden_states,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
self.config.image_token_index,
|
||||
self.config.video_token_index,
|
||||
)
|
||||
|
||||
# Pre-calculate position_ids for chunked prefill
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
position_ids, rope_deltas = self.language_model.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, mask
|
||||
)
|
||||
self.language_model._position_ids = position_ids
|
||||
self.language_model._rope_deltas = rope_deltas
|
||||
|
||||
return InputEmbeddingsFeatures(
|
||||
inputs_embeds = inputs_embeds,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, image_token_index, video_token_index
|
||||
):
|
||||
special_image_mask = input_ids == image_token_index
|
||||
special_video_mask = input_ids == video_token_index
|
||||
special_image_mask = special_image_mask | special_video_mask
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = special_image_mask[..., None]
|
||||
special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
|
||||
|
||||
n_image_features = image_features.shape[0]
|
||||
n_image_mask_elements = special_image_mask.sum()
|
||||
if n_image_mask_elements != image_features.size:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
inputs_embeds = masked_scatter(
|
||||
inputs_embeds, special_image_mask, image_features
|
||||
)
|
||||
|
||||
return inputs_embeds, special_image_mask
|
||||
|
||||
def sanitize(self, weights):
|
||||
# ignore mtp weights
|
||||
weights = {key: value for key, value in weights.items() if "mtp." not in key}
|
||||
|
||||
if self.config.text_config.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
|
||||
norm_keys = (
|
||||
".input_layernorm.weight",
|
||||
".post_attention_layernorm.weight",
|
||||
"model.norm.weight",
|
||||
".q_norm.weight",
|
||||
".k_norm.weight",
|
||||
)
|
||||
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
if "model" in key:
|
||||
if "model.language_model" in key:
|
||||
key = key.replace("model.language_model", "language_model.model")
|
||||
elif "model.visual" in key:
|
||||
key = key.replace("model.visual", "vision_tower")
|
||||
elif "lm_head" in key:
|
||||
key = key.replace("lm_head", "language_model.lm_head")
|
||||
|
||||
if "conv1d.weight" in key and value.shape[-1] != 1:
|
||||
value = value.moveaxis(2, 1)
|
||||
if any(key.endswith(sfx) for sfx in norm_keys):
|
||||
if value.ndim == 1:
|
||||
value += 1.0
|
||||
|
||||
sanitized_weights[key] = value
|
||||
|
||||
return sanitized_weights
|
||||
|
||||
@property
|
||||
def quant_predicate(self):
|
||||
return self.language_model.quant_predicate
|
||||
|
||||
@property
|
||||
def cast_predicate(self):
|
||||
return self.language_model.cast_predicate
|
||||
Loading…
Reference in a new issue