mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* Fix review findings for PR #49 1. Sandbox fallback Jinja env in _VariantTokenizerProxy.apply_chat_template (use SandboxedEnvironment, matching _derive_assistant_prefix_by_render) 2. Unwrap benign outer-If guards in _template_ends_with_toplevel_for so templates like {% if messages %}{% for ... %}{% endfor %}{% endif %} are still repairable (preserves Qwen3-Guard rejection via else-branch and add_generation_prompt-name checks) 3. Preserve raw name_or_path in _VariantTokenizerProxy._source_path so local-path detection works for dict/list variant tokenizers 4. Context-aware strict-mode messages: omit "will still load" and "Set UNSLOTH_STRICT_CHAT_TEMPLATE=1" when already raising * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b42e3a120d
commit
ff23ce40b4
1 changed files with 91 additions and 31 deletions
|
|
@ -669,11 +669,11 @@ def _find_end_position(template, endfor = None, endif = None):
|
|||
def _template_ends_with_toplevel_for(chat_template):
|
||||
"""Return True if the last structural node at the template's top level is
|
||||
a For (message-iteration) loop, ignoring trailing pure-whitespace Output
|
||||
nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure
|
||||
is something else (e.g. an outer If that wraps the whole template, as in
|
||||
Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %}
|
||||
block at the end -- it would land inside or after an unrelated control
|
||||
structure."""
|
||||
nodes. Unwraps benign outer-If guards (no else branch, not testing
|
||||
add_generation_prompt) so that templates like
|
||||
``{% if messages %}{% for ... %}{% endfor %}{% endif %}`` are still
|
||||
repairable. Rejects real structural wrappers (e.g. Qwen3-Guard with
|
||||
else branches)."""
|
||||
try:
|
||||
import jinja2
|
||||
import jinja2.nodes
|
||||
|
|
@ -681,20 +681,31 @@ def _template_ends_with_toplevel_for(chat_template):
|
|||
ast = jinja2.Environment().parse(chat_template)
|
||||
except Exception:
|
||||
return False
|
||||
for node in reversed(ast.body):
|
||||
# Skip trailing output nodes that are only whitespace -- they come
|
||||
# from trailing whitespace/newlines in the source, not from real
|
||||
# message-rendering logic.
|
||||
if isinstance(node, jinja2.nodes.Output):
|
||||
only_ws = all(
|
||||
isinstance(child, jinja2.nodes.TemplateData)
|
||||
and child.data.strip() == ""
|
||||
for child in node.nodes
|
||||
)
|
||||
if only_ws:
|
||||
continue
|
||||
return isinstance(node, jinja2.nodes.For)
|
||||
return False
|
||||
|
||||
def _last_structural(nodes):
|
||||
for node in reversed(nodes):
|
||||
if isinstance(node, jinja2.nodes.Output):
|
||||
only_ws = all(
|
||||
isinstance(child, jinja2.nodes.TemplateData)
|
||||
and child.data.strip() == ""
|
||||
for child in node.nodes
|
||||
)
|
||||
if only_ws:
|
||||
continue
|
||||
return node
|
||||
return None
|
||||
|
||||
node = _last_structural(ast.body)
|
||||
while isinstance(node, jinja2.nodes.If) and not node.else_:
|
||||
names = []
|
||||
if isinstance(node.test, jinja2.nodes.Name):
|
||||
names.append(node.test)
|
||||
names.extend(node.test.find_all(jinja2.nodes.Name))
|
||||
if any(n.name == "add_generation_prompt" for n in names):
|
||||
break
|
||||
node = _last_structural(node.body)
|
||||
|
||||
return isinstance(node, jinja2.nodes.For)
|
||||
|
||||
|
||||
def _if_body_emits_content(if_node):
|
||||
|
|
@ -944,10 +955,18 @@ def _name_is_local_path(name_or_path):
|
|||
return False
|
||||
|
||||
|
||||
def _format_chat_template_message(name_or_path, repaired):
|
||||
def _format_chat_template_message(
|
||||
name_or_path,
|
||||
repaired,
|
||||
has_generation_block = False,
|
||||
local_path_source = None,
|
||||
strict = False,
|
||||
):
|
||||
"""Build a user-facing warning/error message that points at the right
|
||||
responsible party (user's downstream tool vs. upstream model maintainer)."""
|
||||
local = _name_is_local_path(name_or_path)
|
||||
local = _name_is_local_path(
|
||||
local_path_source if local_path_source is not None else name_or_path
|
||||
)
|
||||
if local:
|
||||
source_hint = (
|
||||
"This tokenizer was loaded from a local path. The likely cause is a "
|
||||
|
|
@ -961,19 +980,39 @@ def _format_chat_template_message(name_or_path, repaired):
|
|||
"The chat_template shipped with `{name}` appears incomplete. "
|
||||
"Consider filing a bug report with the model maintainers."
|
||||
).format(name = name_or_path)
|
||||
strict_suffix = (
|
||||
""
|
||||
if strict
|
||||
else (" Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn.")
|
||||
)
|
||||
if repaired:
|
||||
return (
|
||||
"Unsloth: Patched the chat_template on `{name}` to add a "
|
||||
"{{% if add_generation_prompt %}} block. {hint}"
|
||||
).format(name = name_or_path, hint = source_hint)
|
||||
if has_generation_block:
|
||||
return (
|
||||
"Unsloth: The tokenizer `{name}` has a "
|
||||
"{{% if add_generation_prompt %}} block, but it does not change "
|
||||
"the rendered output. {hint}{suffix}"
|
||||
).format(name = name_or_path, hint = source_hint, suffix = strict_suffix)
|
||||
load_clause = (
|
||||
"Loading is blocked in strict mode."
|
||||
if strict
|
||||
else "The model will still load, but "
|
||||
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
|
||||
"correct assistant-turn marker."
|
||||
)
|
||||
return (
|
||||
"Unsloth: The tokenizer `{name}` does not have a "
|
||||
"{{% if add_generation_prompt %}} block for generation purposes, and "
|
||||
"automatic repair was not possible. The model will still load, but "
|
||||
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
|
||||
"correct assistant-turn marker. {hint} Set "
|
||||
"UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn."
|
||||
).format(name = name_or_path, hint = source_hint)
|
||||
"automatic repair was not possible. {load_clause} {hint}{suffix}"
|
||||
).format(
|
||||
name = name_or_path,
|
||||
load_clause = load_clause,
|
||||
hint = source_hint,
|
||||
suffix = strict_suffix,
|
||||
)
|
||||
|
||||
|
||||
def _validate_patched_template(tokenizer, patched_template, is_sharegpt):
|
||||
|
|
@ -1041,6 +1080,7 @@ def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
|
|||
UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise
|
||||
RuntimeError (strict)."""
|
||||
name = getattr(tokenizer, "name_or_path", "unknown")
|
||||
source_path = getattr(tokenizer, "_source_path", name)
|
||||
|
||||
# Detect ShareGPT vs HF style by probing apply_chat_template.
|
||||
is_sharegpt = None
|
||||
|
|
@ -1092,19 +1132,38 @@ def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
|
|||
if _has_add_generation_prompt_block(chat_template):
|
||||
# Template has the block but it does not change output. This is the
|
||||
# "wasn't provided correctly" case from the pre-warn code path.
|
||||
msg = _format_chat_template_message(name, repaired = False)
|
||||
if _is_strict_chat_template_mode():
|
||||
strict = _is_strict_chat_template_mode()
|
||||
msg = _format_chat_template_message(
|
||||
name,
|
||||
repaired = False,
|
||||
has_generation_block = True,
|
||||
local_path_source = source_path,
|
||||
strict = strict,
|
||||
)
|
||||
if strict:
|
||||
raise RuntimeError(msg)
|
||||
logger.warning_once(msg)
|
||||
return chat_template
|
||||
|
||||
repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt)
|
||||
if repaired is not None:
|
||||
logger.warning_once(_format_chat_template_message(name, repaired = True))
|
||||
logger.warning_once(
|
||||
_format_chat_template_message(
|
||||
name,
|
||||
repaired = True,
|
||||
local_path_source = source_path,
|
||||
)
|
||||
)
|
||||
return repaired
|
||||
|
||||
msg = _format_chat_template_message(name, repaired = False)
|
||||
if _is_strict_chat_template_mode():
|
||||
strict = _is_strict_chat_template_mode()
|
||||
msg = _format_chat_template_message(
|
||||
name,
|
||||
repaired = False,
|
||||
local_path_source = source_path,
|
||||
strict = strict,
|
||||
)
|
||||
if strict:
|
||||
raise RuntimeError(msg)
|
||||
logger.warning_once(msg)
|
||||
return chat_template
|
||||
|
|
@ -1125,6 +1184,7 @@ class _VariantTokenizerProxy:
|
|||
self._base = base_tokenizer
|
||||
self._template = variant_template
|
||||
base_name = getattr(base_tokenizer, "name_or_path", "unknown")
|
||||
self._source_path = base_name
|
||||
self.name_or_path = (
|
||||
f"{base_name} ({variant_label})" if variant_label else base_name
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue