mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Chat-template repair: warn-by-default, AST classification, dict support (#5049)
* Chat-template repair: warn-by-default, AST classification, dict support Follow-up hardening on top of PR #4426 (which fixed the #4150 RuntimeError for ChatML LoRA reloads). Behavior changes: - Warn-by-default instead of RuntimeError. When fix_chat_template cannot repair a broken template, emit a warning and return the original. Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to restore the pre-warn hard fail. Fixes the UX where a missing `{% if add_generation_prompt %}` block on a saved LoRA (typical after LlamaFactory / Axolotl re-serialize) would block model loading entirely. - Local path vs HF hub distinguished in the warning message. For local paths the message points at the likely downstream tool; for HF IDs it points at the upstream model maintainers. Previously both said "file a bug report to the maintainers of <path>" even when <path> was the user's own saves/ directory. - Dict / list chat_template now handled. Hermes-3 ships with {default, tool_use} and the previous code crashed with AttributeError: 'dict' object has no attribute 'find' when entering _fix_chat_template with a dict. Each variant is now fixed independently; structure is preserved. Internals: - _find_end_position now matches all four Jinja whitespace-control variants ({% %}, {%- %}, {% -%}, {%- -%}) and returns the rightmost endfor/endif so multi-for templates aren't locked onto the first loop. Previously {%- endfor -%} (both-side dash, used by Qwen3-Guard) was silently bypassed. - _has_add_generation_prompt_block uses Jinja AST via jinja2.nodes.If/Name walks instead of substring matching, so templates that hide the block behind comments or dash-style variants are classified correctly. - _template_ends_with_toplevel_for gates the GH#4150 ChatML repair on the AST: only fires when the last structural top-level node is a For (standard ChatML shape), ignoring trailing pure-whitespace output nodes. Templates wrapped in an outer If (Qwen3-Guard) are now explicitly skipped at the _fix_chat_template level as well, not just at load_correct_tokenizer's name-based exemption. - _validate_patched_template renders the patched template with and without add_generation_prompt and confirms the patched output responds to the flag by appending (not replacing) content. If validation fails, the patch is discarded and we fall through to the warn path. Verified with an expanded regression suite in tests/: - test_fix_chat_template_pr4426.py: 42/42 template-matrix cells - test_load_correct_tokenizer_pr4426.py: 5/5 tokenizer loads - test_chat_template_followups.py: 10/10 new follow-up tests - test_mistral_pr4426.py: 5 Mistral variants byte-identical - test_qwen_pr4426.py: 14 Qwen variants byte-identical (Qwen1.5, Qwen2, Qwen2.5-Instruct/Coder/Math/VL, Qwen3, Qwen3-Coder, QwQ, Qwen3-Guard-Gen) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Guard _validate_patched_template against read-only chat_template If tokenizer.chat_template is a property or otherwise read-only, the validation helper would crash with AttributeError when trying to temporarily set the patched template. Catch the assignment failure and return False (skip validation), and best-effort restore in the finally block. * Replace regex separator inference with render-diff; broaden repair to non-ChatML templates The previous `_infer_assistant_separator` was a four-tier regex heuristic that only worked on ChatML-shaped templates and forced a hard `<|im_start|>` / `<|im_end|>` presence gate on Case 2 repair. This meant a Llama-3, Gemma, or Phi-3 template stripped of its generation-prompt block by a downstream tool (LlamaFactory, Axolotl, etc.) would still warn-and-return even though the structural shape is identical to the ChatML case the PR already handles. This replaces the regex with `_derive_assistant_prefix_by_render`: render the template with two dialogs that differ only in assistant content, then `os.path.commonprefix` on the tails captures the exact assistant-turn prefix the template emits. The template itself is ground truth, so non-ChatML shapes work as long as the assistant block is a literal the template emits once per message. Three guards keep the derivation safe: A. both assistant renders extend the base render (no reordering); B. the divergence point is exactly the content-insertion site (sentinel follows the common prefix); C. a user-role cross-check: if a render with a user sentinel also emits the same prefix, role has no effect on output and we reject. A render failure on [user, user] (e.g. Gemma's `raise_exception` alternation check) is evidence that role matters; we accept. Sentinels differ at character 0 so `commonprefix` cannot absorb them, and trailing whitespace/comments after the last `{% endfor %}` are stripped before probing (they would appear in base but not after the appended assistant turn and break Guard A). `_fix_chat_template` and `_repair_string_template` now thread an `is_sharegpt` kwarg; `_fix_chat_template` retries once with `is_sharegpt=True` if the first probe returns None (dual-probe fallback for dict/list callers). The ChatML `<|im_start|>` / `<|im_end|>` hard gate in Case 2 is dropped. `_infer_assistant_separator` is deleted. Verified via: - tests/test_fix_chat_template_pr4426.py: 51/51 cells (new Llama-3, Gemma, Phi-3 broken-template rows all repair FIX-OK) - tests/test_load_correct_tokenizer_pr4426.py: 5/5 - tests/test_chat_template_followups.py: 18/18 (T11-T18 cover non-ChatML repair + probe failure modes) - tests/test_mistral_pr4426.py: 5/5 byte-identical - tests/test_qwen_pr4426.py: 14/14 byte-identical (Qwen3-Guard AST gate still rejects) - tests/hermes3_lora_pr4426.py reload: patched template ends with `<|im_start|>assistant\n`, inference returns sensible output. - temp/sim/battery.py: 79/79 followup; vs baseline: 0 regressions, 9 improvements. - Spot-check probe on real stripped tokenizers (Hermes-3, Phi-4, Llama-3.2-1B, Gemma-3-1B): all derive the expected prefix. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address reviewer findings: variant routing, positive-gate detection, comment-safe end scan Resolves three reviewer findings on PR #5049 (`fix/chat-template-followups`): Finding #1 [10/10]: dict/list variants now route through `_fix_chat_template_for_tokenizer` via a new `_VariantTokenizerProxy` adapter. Previously the dict/list branches called `_fix_chat_template` directly, silently bypassing the warn/strict (`UNSLOTH_STRICT_CHAT_TEMPLATE`) contract, the `no == yes` diagnostic, broken-existing-block detection, and `_validate_patched_template` guard. The proxy swaps `base.chat_template` to the variant string before each `apply_chat_template` call so tokenizer globals (`bos_token`, custom filters, `raise_exception`) remain available; if the base is read-only it falls back to isolated Jinja rendering. Finding #2 [1/10]: `_has_add_generation_prompt_block` now requires the `If` body to contain at least one `Output` node (a new `_if_body_emits_content` helper walks descendants). This distinguishes a real generation-prompt block from a header guard like `{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}` (body contains only `Assign`) which references the name but emits nothing. Also dropped a now-redundant `"add_generation_prompt" not in scrubbed` guard in `_fix_chat_template` Case 2 so header-guarded templates still get repaired. Finding #4 [1/10]: `_find_end_position` now replaces Jinja comments with equal-length whitespace before scanning for `{% endfor %}` / `{% endif %}` tokens. This prevents a trailing comment containing those tokens from being picked as the real end tag. Positions in the padded string map 1:1 to positions in the original template. Tests: - tests/test_chat_template_followups.py: 21/21 (T19 strict-mode dict variant, T20 header-guard repair, T21 comment-endfor trap added; T4/T5 stubs updated with a working apply_chat_template that routes through Jinja). - tests/test_fix_chat_template_pr4426.py: 51/51 cells unchanged. - tests/test_load_correct_tokenizer_pr4426.py: 5/5. - tests/test_mistral_pr4426.py: 5/5 byte-identical. - tests/test_qwen_pr4426.py: 14/14 byte-identical. - temp/sim/battery.py: 79/79 followup; 0 regressions vs baseline. - Phase 3 Hermes-3 broken-LoRA reload: inference still returns `'The answer to the equation 2+2 is 4.'`. - Spot-checks on Hermes-3 / Phi-4 / Llama-3.2-1B / Gemma-3-1B real stripped templates: probe still derives the expected prefix. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tighten comments in chat-template helpers Pure comment minimization across `_find_end_position`, `_has_add_generation_prompt_block`, `_if_body_emits_content`, `_derive_assistant_prefix_by_render`, `_fix_chat_template` Case 2, and `_VariantTokenizerProxy`. No behavior change; same intent, fewer lines. All 21 follow-up tests and the 51-cell Phase 1 matrix still pass. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Sandbox probe, fix is_sharegpt validator mismatch, reject negated gates Three real bugs from the 10-agent Opus review: 1. Probe now uses `jinja2.sandbox.SandboxedEnvironment` instead of bare `jinja2.Environment`. The probe renders at model-load time (before the user calls `apply_chat_template`), so it was a new eager code-execution surface that the base HF tokenizer loading does not have. SandboxedEnvironment blocks attribute-chain exploits at negligible cost. 2. `_repair_string_template` now tries validation with both `is_sharegpt=False` and `is_sharegpt=True`. Previously, when `_fix_chat_template` internally fell back to the other schema via its dual-probe, the outer validation still used the caller's original `is_sharegpt` -- rendering with the wrong message keys and spuriously dropping a valid repair. 3. `_has_add_generation_prompt_block` now skips `If` nodes whose test is a `Not` expression. A negated gate like `{% if not add_generation_prompt %}{{ x }}{% endif %}` fires when agp=False, so its emitting body is not a generation block -- but the old code counted any Name reference regardless of polarity. Cleanup: removed unused `self._label`, added `\r` escape in generation-block literal, switched variant labels to `!r` formatting, removed redundant `import os as _os`. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix jinja2.sandbox import and sandbox proxy fallback Two critical findings from the 20-reviewer pass: 1. [20/20] The proxy read-only fallback used bare `jinja2.Environment`, not sandboxed. All 20 reviewers independently reproduced marker-file creation via `cycler.__init__.__globals__['os'].system(...)` during `fix_chat_template()`. Fixed: fallback now uses `from jinja2.sandbox import SandboxedEnvironment`. 2. [14/20] The render-diff probe did `import jinja2` then referenced `jinja2.sandbox.SandboxedEnvironment`. `jinja2.sandbox` is a submodule that is NOT auto-imported by `import jinja2` on Jinja 3.1.6. This caused `AttributeError` (swallowed by `except Exception`), making the entire Case 2 repair path silently return None in a clean process. The 6 reviewers who saw it work had `jinja2.sandbox` pre-imported by an earlier module in their process. Fixed: both the probe and the proxy fallback now use `from jinja2.sandbox import SandboxedEnvironment`. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6e87bade25
commit
c5be8b1cd2
1 changed files with 544 additions and 138 deletions
|
|
@ -636,173 +636,579 @@ def load_correct_tokenizer(
|
|||
return tokenizer
|
||||
|
||||
|
||||
def _find_end_position(template, endfor, endif):
|
||||
where_endfor = template.find(endfor)
|
||||
where_endif = template.find(endif)
|
||||
if where_endfor == where_endif == -1:
|
||||
# All four Jinja whitespace-control variants of endfor/endif:
|
||||
# {% endfor %} {%- endfor %} {% endfor -%} {%- endfor -%}
|
||||
_RE_ENDFOR = re.compile(r"\{%(-?)\s*endfor\s*(-?)%\}")
|
||||
_RE_ENDIF = re.compile(r"\{%(-?)\s*endif\s*(-?)%\}")
|
||||
_RE_JINJA_COMMENT = re.compile(r"\{#.*?#\}", flags = re.DOTALL)
|
||||
|
||||
|
||||
def _find_end_position(template, endfor = None, endif = None):
|
||||
"""Rightmost {% endfor %}/{% endif %} (any dash variant), as a dict
|
||||
with start/end/text/dash_left/dash_right. Tokens inside Jinja comments
|
||||
are ignored. `endfor`/`endif` kwargs kept for back-compat, ignored."""
|
||||
# Space-pad comments so positions still map 1:1 to the original.
|
||||
scrubbed = _RE_JINJA_COMMENT.sub(lambda m: " " * len(m.group(0)), template)
|
||||
endfor_matches = list(_RE_ENDFOR.finditer(scrubbed))
|
||||
endif_matches = list(_RE_ENDIF.finditer(scrubbed))
|
||||
last_endfor = endfor_matches[-1] if endfor_matches else None
|
||||
last_endif = endif_matches[-1] if endif_matches else None
|
||||
candidates = [m for m in (last_endfor, last_endif) if m is not None]
|
||||
if not candidates:
|
||||
return None
|
||||
elif where_endfor > where_endif:
|
||||
return endfor
|
||||
m = max(candidates, key = lambda x: x.end())
|
||||
return {
|
||||
"start": m.start(),
|
||||
"end": m.end(),
|
||||
"text": m.group(0),
|
||||
"dash_left": bool(m.group(1)),
|
||||
"dash_right": bool(m.group(2)),
|
||||
}
|
||||
|
||||
|
||||
def _template_ends_with_toplevel_for(chat_template):
|
||||
"""Return True if the last structural node at the template's top level is
|
||||
a For (message-iteration) loop, ignoring trailing pure-whitespace Output
|
||||
nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure
|
||||
is something else (e.g. an outer If that wraps the whole template, as in
|
||||
Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %}
|
||||
block at the end -- it would land inside or after an unrelated control
|
||||
structure."""
|
||||
try:
|
||||
import jinja2
|
||||
import jinja2.nodes
|
||||
|
||||
ast = jinja2.Environment().parse(chat_template)
|
||||
except Exception:
|
||||
return False
|
||||
for node in reversed(ast.body):
|
||||
# Skip trailing output nodes that are only whitespace -- they come
|
||||
# from trailing whitespace/newlines in the source, not from real
|
||||
# message-rendering logic.
|
||||
if isinstance(node, jinja2.nodes.Output):
|
||||
only_ws = all(
|
||||
isinstance(child, jinja2.nodes.TemplateData)
|
||||
and child.data.strip() == ""
|
||||
for child in node.nodes
|
||||
)
|
||||
if only_ws:
|
||||
continue
|
||||
return isinstance(node, jinja2.nodes.For)
|
||||
return False
|
||||
|
||||
|
||||
def _if_body_emits_content(if_node):
|
||||
"""True if the If's body contains any Output node (directly or nested).
|
||||
Distinguishes a real generation block from a header guard that only
|
||||
does `{% set ... %}`."""
|
||||
import jinja2.nodes
|
||||
|
||||
for node in if_node.body:
|
||||
if isinstance(node, jinja2.nodes.Output):
|
||||
return True
|
||||
if any(
|
||||
isinstance(d, jinja2.nodes.Output)
|
||||
for d in node.find_all(jinja2.nodes.Output)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _has_add_generation_prompt_block(chat_template):
|
||||
"""True if the template has a *positive* `{% if add_generation_prompt %}`
|
||||
gate whose body emits output. Rejects header guards like
|
||||
`{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}`
|
||||
that reference the name but emit nothing. AST-based; string-scan
|
||||
fallback if Jinja fails to parse."""
|
||||
try:
|
||||
import jinja2
|
||||
import jinja2.nodes
|
||||
|
||||
ast = jinja2.Environment().parse(chat_template)
|
||||
except Exception:
|
||||
return "if add_generation_prompt" in chat_template and "%}" in chat_template
|
||||
for if_node in ast.find_all(jinja2.nodes.If):
|
||||
test = if_node.test
|
||||
# Reject negated gates: `{% if not add_generation_prompt %}` fires
|
||||
# when agp=False, so it's not a generation block even if it emits.
|
||||
if isinstance(test, jinja2.nodes.Not):
|
||||
continue
|
||||
# find_all skips the test root, so check bare Name tests explicitly.
|
||||
references_agp = False
|
||||
if isinstance(test, jinja2.nodes.Name) and test.name == "add_generation_prompt":
|
||||
references_agp = True
|
||||
else:
|
||||
for name_node in test.find_all(jinja2.nodes.Name):
|
||||
if name_node.name == "add_generation_prompt":
|
||||
references_agp = True
|
||||
break
|
||||
if references_agp and _if_body_emits_content(if_node):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Sentinels for _derive_assistant_prefix_by_render. Diverge at char 0 so
|
||||
# commonprefix can't absorb them; long random tail makes collision with real
|
||||
# template literals negligible (see T18).
|
||||
_RENDER_DIFF_SENTINEL_A = "AAAA_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
|
||||
_RENDER_DIFF_SENTINEL_B = "BBBB_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
|
||||
_RENDER_DIFF_SENTINEL_C = "CCCC_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
|
||||
|
||||
|
||||
def _derive_assistant_prefix_by_render(chat_template, is_sharegpt = False):
|
||||
"""Return the assistant-turn prefix the template emits, derived by
|
||||
rendering two dialogs that differ only in assistant content: the common
|
||||
prefix of their tails (after the base [user]-only render) is what the
|
||||
template emits for an assistant turn. None if any guard fails.
|
||||
|
||||
Works for Llama-3 / Gemma / Phi-3 and other non-ChatML shapes; the
|
||||
template is its own ground truth.
|
||||
|
||||
Known limitation: an `eos-on-non-last` pattern (turn-end sentinel only
|
||||
emitted for non-last messages) would produce a consistent but wrong
|
||||
prefix that `_validate_patched_template` can't catch. No real-world
|
||||
template is known to use this.
|
||||
"""
|
||||
try:
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if is_sharegpt:
|
||||
base_msgs = [{"from": "human", "value": "Hi"}]
|
||||
sent_a_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_A}]
|
||||
sent_b_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_B}]
|
||||
# User-role cross-check (Guard C below).
|
||||
sent_c_msgs = base_msgs + [{"from": "human", "value": _RENDER_DIFF_SENTINEL_C}]
|
||||
else:
|
||||
return endif
|
||||
base_msgs = [{"role": "user", "content": "Hi"}]
|
||||
sent_a_msgs = base_msgs + [
|
||||
{"role": "assistant", "content": _RENDER_DIFF_SENTINEL_A}
|
||||
]
|
||||
sent_b_msgs = base_msgs + [
|
||||
{"role": "assistant", "content": _RENDER_DIFF_SENTINEL_B}
|
||||
]
|
||||
sent_c_msgs = base_msgs + [{"role": "user", "content": _RENDER_DIFF_SENTINEL_C}]
|
||||
|
||||
# Strip trailing whitespace/comments after the last endfor/endif: they
|
||||
# appear after the message loop and would break Guard A. The splice in
|
||||
# `_fix_chat_template` drops them too.
|
||||
probe_template = chat_template
|
||||
end = _find_end_position(chat_template)
|
||||
if end is not None:
|
||||
after = chat_template[end["end"] :]
|
||||
if _RE_JINJA_COMMENT.sub("", after).strip() == "":
|
||||
probe_template = chat_template[: end["end"]]
|
||||
|
||||
# Sandboxed: probe renders at load time, before user calls
|
||||
# apply_chat_template. SandboxedEnvironment blocks attribute-chain exploits.
|
||||
try:
|
||||
env = SandboxedEnvironment(
|
||||
autoescape = False,
|
||||
keep_trailing_newline = True,
|
||||
)
|
||||
tmpl = env.from_string(probe_template)
|
||||
out_base = tmpl.render(messages = base_msgs, add_generation_prompt = False)
|
||||
out_a = tmpl.render(messages = sent_a_msgs, add_generation_prompt = False)
|
||||
out_b = tmpl.render(messages = sent_b_msgs, add_generation_prompt = False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Best-effort: alternation-enforcing templates (e.g. Gemma's
|
||||
# raise_exception) fail on [user, user]; that's a positive signal
|
||||
# for Guard C, not a probe failure.
|
||||
out_user_c = None
|
||||
try:
|
||||
out_user_c = tmpl.render(messages = sent_c_msgs, add_generation_prompt = False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Guard A: assistant renders extend base (no reordering).
|
||||
if not (out_a.startswith(out_base) and out_b.startswith(out_base)):
|
||||
return None
|
||||
|
||||
tail_a = out_a[len(out_base) :]
|
||||
tail_b = out_b[len(out_base) :]
|
||||
if not tail_a or not tail_b:
|
||||
return None
|
||||
|
||||
prefix = os.path.commonprefix([tail_a, tail_b])
|
||||
|
||||
# Guard B: divergence is exactly at the content-insertion site.
|
||||
if not (
|
||||
tail_a[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_A)
|
||||
and tail_b[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_B)
|
||||
):
|
||||
return None
|
||||
|
||||
# Guard C: reject if a [user, user] render also emits the same prefix
|
||||
# (role-insensitive template, e.g. `{% set greeting='Hi' %}...`).
|
||||
if out_user_c is not None and out_user_c.startswith(out_base):
|
||||
tail_c = out_user_c[len(out_base) :]
|
||||
if tail_c.startswith(prefix) and prefix != "":
|
||||
return None
|
||||
|
||||
if not prefix:
|
||||
return None
|
||||
|
||||
return prefix
|
||||
|
||||
|
||||
def _fix_chat_template(chat_template):
|
||||
endfor = "{% endfor %}"
|
||||
endif = "{% endif %}"
|
||||
chosen_end = _find_end_position(chat_template, endfor, endif)
|
||||
if chosen_end is None:
|
||||
endfor = "{%- endfor %}"
|
||||
endif = "{%- endif %}"
|
||||
chosen_end = _find_end_position(chat_template, endfor, endif)
|
||||
if chosen_end is None:
|
||||
def _fix_chat_template(chat_template, is_sharegpt = False):
|
||||
# Fast path: already has an {% if add_generation_prompt %} block, nothing
|
||||
# to do. This catches cases the old string-based check would miss (e.g.
|
||||
# templates that use {%- if add_generation_prompt -%} with both-side dash,
|
||||
# or that sneak the block into a nested If/For).
|
||||
if _has_add_generation_prompt_block(chat_template):
|
||||
return chat_template
|
||||
|
||||
where = chat_template.find(chosen_end)
|
||||
end = _find_end_position(chat_template)
|
||||
if end is None:
|
||||
return chat_template
|
||||
|
||||
after_endfor = chat_template[where + len(chosen_end) :]
|
||||
|
||||
dash = "-" if chosen_end.startswith("{%-") else ""
|
||||
after_endfor = chat_template[end["end"] :]
|
||||
dash_l = "-" if end["dash_left"] else ""
|
||||
dash_r = "-" if end["dash_right"] else ""
|
||||
open_tag = lambda body: "{%" + dash_l + " " + body + " " + dash_r + "%}"
|
||||
|
||||
# Case 1 (pre-existing base case): template ends with a single trailing
|
||||
# {{ expr }} that is the generation prefix. Wrap it in an
|
||||
# {% if add_generation_prompt %} ... {% endif %}.
|
||||
if (
|
||||
"{%" + dash + " if" not in after_endfor
|
||||
and "{%" + dash + " set " not in after_endfor
|
||||
"{%" + dash_l + " if" not in after_endfor
|
||||
and "{%" + dash_l + " set " not in after_endfor
|
||||
and after_endfor.startswith("{{")
|
||||
and after_endfor.endswith("}}")
|
||||
and after_endfor.count("{{") == 1
|
||||
and after_endfor.count("}}") == 1
|
||||
):
|
||||
after_endfor = (
|
||||
"{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
|
||||
wrapped = (
|
||||
open_tag("if add_generation_prompt") + after_endfor + open_tag("endif")
|
||||
)
|
||||
return chat_template[: end["end"]] + wrapped
|
||||
|
||||
chat_template = chat_template[: where + len(chosen_end)] + after_endfor
|
||||
|
||||
elif re.sub(r"\{#.*?#\}", "", after_endfor, flags = re.DOTALL).strip() == "":
|
||||
# GH#4150: ChatML templates ending at {% endfor %} without an
|
||||
# add_generation_prompt block. Scrub Jinja `{# ... #}` comments so
|
||||
# tokens inside comments cannot fool the guard below.
|
||||
scrubbed = re.sub(r"\{#.*?#\}", "", chat_template, flags = re.DOTALL)
|
||||
if (
|
||||
"<|im_start|>" in scrubbed
|
||||
and "<|im_end|>" in scrubbed
|
||||
and "add_generation_prompt" not in scrubbed
|
||||
):
|
||||
# Infer the assistant-turn separator. Prefer an explicit
|
||||
# '<|im_start|>assistant<sep>' literal; else the unique
|
||||
# `message['role'] + '<sep>'` from role concatenations; else
|
||||
# '<|im_sep|>' if present (Phi-4-mini uses '\n' for system and
|
||||
# '<|im_sep|>' for user/assistant); else '\n'.
|
||||
assistant_match = re.search(
|
||||
r"""(['"])<\|im_start\|>assistant([^'"]*)\1""",
|
||||
scrubbed,
|
||||
# Case 2 (GH#4150): template ends at {% endfor %} with only whitespace
|
||||
# or comments left. Inject an {% if add_generation_prompt %} block with
|
||||
# the assistant prefix derived by render-diff. The top-level-For gate
|
||||
# keeps us out of outer-If wrappers (e.g. Qwen3-Guard).
|
||||
if _RE_JINJA_COMMENT.sub(
|
||||
"", after_endfor
|
||||
).strip() == "" and _template_ends_with_toplevel_for(chat_template):
|
||||
# No redundant "agp not in scrubbed" check: the fast path already
|
||||
# confirmed no *positive* block, and a mere reference (header
|
||||
# guard) should still get repaired.
|
||||
assistant_prefix = _derive_assistant_prefix_by_render(
|
||||
chat_template, is_sharegpt
|
||||
)
|
||||
# Dual-probe: dict/list callers don't know the shape up front.
|
||||
if assistant_prefix is None and not is_sharegpt:
|
||||
assistant_prefix = _derive_assistant_prefix_by_render(
|
||||
chat_template, is_sharegpt = True
|
||||
)
|
||||
role_seps = [
|
||||
m.group(2)
|
||||
for m in re.finditer(
|
||||
r"""message(?:\[['"]role['"]\]|\.role)\s*\+\s*(['"])([^'"]*)\1""",
|
||||
scrubbed,
|
||||
)
|
||||
]
|
||||
unique_role_seps = list(dict.fromkeys(role_seps))
|
||||
if assistant_match is not None and assistant_match.group(2):
|
||||
separator = assistant_match.group(2)
|
||||
elif len(unique_role_seps) == 1:
|
||||
separator = unique_role_seps[0]
|
||||
elif "<|im_sep|>" in scrubbed:
|
||||
separator = "<|im_sep|>"
|
||||
else:
|
||||
separator = "\\n"
|
||||
# Emit a double-quoted Jinja literal so a single quote in the
|
||||
# separator cannot break the block. Drop trailing whitespace/
|
||||
# comments after endfor: they would render as stray output
|
||||
# after the generation prefix.
|
||||
assistant_prefix = "<|im_start|>assistant" + separator
|
||||
generation_block = (
|
||||
"{%" + dash + " if add_generation_prompt %}"
|
||||
'{{ "' + assistant_prefix.replace('"', '\\"') + '" }}'
|
||||
"{%" + dash + " endif %}"
|
||||
)
|
||||
chat_template = chat_template[: where + len(chosen_end)] + generation_block
|
||||
if assistant_prefix is None:
|
||||
return chat_template
|
||||
# Escape for a double-quoted Jinja string literal.
|
||||
escaped = (
|
||||
assistant_prefix.replace("\\", "\\\\")
|
||||
.replace('"', '\\"')
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "\\r")
|
||||
)
|
||||
generation_block = (
|
||||
open_tag("if add_generation_prompt")
|
||||
+ '{{ "'
|
||||
+ escaped
|
||||
+ '" }}'
|
||||
+ open_tag("endif")
|
||||
)
|
||||
return chat_template[: end["end"]] + generation_block
|
||||
|
||||
return chat_template
|
||||
|
||||
|
||||
def _is_strict_chat_template_mode():
|
||||
"""Opt-in strict mode restores the pre-warn RuntimeError behavior."""
|
||||
val = os.environ.get("UNSLOTH_STRICT_CHAT_TEMPLATE", "0")
|
||||
return str(val).strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def _name_is_local_path(name_or_path):
|
||||
"""True if name_or_path refers to an existing local directory. Used to
|
||||
tailor the warning message: for local paths the user cannot 'file a bug
|
||||
report to the maintainers of <path>' since that path is their own."""
|
||||
if not name_or_path:
|
||||
return False
|
||||
try:
|
||||
return os.path.isdir(str(name_or_path))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _format_chat_template_message(name_or_path, repaired):
|
||||
"""Build a user-facing warning/error message that points at the right
|
||||
responsible party (user's downstream tool vs. upstream model maintainer)."""
|
||||
local = _name_is_local_path(name_or_path)
|
||||
if local:
|
||||
source_hint = (
|
||||
"This tokenizer was loaded from a local path. The likely cause is a "
|
||||
"downstream tool (LlamaFactory, Axolotl, etc.) that re-serialized "
|
||||
"the tokenizer during save and stripped the generation-prompt "
|
||||
"block. Either re-save with the original template, or set "
|
||||
"`tokenizer.chat_template` manually before loading."
|
||||
)
|
||||
else:
|
||||
source_hint = (
|
||||
"The chat_template shipped with `{name}` appears incomplete. "
|
||||
"Consider filing a bug report with the model maintainers."
|
||||
).format(name = name_or_path)
|
||||
if repaired:
|
||||
return (
|
||||
"Unsloth: Patched the chat_template on `{name}` to add a "
|
||||
"{{% if add_generation_prompt %}} block. {hint}"
|
||||
).format(name = name_or_path, hint = source_hint)
|
||||
return (
|
||||
"Unsloth: The tokenizer `{name}` does not have a "
|
||||
"{{% if add_generation_prompt %}} block for generation purposes, and "
|
||||
"automatic repair was not possible. The model will still load, but "
|
||||
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
|
||||
"correct assistant-turn marker. {hint} Set "
|
||||
"UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn."
|
||||
).format(name = name_or_path, hint = source_hint)
|
||||
|
||||
|
||||
def _validate_patched_template(tokenizer, patched_template, is_sharegpt):
|
||||
"""Render the just-patched template with and without
|
||||
add_generation_prompt, and confirm the patched output responds to the
|
||||
flag by appending (not replacing) content. Returns True if validation
|
||||
passes."""
|
||||
msgs = (
|
||||
[{"from": "human", "value": "Hi"}]
|
||||
if is_sharegpt
|
||||
else [{"role": "user", "content": "Hi"}]
|
||||
)
|
||||
original = getattr(tokenizer, "chat_template", None)
|
||||
try:
|
||||
try:
|
||||
tokenizer.chat_template = patched_template
|
||||
except Exception:
|
||||
return False # read-only tokenizer, skip validation
|
||||
try:
|
||||
yes = tokenizer.apply_chat_template(
|
||||
msgs,
|
||||
add_generation_prompt = True,
|
||||
tokenize = False,
|
||||
)
|
||||
no = tokenizer.apply_chat_template(
|
||||
msgs,
|
||||
add_generation_prompt = False,
|
||||
tokenize = False,
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
tokenizer.chat_template = original
|
||||
except Exception:
|
||||
pass # best-effort restore
|
||||
# Contract after a successful repair: the two renders differ, and the
|
||||
# "yes" render is a strict extension of the "no" render (we only
|
||||
# appended content inside the new add_generation_prompt block).
|
||||
return yes != no and yes.startswith(no)
|
||||
|
||||
|
||||
def _repair_string_template(tokenizer, chat_template, is_sharegpt):
|
||||
"""Core string-template repair. Returns the repaired template on success,
|
||||
or None if repair was not possible / failed validation."""
|
||||
candidate = _fix_chat_template(chat_template, is_sharegpt = is_sharegpt)
|
||||
if not _has_add_generation_prompt_block(candidate):
|
||||
return None
|
||||
# Validate with the caller's is_sharegpt first. If that fails, the
|
||||
# dual-probe in _fix_chat_template may have fallen back to the other
|
||||
# schema internally -- try validating with the opposite schema before
|
||||
# giving up.
|
||||
if _validate_patched_template(tokenizer, candidate, is_sharegpt):
|
||||
return candidate
|
||||
if _validate_patched_template(tokenizer, candidate, not is_sharegpt):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
|
||||
"""Entry point for a string chat_template. Runs the no==yes diagnostic,
|
||||
attempts repair if needed, and returns the (possibly patched) template.
|
||||
|
||||
On repair failure, the behavior is controlled by
|
||||
UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise
|
||||
RuntimeError (strict)."""
|
||||
name = getattr(tokenizer, "name_or_path", "unknown")
|
||||
|
||||
# Detect ShareGPT vs HF style by probing apply_chat_template.
|
||||
is_sharegpt = None
|
||||
try:
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "Who are you?"}],
|
||||
add_generation_prompt = False,
|
||||
tokenize = False,
|
||||
)
|
||||
is_sharegpt = False
|
||||
except Exception:
|
||||
try:
|
||||
tokenizer.apply_chat_template(
|
||||
[{"from": "human", "value": "Who are you?"}],
|
||||
add_generation_prompt = False,
|
||||
tokenize = False,
|
||||
)
|
||||
is_sharegpt = True
|
||||
except Exception:
|
||||
is_sharegpt = None
|
||||
|
||||
if is_sharegpt is None:
|
||||
return chat_template
|
||||
|
||||
messages = (
|
||||
[{"from": "human", "value": "Who are you?"}]
|
||||
if is_sharegpt
|
||||
else [{"role": "user", "content": "Who are you?"}]
|
||||
)
|
||||
try:
|
||||
no = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = False,
|
||||
tokenize = False,
|
||||
)
|
||||
yes = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True,
|
||||
tokenize = False,
|
||||
)
|
||||
except Exception:
|
||||
return chat_template
|
||||
|
||||
if no != yes:
|
||||
# Template already responds to the flag; leave as is.
|
||||
return chat_template
|
||||
|
||||
# no == yes: template ignores add_generation_prompt. Try to repair.
|
||||
if _has_add_generation_prompt_block(chat_template):
|
||||
# Template has the block but it does not change output. This is the
|
||||
# "wasn't provided correctly" case from the pre-warn code path.
|
||||
msg = _format_chat_template_message(name, repaired = False)
|
||||
if _is_strict_chat_template_mode():
|
||||
raise RuntimeError(msg)
|
||||
logger.warning_once(msg)
|
||||
return chat_template
|
||||
|
||||
repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt)
|
||||
if repaired is not None:
|
||||
logger.warning_once(_format_chat_template_message(name, repaired = True))
|
||||
return repaired
|
||||
|
||||
msg = _format_chat_template_message(name, repaired = False)
|
||||
if _is_strict_chat_template_mode():
|
||||
raise RuntimeError(msg)
|
||||
logger.warning_once(msg)
|
||||
return chat_template
|
||||
|
||||
|
||||
class _VariantTokenizerProxy:
|
||||
"""Single-variant view of a multi-variant tokenizer. Routes each variant
|
||||
through `_fix_chat_template_for_tokenizer` so the full contract
|
||||
(is_sharegpt probe, no==yes, warn/strict, `_validate_patched_template`)
|
||||
applies instead of jumping straight to structural repair.
|
||||
|
||||
`apply_chat_template` swaps `base.chat_template` to the variant before
|
||||
calling so tokenizer globals (bos_token, filters, raise_exception) are
|
||||
preserved; falls back to bare Jinja for read-only stubs.
|
||||
"""
|
||||
|
||||
def __init__(self, base_tokenizer, variant_template, variant_label = ""):
|
||||
self._base = base_tokenizer
|
||||
self._template = variant_template
|
||||
base_name = getattr(base_tokenizer, "name_or_path", "unknown")
|
||||
self.name_or_path = (
|
||||
f"{base_name} ({variant_label})" if variant_label else base_name
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_template(self):
|
||||
return self._template
|
||||
|
||||
@chat_template.setter
|
||||
def chat_template(self, value):
|
||||
self._template = value
|
||||
|
||||
def apply_chat_template(self, *args, **kwargs):
|
||||
base_original = getattr(self._base, "chat_template", None)
|
||||
swapped = False
|
||||
try:
|
||||
try:
|
||||
self._base.chat_template = self._template
|
||||
swapped = True
|
||||
except Exception:
|
||||
swapped = False
|
||||
if swapped:
|
||||
return self._base.apply_chat_template(*args, **kwargs)
|
||||
# Read-only base: fall back to sandboxed Jinja.
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
env = SandboxedEnvironment(
|
||||
autoescape = False,
|
||||
keep_trailing_newline = True,
|
||||
)
|
||||
messages = args[0] if args else kwargs.get("messages", [])
|
||||
add_generation_prompt = kwargs.get("add_generation_prompt", False)
|
||||
return env.from_string(self._template).render(
|
||||
messages = messages,
|
||||
add_generation_prompt = add_generation_prompt,
|
||||
)
|
||||
finally:
|
||||
if swapped:
|
||||
try:
|
||||
self._base.chat_template = base_original
|
||||
except Exception:
|
||||
pass # best-effort restore
|
||||
|
||||
|
||||
def fix_chat_template(tokenizer):
|
||||
chat_template = getattr(tokenizer, "chat_template", None)
|
||||
if chat_template is None:
|
||||
return None
|
||||
|
||||
### 1. Check if add_generation_prompt works
|
||||
# Check for ShareGPT style first
|
||||
is_sharegpt = None
|
||||
try:
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
)
|
||||
is_sharegpt = False
|
||||
except:
|
||||
try:
|
||||
messages = [
|
||||
{"from": "human", "value": "Who are you?"},
|
||||
]
|
||||
tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
# Multi-variant dict (e.g. Hermes-3 {default, tool_use}): route each
|
||||
# variant through the full repair contract via _VariantTokenizerProxy.
|
||||
if isinstance(chat_template, dict):
|
||||
fixed = {}
|
||||
for key, tmpl in chat_template.items():
|
||||
if not isinstance(tmpl, str):
|
||||
fixed[key] = tmpl
|
||||
continue
|
||||
proxy = _VariantTokenizerProxy(
|
||||
tokenizer, tmpl, variant_label = f"variant={key!r}"
|
||||
)
|
||||
is_sharegpt = True
|
||||
except:
|
||||
is_sharegpt = None
|
||||
fixed[key] = _fix_chat_template_for_tokenizer(proxy, tmpl)
|
||||
return fixed
|
||||
|
||||
# Not ShareGPT or HF style - just return
|
||||
if is_sharegpt is None:
|
||||
return chat_template
|
||||
|
||||
# Tokenize
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"}
|
||||
if not is_sharegpt
|
||||
else {"from": "human", "value": "Who are you?"}
|
||||
]
|
||||
no = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt = False, tokenize = False
|
||||
)
|
||||
yes = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt = True, tokenize = False
|
||||
)
|
||||
|
||||
if no == yes:
|
||||
# SAME?! That's not good! We check for add_generation_prompt
|
||||
if (
|
||||
"{% if add_generation_prompt %}" not in chat_template
|
||||
and "{%- if add_generation_prompt %}" not in chat_template
|
||||
):
|
||||
# Try fixing it by adding it
|
||||
new_chat_template = _fix_chat_template(chat_template)
|
||||
if (
|
||||
"{% if add_generation_prompt %}" not in new_chat_template
|
||||
and "{%- if add_generation_prompt %}" not in new_chat_template
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
|
||||
"does not have a {% if add_generation_prompt %} for generation purposes.\n"
|
||||
f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!"
|
||||
)
|
||||
# List-of-dicts form (older HF multi-template style).
|
||||
if isinstance(chat_template, list):
|
||||
fixed = []
|
||||
for item in chat_template:
|
||||
if not isinstance(item, dict) or "template" not in item:
|
||||
fixed.append(item)
|
||||
continue
|
||||
tmpl = item["template"]
|
||||
if not isinstance(tmpl, str):
|
||||
fixed.append(item)
|
||||
continue
|
||||
label = f"variant={item.get('name', '?')!r}"
|
||||
proxy = _VariantTokenizerProxy(tokenizer, tmpl, variant_label = label)
|
||||
new_tmpl = _fix_chat_template_for_tokenizer(proxy, tmpl)
|
||||
if new_tmpl is tmpl or new_tmpl == tmpl:
|
||||
fixed.append(item)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"
|
||||
f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!"
|
||||
)
|
||||
chat_template = new_chat_template
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
|
||||
"has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"
|
||||
"Please file a bug report immediately - thanks!"
|
||||
)
|
||||
return chat_template
|
||||
fixed.append({**item, "template": new_tmpl})
|
||||
return fixed
|
||||
|
||||
return _fix_chat_template_for_tokenizer(tokenizer, chat_template)
|
||||
|
||||
|
||||
def check_tokenizer(
|
||||
|
|
|
|||
Loading…
Reference in a new issue