mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
UI Changes (#4782)
* UI Changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unrelated test file --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3b613eb1e8
commit
f9c4b08726
11 changed files with 267 additions and 6 deletions
|
|
@ -6,6 +6,14 @@
|
|||
import utils.hardware.hardware as hw
|
||||
|
||||
DEFAULT_MODELS_GGUF = [
|
||||
"unsloth/gemma-4-E2B-it-GGUF",
|
||||
"unsloth/gemma-4-E4B-it-GGUF",
|
||||
"unsloth/gemma-4-31B-it-GGUF",
|
||||
"unsloth/gemma-4-26B-A4B-it-GGUF",
|
||||
"unsloth/Qwen3.5-4B-GGUF",
|
||||
"unsloth/Qwen3.5-9B-GGUF",
|
||||
"unsloth/Qwen3.5-35B-A3B-GGUF",
|
||||
"unsloth/Qwen3.5-0.8B-GGUF",
|
||||
"unsloth/Llama-3.2-1B-Instruct-GGUF",
|
||||
"unsloth/Llama-3.2-3B-Instruct-GGUF",
|
||||
"unsloth/Llama-3.1-8B-Instruct-GGUF",
|
||||
|
|
@ -15,6 +23,18 @@ DEFAULT_MODELS_GGUF = [
|
|||
]
|
||||
|
||||
DEFAULT_MODELS_STANDARD = [
|
||||
"unsloth/gemma-4-E2B-it-GGUF",
|
||||
"unsloth/gemma-4-E4B-it-GGUF",
|
||||
"unsloth/gemma-4-31B-it-GGUF",
|
||||
"unsloth/gemma-4-26B-A4B-it-GGUF",
|
||||
"unsloth/Qwen3.5-4B-GGUF",
|
||||
"unsloth/Qwen3.5-9B-GGUF",
|
||||
"unsloth/Qwen3.5-35B-A3B-GGUF",
|
||||
"unsloth/Qwen3.5-0.8B-GGUF",
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
"unsloth/Qwen3-4B-Instruct-2507",
|
||||
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
|
||||
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
|
||||
|
|
|
|||
|
|
@ -810,7 +810,9 @@ class LlamaCppBackend:
|
|||
# Detect tool calling support from chat template
|
||||
tool_markers = [
|
||||
"{%- if tools %}",
|
||||
"{%- if tools -%}",
|
||||
"{% if tools %}",
|
||||
"{% if tools -%}",
|
||||
'"role" == "tool"',
|
||||
"'role' == 'tool'",
|
||||
'message.role == "tool"',
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ git+https://github.com/meta-pytorch/OpenEnv.git
|
|||
# executorch>=1.0.1 # 41.5 MB - no imports in unsloth/zoo/studio
|
||||
torch-c-dlpack-ext
|
||||
sentence_transformers==5.2.0
|
||||
transformers==4.57.6
|
||||
transformers>=4.57.6
|
||||
pytorch_tokenizers
|
||||
kernels
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# Single-env pins for unsloth + studio + data-designer
|
||||
# Keep compatible with unsloth transformers bounds.
|
||||
transformers==4.57.6
|
||||
transformers>=4.57.6
|
||||
trl==0.23.1
|
||||
huggingface-hub==0.36.2
|
||||
huggingface-hub>=0.36.2
|
||||
|
||||
# Studio stack
|
||||
datasets==4.3.0
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ TRANSFORMERS_5_MODEL_SUBSTRINGS: tuple[str, ...] = (
|
|||
"qwen3.5", # Qwen3.5 family (35B-A3B, etc.)
|
||||
"qwen3-next", # Qwen3-Next and variants
|
||||
"tiny_qwen3_moe", # imdatta0/tiny_qwen3_moe_2.8B_0.7B
|
||||
"gemma-4", # Gemma-4 (E2B-it, E4B-it, 31B-it, 26B-A4B-it)
|
||||
"gemma4", # Gemma-4 alternate naming
|
||||
)
|
||||
|
||||
# Tokenizer classes that only exist in transformers>=5.x
|
||||
|
|
@ -58,7 +60,7 @@ _TRANSFORMERS_5_TOKENIZER_CLASSES: set[str] = {
|
|||
_tokenizer_class_cache: dict[str, bool] = {}
|
||||
|
||||
# Versions
|
||||
TRANSFORMERS_5_VERSION = "5.3.0"
|
||||
TRANSFORMERS_5_VERSION = "5.5.0.dev0"
|
||||
TRANSFORMERS_DEFAULT_VERSION = "4.57.6"
|
||||
|
||||
# Pre-installed directory for transformers 5.x — created by setup.sh / setup.ps1
|
||||
|
|
@ -258,7 +260,7 @@ def _purge_modules() -> int:
|
|||
|
||||
_VENV_T5_PACKAGES = (
|
||||
f"transformers=={TRANSFORMERS_5_VERSION}",
|
||||
"huggingface_hub==1.7.1",
|
||||
"huggingface_hub==1.8.0",
|
||||
"hf_xet==1.4.2",
|
||||
"tiktoken",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -141,6 +141,10 @@ export const MODEL_TYPE_TO_HF_TASK: Record<ModelType, PipelineType> = {
|
|||
|
||||
|
||||
export const PRIORITY_TRAINING_MODELS: readonly string[] = [
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
"unsloth/Qwen3.5-2B",
|
||||
"unsloth/Qwen3.5-9B",
|
||||
"unsloth/gpt-oss-20b",
|
||||
|
|
|
|||
|
|
@ -863,6 +863,114 @@ DEFAULT_SYSTEM_MESSAGE["gemma-3n"] = None # No system message in Gemma-3n
|
|||
CHAT_TEMPLATES["gemma3n"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n
|
||||
|
||||
# =========================================== Gemma-4
|
||||
# Gemma-4 uses <|turn>role\n...<turn|>\n format
|
||||
gemma4_template = \
|
||||
"""{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'audio' -%}
|
||||
{{ '<|audio|>' }}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{ '<|image|>' }}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{ '<|video|>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
{{ '<turn|>\n' }}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<|turn>model\n'}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
try:
|
||||
gemma4_ollama = _ollama_template("gemma-4")
|
||||
except KeyError:
|
||||
gemma4_ollama = ""
|
||||
gemma4_template_eos_token = "<turn|>"
|
||||
CHAT_TEMPLATES["gemma-4"] = (gemma4_template, gemma4_template_eos_token, False, gemma4_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma-4"] = None
|
||||
|
||||
CHAT_TEMPLATES["gemma4"] = (gemma4_template, gemma4_template_eos_token, False, gemma4_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma4"] = None
|
||||
|
||||
# Gemma-4 with empty thought channel (required for larger models like 31B, 26B-A4B)
|
||||
# Injects <|channel>thought\n<channel|> at the start of each model response during training
|
||||
gemma4_thinking_template = \
|
||||
"""{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||
{%- if role == "model" -%}
|
||||
{{ '<|channel>thought\n<channel|>' }}
|
||||
{%- endif -%}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'audio' -%}
|
||||
{{ '<|audio|>' }}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{ '<|image|>' }}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{ '<|video|>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
{{ '<turn|>\n' }}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<|turn>model\n'}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
CHAT_TEMPLATES["gemma-4-thinking"] = (gemma4_thinking_template, gemma4_template_eos_token, False, gemma4_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma-4-thinking"] = None
|
||||
|
||||
CHAT_TEMPLATES["gemma4-thinking"] = (gemma4_thinking_template, gemma4_template_eos_token, False, gemma4_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma4-thinking"] = None
|
||||
|
||||
# =========================================== GPT-OSS
|
||||
# Obtained via
|
||||
# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
|
||||
|
|
|
|||
|
|
@ -537,6 +537,15 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
# Gemma4 It is strongly recommended to train Gemma4 models with the `eager`
|
||||
try:
|
||||
from transformers.models.gemma4.modeling_gemma4 import logger as gemma4_logger
|
||||
|
||||
gemma4_logger.addFilter(HideLoggingMessage("strongly recommended"))
|
||||
del gemma4_logger
|
||||
except:
|
||||
pass
|
||||
|
||||
# Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.
|
||||
try:
|
||||
from huggingface_hub.file_download import logger as hub_logger
|
||||
|
|
@ -1930,6 +1939,18 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
|
|||
_has_ccm = _mod is not None and hasattr(_mod, "create_causal_mask_mapping")
|
||||
if _has_ccm and _inner.training:
|
||||
inputs["token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
# Gemma4 uses mm_token_type_ids (not token_type_ids) for VLM masking
|
||||
if "mm_token_type_ids" not in inputs and "input_ids" in inputs:
|
||||
_inner = model
|
||||
for _attr in ("base_model", "model", "model"):
|
||||
_inner = getattr(_inner, _attr, _inner)
|
||||
if getattr(getattr(_inner, "config", None), "model_type", "") in ("gemma4",):
|
||||
import sys as _sys
|
||||
|
||||
_mod = _sys.modules.get(type(_inner).__module__)
|
||||
_has_ccm = _mod is not None and hasattr(_mod, "create_causal_mask_mapping")
|
||||
if _has_ccm and _inner.training:
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
|
||||
outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
|
|||
SUPPORTS_FALCON_H1 = transformers_version >= Version("4.53.0")
|
||||
SUPPORTS_GEMMA3N = transformers_version >= Version("4.53.0")
|
||||
SUPPORTS_GPTOSS = transformers_version >= Version("4.55.0")
|
||||
SUPPORTS_GEMMA4 = transformers_version >= Version("5.5.0.dev0")
|
||||
# Transformers v5 meta-device loading corrupts non-persistent buffers (inv_freq).
|
||||
# See _fix_rope_inv_freq() below for details.
|
||||
_NEEDS_ROPE_FIX = transformers_version >= Version("5.0.0")
|
||||
|
|
@ -107,6 +108,8 @@ FORCE_FLOAT32 = [
|
|||
"gemma3n",
|
||||
"gpt_oss",
|
||||
"qwen3_5", # Qwen3.5 GDN layers produce NaN grad norms in float16 training
|
||||
"gemma4,", # Add comma bc gemma4 will match gemma4_text
|
||||
"gemma4_text",
|
||||
]
|
||||
|
||||
global DISABLE_COMPILE_MODEL_NAMES
|
||||
|
|
@ -1130,6 +1133,17 @@ class FastModel(FastBaseModel):
|
|||
raise RuntimeError(
|
||||
"Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST
|
||||
)
|
||||
# Gemma 4 must be before Gemma 3N and Gemma 3
|
||||
elif "gemma4" in model_types_all:
|
||||
if not SUPPORTS_GEMMA4:
|
||||
raise RuntimeError(
|
||||
"Unsloth: Gemma 4 requires transformers >= 5.5.0" + LATEST
|
||||
)
|
||||
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
|
||||
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
|
||||
# Disable flex_attention for Gemma-4: flex compile overhead is 2.7x slower
|
||||
# than SDPA. Our attention patch ensures Q/K/V dtype alignment for SDPA.
|
||||
os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "0"
|
||||
# Gemma 3N must be before Gemma 3
|
||||
elif "gemma3n" in model_types_all:
|
||||
if transformers_version < Version("4.53.0"):
|
||||
|
|
|
|||
|
|
@ -216,6 +216,10 @@ def unsloth_base_fast_generate(
|
|||
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
kwargs["pixel_values_videos"] = kwargs["pixel_values_videos"].to(dtype)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Mixed precision autocast
|
||||
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
|
||||
|
|
@ -1029,6 +1033,15 @@ class FastBaseModel:
|
|||
f"Unsloth: Warning - VLM processor fallback returned None for model_type={model_type_arch}",
|
||||
file = sys.stderr,
|
||||
)
|
||||
# Backwards compat: if processor has no chat_template (e.g. old saves without
|
||||
# chat_template.jinja) but the inner tokenizer does, copy it to the processor.
|
||||
if (
|
||||
hasattr(tokenizer, "tokenizer")
|
||||
and getattr(tokenizer, "chat_template", None) is None
|
||||
and getattr(tokenizer.tokenizer, "chat_template", None) is not None
|
||||
):
|
||||
tokenizer.chat_template = tokenizer.tokenizer.chat_template
|
||||
|
||||
if hasattr(tokenizer, "tokenizer"):
|
||||
__tokenizer = tokenizer.tokenizer
|
||||
# Add padding side as well
|
||||
|
|
@ -1285,7 +1298,59 @@ class FastBaseModel:
|
|||
model,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
)
|
||||
# Gemma4 ClippableLinear wraps nn.Linear -- PEFT can't inject LoRA on it directly.
|
||||
# Monkey-patch PEFT to target the inner .linear child instead.
|
||||
_clippable_linear_cls = None
|
||||
try:
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4ClippableLinear as _clippable_linear_cls,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
if _clippable_linear_cls is not None:
|
||||
from peft.tuners.lora.model import LoraModel as _LoraModel
|
||||
|
||||
_original_car = _LoraModel._create_and_replace
|
||||
|
||||
def _patched_car(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key = None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(target, _clippable_linear_cls):
|
||||
return _original_car(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target.linear,
|
||||
"linear",
|
||||
target,
|
||||
current_key = current_key,
|
||||
**kwargs,
|
||||
)
|
||||
return _original_car(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key = current_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_LoraModel._create_and_replace = _patched_car
|
||||
|
||||
model = _get_peft_model(model, lora_config)
|
||||
|
||||
# Restore original PEFT method
|
||||
if _clippable_linear_cls is not None:
|
||||
_LoraModel._create_and_replace = _original_car
|
||||
# Apply QAT + LoRA if specified
|
||||
if qat_scheme is not None:
|
||||
print("Unsloth: Applying QAT to mitigate quantization degradation")
|
||||
|
|
@ -1383,7 +1448,7 @@ class FastBaseModel:
|
|||
# after this point, so we intercept gradient_checkpointing_enable
|
||||
# to always force use_reentrant=True for Gemma3N.
|
||||
_model_type = getattr(getattr(model, "config", None), "model_type", "") or ""
|
||||
if "gemma3n" in _model_type.lower():
|
||||
if "gemma3n" in _model_type.lower() or "gemma4" in _model_type.lower():
|
||||
_original_gc_enable = model.gradient_checkpointing_enable
|
||||
|
||||
def _gc_enable_reentrant(**kwargs):
|
||||
|
|
|
|||
|
|
@ -1199,6 +1199,21 @@ TEMPLATE """{{- range $i, $_ := .Messages }}
|
|||
OLLAMA_TEMPLATES["gemma-3n"] = gemma3n_ollama
|
||||
OLLAMA_TEMPLATES["gemma3n"] = gemma3n_ollama
|
||||
|
||||
# =========================================== Gemma-4
|
||||
gemma4_ollama = '''
|
||||
FROM {__FILE_LOCATION__}
|
||||
TEMPLATE """{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
<|turn>{{ .Role }}
|
||||
{{ .Content }}{{ if not $last }}<turn|>
|
||||
{{ end }}
|
||||
{{- end }}<turn|>
|
||||
<|turn>model
|
||||
"""
|
||||
'''
|
||||
OLLAMA_TEMPLATES["gemma-4"] = gemma4_ollama
|
||||
OLLAMA_TEMPLATES["gemma4"] = gemma4_ollama
|
||||
|
||||
# =========================================== GPT-OSS
|
||||
|
||||
# Ollama from https://ollama.com/library/gpt-oss:latest/blobs/fa6710a93d78
|
||||
|
|
@ -1961,6 +1976,16 @@ OLLAMA_TEMPLATE_TO_MODEL_MAPPER = {
|
|||
"google/medgemma-27b-text-it",
|
||||
"unsloth/medgemma-27b-text-it-bnb-4bit",
|
||||
),
|
||||
"gemma4": (
|
||||
"unsloth/gemma-4-E2B-it",
|
||||
"unsloth/gemma-4-E2B",
|
||||
"unsloth/gemma-4-E4B-it",
|
||||
"unsloth/gemma-4-E4B",
|
||||
"unsloth/gemma-4-31B-it",
|
||||
"unsloth/gemma-4-31B",
|
||||
"unsloth/gemma-4-26B-A4B-it",
|
||||
"unsloth/gemma-4-26B-A4B",
|
||||
),
|
||||
"gemma3n": (
|
||||
"unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
|
||||
"unsloth/gemma-3n-E4B-it",
|
||||
|
|
|
|||
Loading…
Reference in a new issue