Fix review findings for chat-template repair (#5049) (#5056)

* Fix review findings for PR #49

1. Sandbox fallback Jinja env in _VariantTokenizerProxy.apply_chat_template
   (use SandboxedEnvironment, matching _derive_assistant_prefix_by_render)
2. Unwrap benign outer-If guards in _template_ends_with_toplevel_for so
   templates like {% if messages %}{% for ... %}{% endfor %}{% endif %}
   are still repairable (preserves Qwen3-Guard rejection via else-branch
   and add_generation_prompt-name checks)
3. Preserve raw name_or_path in _VariantTokenizerProxy._source_path so
   local-path detection works for dict/list variant tokenizers
4. Context-aware strict-mode messages: omit "will still load" and
   "Set UNSLOTH_STRICT_CHAT_TEMPLATE=1" when already raising

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-16 08:02:05 -07:00 committed by GitHub
parent b42e3a120d
commit ff23ce40b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -669,11 +669,11 @@ def _find_end_position(template, endfor = None, endif = None):
def _template_ends_with_toplevel_for(chat_template):
"""Return True if the last structural node at the template's top level is
a For (message-iteration) loop, ignoring trailing pure-whitespace Output
nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure
is something else (e.g. an outer If that wraps the whole template, as in
Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %}
block at the end -- it would land inside or after an unrelated control
structure."""
nodes. Unwraps benign outer-If guards (no else branch, not testing
add_generation_prompt) so that templates like
``{% if messages %}{% for ... %}{% endfor %}{% endif %}`` are still
repairable. Rejects real structural wrappers (e.g. Qwen3-Guard with
else branches)."""
try:
import jinja2
import jinja2.nodes
@ -681,20 +681,31 @@ def _template_ends_with_toplevel_for(chat_template):
ast = jinja2.Environment().parse(chat_template)
except Exception:
return False
for node in reversed(ast.body):
# Skip trailing output nodes that are only whitespace -- they come
# from trailing whitespace/newlines in the source, not from real
# message-rendering logic.
if isinstance(node, jinja2.nodes.Output):
only_ws = all(
isinstance(child, jinja2.nodes.TemplateData)
and child.data.strip() == ""
for child in node.nodes
)
if only_ws:
continue
return isinstance(node, jinja2.nodes.For)
return False
def _last_structural(nodes):
for node in reversed(nodes):
if isinstance(node, jinja2.nodes.Output):
only_ws = all(
isinstance(child, jinja2.nodes.TemplateData)
and child.data.strip() == ""
for child in node.nodes
)
if only_ws:
continue
return node
return None
node = _last_structural(ast.body)
while isinstance(node, jinja2.nodes.If) and not node.else_:
names = []
if isinstance(node.test, jinja2.nodes.Name):
names.append(node.test)
names.extend(node.test.find_all(jinja2.nodes.Name))
if any(n.name == "add_generation_prompt" for n in names):
break
node = _last_structural(node.body)
return isinstance(node, jinja2.nodes.For)
def _if_body_emits_content(if_node):
@ -944,10 +955,18 @@ def _name_is_local_path(name_or_path):
return False
def _format_chat_template_message(name_or_path, repaired):
def _format_chat_template_message(
name_or_path,
repaired,
has_generation_block = False,
local_path_source = None,
strict = False,
):
"""Build a user-facing warning/error message that points at the right
responsible party (user's downstream tool vs. upstream model maintainer)."""
local = _name_is_local_path(name_or_path)
local = _name_is_local_path(
local_path_source if local_path_source is not None else name_or_path
)
if local:
source_hint = (
"This tokenizer was loaded from a local path. The likely cause is a "
@ -961,19 +980,39 @@ def _format_chat_template_message(name_or_path, repaired):
"The chat_template shipped with `{name}` appears incomplete. "
"Consider filing a bug report with the model maintainers."
).format(name = name_or_path)
strict_suffix = (
""
if strict
else (" Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn.")
)
if repaired:
return (
"Unsloth: Patched the chat_template on `{name}` to add a "
"{{% if add_generation_prompt %}} block. {hint}"
).format(name = name_or_path, hint = source_hint)
if has_generation_block:
return (
"Unsloth: The tokenizer `{name}` has a "
"{{% if add_generation_prompt %}} block, but it does not change "
"the rendered output. {hint}{suffix}"
).format(name = name_or_path, hint = source_hint, suffix = strict_suffix)
load_clause = (
"Loading is blocked in strict mode."
if strict
else "The model will still load, but "
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
"correct assistant-turn marker."
)
return (
"Unsloth: The tokenizer `{name}` does not have a "
"{{% if add_generation_prompt %}} block for generation purposes, and "
"automatic repair was not possible. The model will still load, but "
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
"correct assistant-turn marker. {hint} Set "
"UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn."
).format(name = name_or_path, hint = source_hint)
"automatic repair was not possible. {load_clause} {hint}{suffix}"
).format(
name = name_or_path,
load_clause = load_clause,
hint = source_hint,
suffix = strict_suffix,
)
def _validate_patched_template(tokenizer, patched_template, is_sharegpt):
@ -1041,6 +1080,7 @@ def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise
RuntimeError (strict)."""
name = getattr(tokenizer, "name_or_path", "unknown")
source_path = getattr(tokenizer, "_source_path", name)
# Detect ShareGPT vs HF style by probing apply_chat_template.
is_sharegpt = None
@ -1092,19 +1132,38 @@ def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
if _has_add_generation_prompt_block(chat_template):
# Template has the block but it does not change output. This is the
# "wasn't provided correctly" case from the pre-warn code path.
msg = _format_chat_template_message(name, repaired = False)
if _is_strict_chat_template_mode():
strict = _is_strict_chat_template_mode()
msg = _format_chat_template_message(
name,
repaired = False,
has_generation_block = True,
local_path_source = source_path,
strict = strict,
)
if strict:
raise RuntimeError(msg)
logger.warning_once(msg)
return chat_template
repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt)
if repaired is not None:
logger.warning_once(_format_chat_template_message(name, repaired = True))
logger.warning_once(
_format_chat_template_message(
name,
repaired = True,
local_path_source = source_path,
)
)
return repaired
msg = _format_chat_template_message(name, repaired = False)
if _is_strict_chat_template_mode():
strict = _is_strict_chat_template_mode()
msg = _format_chat_template_message(
name,
repaired = False,
local_path_source = source_path,
strict = strict,
)
if strict:
raise RuntimeError(msg)
logger.warning_once(msg)
return chat_template
@ -1125,6 +1184,7 @@ class _VariantTokenizerProxy:
self._base = base_tokenizer
self._template = variant_template
base_name = getattr(base_tokenizer, "name_or_path", "unknown")
self._source_path = base_name
self.name_or_path = (
f"{base_name} ({variant_label})" if variant_label else base_name
)