Chat-template repair: warn-by-default, AST classification, dict support (#5049)

* Chat-template repair: warn-by-default, AST classification, dict support

Follow-up hardening on top of PR #4426 (which fixed the #4150
RuntimeError for ChatML LoRA reloads).

Behavior changes:

- Warn-by-default instead of RuntimeError. When fix_chat_template cannot
  repair a broken template, emit a warning and return the original.
  Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to restore the pre-warn hard fail.
  Fixes the UX where a missing `{% if add_generation_prompt %}` block on
  a saved LoRA (typical after LlamaFactory / Axolotl re-serialize) would
  block model loading entirely.

- Local path vs HF hub distinguished in the warning message. For local
  paths the message points at the likely downstream tool; for HF IDs it
  points at the upstream model maintainers. Previously both said "file a
  bug report to the maintainers of <path>" even when <path> was the
  user's own saves/ directory.

- Dict / list chat_template now handled. Hermes-3 ships with
  {default, tool_use} and the previous code crashed with
  AttributeError: 'dict' object has no attribute 'find' when entering
  _fix_chat_template with a dict. Each variant is now fixed
  independently; structure is preserved.

Internals:

- _find_end_position now matches all four Jinja whitespace-control
  variants ({% %}, {%- %}, {% -%}, {%- -%}) and returns the rightmost
  endfor/endif so multi-for templates aren't locked onto the first loop.
  Previously {%- endfor -%} (both-side dash, used by Qwen3-Guard) was
  silently bypassed.

- _has_add_generation_prompt_block uses Jinja AST via
  jinja2.nodes.If/Name walks instead of substring matching, so
  templates that hide the block behind comments or dash-style variants
  are classified correctly.

- _template_ends_with_toplevel_for gates the GH#4150 ChatML repair on
  the AST: only fires when the last structural top-level node is a For
  (standard ChatML shape), ignoring trailing pure-whitespace output
  nodes. Templates wrapped in an outer If (Qwen3-Guard) are now
  explicitly skipped at the _fix_chat_template level as well, not just
  at load_correct_tokenizer's name-based exemption.

- _validate_patched_template renders the patched template with and
  without add_generation_prompt and confirms the patched output
  responds to the flag by appending (not replacing) content. If
  validation fails, the patch is discarded and we fall through to the
  warn path.

Verified with an expanded regression suite in tests/:
- test_fix_chat_template_pr4426.py: 42/42 template-matrix cells
- test_load_correct_tokenizer_pr4426.py: 5/5 tokenizer loads
- test_chat_template_followups.py: 10/10 new follow-up tests
- test_mistral_pr4426.py: 5 Mistral variants byte-identical
- test_qwen_pr4426.py: 14 Qwen variants byte-identical
  (Qwen1.5, Qwen2, Qwen2.5-Instruct/Coder/Math/VL, Qwen3,
  Qwen3-Coder, QwQ, Qwen3-Guard-Gen)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard _validate_patched_template against read-only chat_template

If tokenizer.chat_template is a property or otherwise read-only, the
validation helper would crash with AttributeError when trying to
temporarily set the patched template. Catch the assignment failure and
return False (skip validation), and best-effort restore in the finally
block.

* Replace regex separator inference with render-diff; broaden repair to non-ChatML templates

The previous `_infer_assistant_separator` was a four-tier regex heuristic that
only worked on ChatML-shaped templates and forced a hard `<|im_start|>` /
`<|im_end|>` presence gate on Case 2 repair. This meant a Llama-3, Gemma, or
Phi-3 template stripped of its generation-prompt block by a downstream tool
(LlamaFactory, Axolotl, etc.) would still warn-and-return even though the
structural shape is identical to the ChatML case the PR already handles.

This replaces the regex with `_derive_assistant_prefix_by_render`: render the
template with two dialogs that differ only in assistant content, then
`os.path.commonprefix` on the tails captures the exact assistant-turn prefix
the template emits. The template itself is ground truth, so non-ChatML shapes
work as long as the assistant block is a literal the template emits once per
message.

Three guards keep the derivation safe:
  A. both assistant renders extend the base render (no reordering);
  B. the divergence point is exactly the content-insertion site (sentinel
     follows the common prefix);
  C. a user-role cross-check: if a render with a user sentinel also emits
     the same prefix, role has no effect on output and we reject. A render
     failure on [user, user] (e.g. Gemma's `raise_exception` alternation
     check) is evidence that role matters; we accept.

Sentinels differ at character 0 so `commonprefix` cannot absorb them, and
trailing whitespace/comments after the last `{% endfor %}` are stripped
before probing (they would appear in base but not after the appended
assistant turn and break Guard A).

`_fix_chat_template` and `_repair_string_template` now thread an
`is_sharegpt` kwarg; `_fix_chat_template` retries once with
`is_sharegpt=True` if the first probe returns None (dual-probe fallback
for dict/list callers).

The ChatML `<|im_start|>` / `<|im_end|>` hard gate in Case 2 is dropped.
`_infer_assistant_separator` is deleted.

Verified via:
  - tests/test_fix_chat_template_pr4426.py: 51/51 cells (new Llama-3,
    Gemma, Phi-3 broken-template rows all repair FIX-OK)
  - tests/test_load_correct_tokenizer_pr4426.py: 5/5
  - tests/test_chat_template_followups.py: 18/18 (T11-T18 cover
    non-ChatML repair + probe failure modes)
  - tests/test_mistral_pr4426.py: 5/5 byte-identical
  - tests/test_qwen_pr4426.py: 14/14 byte-identical (Qwen3-Guard AST
    gate still rejects)
  - tests/hermes3_lora_pr4426.py reload: patched template ends with
    `<|im_start|>assistant\n`, inference returns sensible output.
  - temp/sim/battery.py: 79/79 followup; vs baseline: 0 regressions,
    9 improvements.
  - Spot-check probe on real stripped tokenizers (Hermes-3, Phi-4,
    Llama-3.2-1B, Gemma-3-1B): all derive the expected prefix.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address reviewer findings: variant routing, positive-gate detection, comment-safe end scan

Resolves three reviewer findings on PR #5049 (`fix/chat-template-followups`):

Finding #1 [10/10]: dict/list variants now route through
`_fix_chat_template_for_tokenizer` via a new `_VariantTokenizerProxy`
adapter. Previously the dict/list branches called `_fix_chat_template`
directly, silently bypassing the warn/strict (`UNSLOTH_STRICT_CHAT_TEMPLATE`)
contract, the `no == yes` diagnostic, broken-existing-block detection,
and `_validate_patched_template` guard. The proxy swaps
`base.chat_template` to the variant string before each
`apply_chat_template` call so tokenizer globals (`bos_token`, custom
filters, `raise_exception`) remain available; if the base is read-only
it falls back to isolated Jinja rendering.

Finding #2 [1/10]: `_has_add_generation_prompt_block` now requires the
`If` body to contain at least one `Output` node (a new
`_if_body_emits_content` helper walks descendants). This distinguishes a
real generation-prompt block from a header guard like
`{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}`
(body contains only `Assign`) which references the name but emits
nothing. Also dropped a now-redundant `"add_generation_prompt" not in
scrubbed` guard in `_fix_chat_template` Case 2 so header-guarded
templates still get repaired.

Finding #4 [1/10]: `_find_end_position` now replaces Jinja comments with
equal-length whitespace before scanning for `{% endfor %}` / `{% endif %}`
tokens. This prevents a trailing comment containing those tokens from
being picked as the real end tag. Positions in the padded string map 1:1
to positions in the original template.

Tests:
  - tests/test_chat_template_followups.py: 21/21 (T19 strict-mode
    dict variant, T20 header-guard repair, T21 comment-endfor trap
    added; T4/T5 stubs updated with a working apply_chat_template
    that routes through Jinja).
  - tests/test_fix_chat_template_pr4426.py: 51/51 cells unchanged.
  - tests/test_load_correct_tokenizer_pr4426.py: 5/5.
  - tests/test_mistral_pr4426.py: 5/5 byte-identical.
  - tests/test_qwen_pr4426.py: 14/14 byte-identical.
  - temp/sim/battery.py: 79/79 followup; 0 regressions vs baseline.
  - Phase 3 Hermes-3 broken-LoRA reload: inference still returns
    `'The answer to the equation 2+2 is 4.'`.
  - Spot-checks on Hermes-3 / Phi-4 / Llama-3.2-1B / Gemma-3-1B real
    stripped templates: probe still derives the expected prefix.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Tighten comments in chat-template helpers

Pure comment minimization across `_find_end_position`,
`_has_add_generation_prompt_block`, `_if_body_emits_content`,
`_derive_assistant_prefix_by_render`, `_fix_chat_template` Case 2,
and `_VariantTokenizerProxy`. No behavior change; same intent,
fewer lines. All 21 follow-up tests and the 51-cell Phase 1 matrix
still pass.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Sandbox probe, fix is_sharegpt validator mismatch, reject negated gates

Three real bugs from the 10-agent Opus review:

1. Probe now uses `jinja2.sandbox.SandboxedEnvironment` instead of bare
   `jinja2.Environment`. The probe renders at model-load time (before
   the user calls `apply_chat_template`), so it was a new eager
   code-execution surface that the base HF tokenizer loading does not
   have. SandboxedEnvironment blocks attribute-chain exploits at
   negligible cost.

2. `_repair_string_template` now tries validation with both
   `is_sharegpt=False` and `is_sharegpt=True`. Previously, when
   `_fix_chat_template` internally fell back to the other schema via
   its dual-probe, the outer validation still used the caller's
   original `is_sharegpt` -- rendering with the wrong message keys and
   spuriously dropping a valid repair.

3. `_has_add_generation_prompt_block` now skips `If` nodes whose test
   is a `Not` expression. A negated gate like
   `{% if not add_generation_prompt %}{{ x }}{% endif %}` fires when
   agp=False, so its emitting body is not a generation block -- but the
   old code counted any Name reference regardless of polarity.

Cleanup: removed unused `self._label`, added `\r` escape in
generation-block literal, switched variant labels to `!r` formatting,
removed redundant `import os as _os`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix jinja2.sandbox import and sandbox proxy fallback

Two critical findings from the 20-reviewer pass:

1. [20/20] The proxy read-only fallback used bare `jinja2.Environment`,
   not sandboxed. All 20 reviewers independently reproduced marker-file
   creation via `cycler.__init__.__globals__['os'].system(...)` during
   `fix_chat_template()`. Fixed: fallback now uses
   `from jinja2.sandbox import SandboxedEnvironment`.

2. [14/20] The render-diff probe did `import jinja2` then referenced
   `jinja2.sandbox.SandboxedEnvironment`. `jinja2.sandbox` is a
   submodule that is NOT auto-imported by `import jinja2` on Jinja 3.1.6.
   This caused `AttributeError` (swallowed by `except Exception`),
   making the entire Case 2 repair path silently return None in a clean
   process. The 6 reviewers who saw it work had `jinja2.sandbox`
   pre-imported by an earlier module in their process. Fixed: both the
   probe and the proxy fallback now use
   `from jinja2.sandbox import SandboxedEnvironment`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-16 05:52:33 -07:00 committed by GitHub
parent 6e87bade25
commit c5be8b1cd2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -636,173 +636,579 @@ def load_correct_tokenizer(
return tokenizer
def _find_end_position(template, endfor, endif):
where_endfor = template.find(endfor)
where_endif = template.find(endif)
if where_endfor == where_endif == -1:
# All four Jinja whitespace-control variants of endfor/endif:
# {% endfor %} {%- endfor %} {% endfor -%} {%- endfor -%}
_RE_ENDFOR = re.compile(r"\{%(-?)\s*endfor\s*(-?)%\}")
_RE_ENDIF = re.compile(r"\{%(-?)\s*endif\s*(-?)%\}")
_RE_JINJA_COMMENT = re.compile(r"\{#.*?#\}", flags = re.DOTALL)
def _find_end_position(template, endfor = None, endif = None):
"""Rightmost {% endfor %}/{% endif %} (any dash variant), as a dict
with start/end/text/dash_left/dash_right. Tokens inside Jinja comments
are ignored. `endfor`/`endif` kwargs kept for back-compat, ignored."""
# Space-pad comments so positions still map 1:1 to the original.
scrubbed = _RE_JINJA_COMMENT.sub(lambda m: " " * len(m.group(0)), template)
endfor_matches = list(_RE_ENDFOR.finditer(scrubbed))
endif_matches = list(_RE_ENDIF.finditer(scrubbed))
last_endfor = endfor_matches[-1] if endfor_matches else None
last_endif = endif_matches[-1] if endif_matches else None
candidates = [m for m in (last_endfor, last_endif) if m is not None]
if not candidates:
return None
elif where_endfor > where_endif:
return endfor
m = max(candidates, key = lambda x: x.end())
return {
"start": m.start(),
"end": m.end(),
"text": m.group(0),
"dash_left": bool(m.group(1)),
"dash_right": bool(m.group(2)),
}
def _template_ends_with_toplevel_for(chat_template):
"""Return True if the last structural node at the template's top level is
a For (message-iteration) loop, ignoring trailing pure-whitespace Output
nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure
is something else (e.g. an outer If that wraps the whole template, as in
Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %}
block at the end -- it would land inside or after an unrelated control
structure."""
try:
import jinja2
import jinja2.nodes
ast = jinja2.Environment().parse(chat_template)
except Exception:
return False
for node in reversed(ast.body):
# Skip trailing output nodes that are only whitespace -- they come
# from trailing whitespace/newlines in the source, not from real
# message-rendering logic.
if isinstance(node, jinja2.nodes.Output):
only_ws = all(
isinstance(child, jinja2.nodes.TemplateData)
and child.data.strip() == ""
for child in node.nodes
)
if only_ws:
continue
return isinstance(node, jinja2.nodes.For)
return False
def _if_body_emits_content(if_node):
"""True if the If's body contains any Output node (directly or nested).
Distinguishes a real generation block from a header guard that only
does `{% set ... %}`."""
import jinja2.nodes
for node in if_node.body:
if isinstance(node, jinja2.nodes.Output):
return True
if any(
isinstance(d, jinja2.nodes.Output)
for d in node.find_all(jinja2.nodes.Output)
):
return True
return False
def _has_add_generation_prompt_block(chat_template):
"""True if the template has a *positive* `{% if add_generation_prompt %}`
gate whose body emits output. Rejects header guards like
`{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}`
that reference the name but emit nothing. AST-based; string-scan
fallback if Jinja fails to parse."""
try:
import jinja2
import jinja2.nodes
ast = jinja2.Environment().parse(chat_template)
except Exception:
return "if add_generation_prompt" in chat_template and "%}" in chat_template
for if_node in ast.find_all(jinja2.nodes.If):
test = if_node.test
# Reject negated gates: `{% if not add_generation_prompt %}` fires
# when agp=False, so it's not a generation block even if it emits.
if isinstance(test, jinja2.nodes.Not):
continue
# find_all skips the test root, so check bare Name tests explicitly.
references_agp = False
if isinstance(test, jinja2.nodes.Name) and test.name == "add_generation_prompt":
references_agp = True
else:
for name_node in test.find_all(jinja2.nodes.Name):
if name_node.name == "add_generation_prompt":
references_agp = True
break
if references_agp and _if_body_emits_content(if_node):
return True
return False
# Sentinels for _derive_assistant_prefix_by_render. Diverge at char 0 so
# commonprefix can't absorb them; long random tail makes collision with real
# template literals negligible (see T18).
_RENDER_DIFF_SENTINEL_A = "AAAA_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
_RENDER_DIFF_SENTINEL_B = "BBBB_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
_RENDER_DIFF_SENTINEL_C = "CCCC_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL"
def _derive_assistant_prefix_by_render(chat_template, is_sharegpt = False):
"""Return the assistant-turn prefix the template emits, derived by
rendering two dialogs that differ only in assistant content: the common
prefix of their tails (after the base [user]-only render) is what the
template emits for an assistant turn. None if any guard fails.
Works for Llama-3 / Gemma / Phi-3 and other non-ChatML shapes; the
template is its own ground truth.
Known limitation: an `eos-on-non-last` pattern (turn-end sentinel only
emitted for non-last messages) would produce a consistent but wrong
prefix that `_validate_patched_template` can't catch. No real-world
template is known to use this.
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except Exception:
return None
if is_sharegpt:
base_msgs = [{"from": "human", "value": "Hi"}]
sent_a_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_A}]
sent_b_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_B}]
# User-role cross-check (Guard C below).
sent_c_msgs = base_msgs + [{"from": "human", "value": _RENDER_DIFF_SENTINEL_C}]
else:
return endif
base_msgs = [{"role": "user", "content": "Hi"}]
sent_a_msgs = base_msgs + [
{"role": "assistant", "content": _RENDER_DIFF_SENTINEL_A}
]
sent_b_msgs = base_msgs + [
{"role": "assistant", "content": _RENDER_DIFF_SENTINEL_B}
]
sent_c_msgs = base_msgs + [{"role": "user", "content": _RENDER_DIFF_SENTINEL_C}]
# Strip trailing whitespace/comments after the last endfor/endif: they
# appear after the message loop and would break Guard A. The splice in
# `_fix_chat_template` drops them too.
probe_template = chat_template
end = _find_end_position(chat_template)
if end is not None:
after = chat_template[end["end"] :]
if _RE_JINJA_COMMENT.sub("", after).strip() == "":
probe_template = chat_template[: end["end"]]
# Sandboxed: probe renders at load time, before user calls
# apply_chat_template. SandboxedEnvironment blocks attribute-chain exploits.
try:
env = SandboxedEnvironment(
autoescape = False,
keep_trailing_newline = True,
)
tmpl = env.from_string(probe_template)
out_base = tmpl.render(messages = base_msgs, add_generation_prompt = False)
out_a = tmpl.render(messages = sent_a_msgs, add_generation_prompt = False)
out_b = tmpl.render(messages = sent_b_msgs, add_generation_prompt = False)
except Exception:
return None
# Best-effort: alternation-enforcing templates (e.g. Gemma's
# raise_exception) fail on [user, user]; that's a positive signal
# for Guard C, not a probe failure.
out_user_c = None
try:
out_user_c = tmpl.render(messages = sent_c_msgs, add_generation_prompt = False)
except Exception:
pass
# Guard A: assistant renders extend base (no reordering).
if not (out_a.startswith(out_base) and out_b.startswith(out_base)):
return None
tail_a = out_a[len(out_base) :]
tail_b = out_b[len(out_base) :]
if not tail_a or not tail_b:
return None
prefix = os.path.commonprefix([tail_a, tail_b])
# Guard B: divergence is exactly at the content-insertion site.
if not (
tail_a[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_A)
and tail_b[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_B)
):
return None
# Guard C: reject if a [user, user] render also emits the same prefix
# (role-insensitive template, e.g. `{% set greeting='Hi' %}...`).
if out_user_c is not None and out_user_c.startswith(out_base):
tail_c = out_user_c[len(out_base) :]
if tail_c.startswith(prefix) and prefix != "":
return None
if not prefix:
return None
return prefix
def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
endif = "{% endif %}"
chosen_end = _find_end_position(chat_template, endfor, endif)
if chosen_end is None:
endfor = "{%- endfor %}"
endif = "{%- endif %}"
chosen_end = _find_end_position(chat_template, endfor, endif)
if chosen_end is None:
def _fix_chat_template(chat_template, is_sharegpt = False):
# Fast path: already has an {% if add_generation_prompt %} block, nothing
# to do. This catches cases the old string-based check would miss (e.g.
# templates that use {%- if add_generation_prompt -%} with both-side dash,
# or that sneak the block into a nested If/For).
if _has_add_generation_prompt_block(chat_template):
return chat_template
where = chat_template.find(chosen_end)
end = _find_end_position(chat_template)
if end is None:
return chat_template
after_endfor = chat_template[where + len(chosen_end) :]
dash = "-" if chosen_end.startswith("{%-") else ""
after_endfor = chat_template[end["end"] :]
dash_l = "-" if end["dash_left"] else ""
dash_r = "-" if end["dash_right"] else ""
open_tag = lambda body: "{%" + dash_l + " " + body + " " + dash_r + "%}"
# Case 1 (pre-existing base case): template ends with a single trailing
# {{ expr }} that is the generation prefix. Wrap it in an
# {% if add_generation_prompt %} ... {% endif %}.
if (
"{%" + dash + " if" not in after_endfor
and "{%" + dash + " set " not in after_endfor
"{%" + dash_l + " if" not in after_endfor
and "{%" + dash_l + " set " not in after_endfor
and after_endfor.startswith("{{")
and after_endfor.endswith("}}")
and after_endfor.count("{{") == 1
and after_endfor.count("}}") == 1
):
after_endfor = (
"{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
wrapped = (
open_tag("if add_generation_prompt") + after_endfor + open_tag("endif")
)
return chat_template[: end["end"]] + wrapped
chat_template = chat_template[: where + len(chosen_end)] + after_endfor
elif re.sub(r"\{#.*?#\}", "", after_endfor, flags = re.DOTALL).strip() == "":
# GH#4150: ChatML templates ending at {% endfor %} without an
# add_generation_prompt block. Scrub Jinja `{# ... #}` comments so
# tokens inside comments cannot fool the guard below.
scrubbed = re.sub(r"\{#.*?#\}", "", chat_template, flags = re.DOTALL)
if (
"<|im_start|>" in scrubbed
and "<|im_end|>" in scrubbed
and "add_generation_prompt" not in scrubbed
):
# Infer the assistant-turn separator. Prefer an explicit
# '<|im_start|>assistant<sep>' literal; else the unique
# `message['role'] + '<sep>'` from role concatenations; else
# '<|im_sep|>' if present (Phi-4-mini uses '\n' for system and
# '<|im_sep|>' for user/assistant); else '\n'.
assistant_match = re.search(
r"""(['"])<\|im_start\|>assistant([^'"]*)\1""",
scrubbed,
# Case 2 (GH#4150): template ends at {% endfor %} with only whitespace
# or comments left. Inject an {% if add_generation_prompt %} block with
# the assistant prefix derived by render-diff. The top-level-For gate
# keeps us out of outer-If wrappers (e.g. Qwen3-Guard).
if _RE_JINJA_COMMENT.sub(
"", after_endfor
).strip() == "" and _template_ends_with_toplevel_for(chat_template):
# No redundant "agp not in scrubbed" check: the fast path already
# confirmed no *positive* block, and a mere reference (header
# guard) should still get repaired.
assistant_prefix = _derive_assistant_prefix_by_render(
chat_template, is_sharegpt
)
# Dual-probe: dict/list callers don't know the shape up front.
if assistant_prefix is None and not is_sharegpt:
assistant_prefix = _derive_assistant_prefix_by_render(
chat_template, is_sharegpt = True
)
role_seps = [
m.group(2)
for m in re.finditer(
r"""message(?:\[['"]role['"]\]|\.role)\s*\+\s*(['"])([^'"]*)\1""",
scrubbed,
)
]
unique_role_seps = list(dict.fromkeys(role_seps))
if assistant_match is not None and assistant_match.group(2):
separator = assistant_match.group(2)
elif len(unique_role_seps) == 1:
separator = unique_role_seps[0]
elif "<|im_sep|>" in scrubbed:
separator = "<|im_sep|>"
else:
separator = "\\n"
# Emit a double-quoted Jinja literal so a single quote in the
# separator cannot break the block. Drop trailing whitespace/
# comments after endfor: they would render as stray output
# after the generation prefix.
assistant_prefix = "<|im_start|>assistant" + separator
generation_block = (
"{%" + dash + " if add_generation_prompt %}"
'{{ "' + assistant_prefix.replace('"', '\\"') + '" }}'
"{%" + dash + " endif %}"
)
chat_template = chat_template[: where + len(chosen_end)] + generation_block
if assistant_prefix is None:
return chat_template
# Escape for a double-quoted Jinja string literal.
escaped = (
assistant_prefix.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", "\\n")
.replace("\r", "\\r")
)
generation_block = (
open_tag("if add_generation_prompt")
+ '{{ "'
+ escaped
+ '" }}'
+ open_tag("endif")
)
return chat_template[: end["end"]] + generation_block
return chat_template
def _is_strict_chat_template_mode():
"""Opt-in strict mode restores the pre-warn RuntimeError behavior."""
val = os.environ.get("UNSLOTH_STRICT_CHAT_TEMPLATE", "0")
return str(val).strip().lower() in ("1", "true", "yes", "on")
def _name_is_local_path(name_or_path):
"""True if name_or_path refers to an existing local directory. Used to
tailor the warning message: for local paths the user cannot 'file a bug
report to the maintainers of <path>' since that path is their own."""
if not name_or_path:
return False
try:
return os.path.isdir(str(name_or_path))
except Exception:
return False
def _format_chat_template_message(name_or_path, repaired):
"""Build a user-facing warning/error message that points at the right
responsible party (user's downstream tool vs. upstream model maintainer)."""
local = _name_is_local_path(name_or_path)
if local:
source_hint = (
"This tokenizer was loaded from a local path. The likely cause is a "
"downstream tool (LlamaFactory, Axolotl, etc.) that re-serialized "
"the tokenizer during save and stripped the generation-prompt "
"block. Either re-save with the original template, or set "
"`tokenizer.chat_template` manually before loading."
)
else:
source_hint = (
"The chat_template shipped with `{name}` appears incomplete. "
"Consider filing a bug report with the model maintainers."
).format(name = name_or_path)
if repaired:
return (
"Unsloth: Patched the chat_template on `{name}` to add a "
"{{% if add_generation_prompt %}} block. {hint}"
).format(name = name_or_path, hint = source_hint)
return (
"Unsloth: The tokenizer `{name}` does not have a "
"{{% if add_generation_prompt %}} block for generation purposes, and "
"automatic repair was not possible. The model will still load, but "
"`apply_chat_template(add_generation_prompt=True)` may not produce a "
"correct assistant-turn marker. {hint} Set "
"UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn."
).format(name = name_or_path, hint = source_hint)
def _validate_patched_template(tokenizer, patched_template, is_sharegpt):
"""Render the just-patched template with and without
add_generation_prompt, and confirm the patched output responds to the
flag by appending (not replacing) content. Returns True if validation
passes."""
msgs = (
[{"from": "human", "value": "Hi"}]
if is_sharegpt
else [{"role": "user", "content": "Hi"}]
)
original = getattr(tokenizer, "chat_template", None)
try:
try:
tokenizer.chat_template = patched_template
except Exception:
return False # read-only tokenizer, skip validation
try:
yes = tokenizer.apply_chat_template(
msgs,
add_generation_prompt = True,
tokenize = False,
)
no = tokenizer.apply_chat_template(
msgs,
add_generation_prompt = False,
tokenize = False,
)
except Exception:
return False
finally:
try:
tokenizer.chat_template = original
except Exception:
pass # best-effort restore
# Contract after a successful repair: the two renders differ, and the
# "yes" render is a strict extension of the "no" render (we only
# appended content inside the new add_generation_prompt block).
return yes != no and yes.startswith(no)
def _repair_string_template(tokenizer, chat_template, is_sharegpt):
"""Core string-template repair. Returns the repaired template on success,
or None if repair was not possible / failed validation."""
candidate = _fix_chat_template(chat_template, is_sharegpt = is_sharegpt)
if not _has_add_generation_prompt_block(candidate):
return None
# Validate with the caller's is_sharegpt first. If that fails, the
# dual-probe in _fix_chat_template may have fallen back to the other
# schema internally -- try validating with the opposite schema before
# giving up.
if _validate_patched_template(tokenizer, candidate, is_sharegpt):
return candidate
if _validate_patched_template(tokenizer, candidate, not is_sharegpt):
return candidate
return None
def _fix_chat_template_for_tokenizer(tokenizer, chat_template):
"""Entry point for a string chat_template. Runs the no==yes diagnostic,
attempts repair if needed, and returns the (possibly patched) template.
On repair failure, the behavior is controlled by
UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise
RuntimeError (strict)."""
name = getattr(tokenizer, "name_or_path", "unknown")
# Detect ShareGPT vs HF style by probing apply_chat_template.
is_sharegpt = None
try:
tokenizer.apply_chat_template(
[{"role": "user", "content": "Who are you?"}],
add_generation_prompt = False,
tokenize = False,
)
is_sharegpt = False
except Exception:
try:
tokenizer.apply_chat_template(
[{"from": "human", "value": "Who are you?"}],
add_generation_prompt = False,
tokenize = False,
)
is_sharegpt = True
except Exception:
is_sharegpt = None
if is_sharegpt is None:
return chat_template
messages = (
[{"from": "human", "value": "Who are you?"}]
if is_sharegpt
else [{"role": "user", "content": "Who are you?"}]
)
try:
no = tokenizer.apply_chat_template(
messages,
add_generation_prompt = False,
tokenize = False,
)
yes = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
tokenize = False,
)
except Exception:
return chat_template
if no != yes:
# Template already responds to the flag; leave as is.
return chat_template
# no == yes: template ignores add_generation_prompt. Try to repair.
if _has_add_generation_prompt_block(chat_template):
# Template has the block but it does not change output. This is the
# "wasn't provided correctly" case from the pre-warn code path.
msg = _format_chat_template_message(name, repaired = False)
if _is_strict_chat_template_mode():
raise RuntimeError(msg)
logger.warning_once(msg)
return chat_template
repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt)
if repaired is not None:
logger.warning_once(_format_chat_template_message(name, repaired = True))
return repaired
msg = _format_chat_template_message(name, repaired = False)
if _is_strict_chat_template_mode():
raise RuntimeError(msg)
logger.warning_once(msg)
return chat_template
class _VariantTokenizerProxy:
"""Single-variant view of a multi-variant tokenizer. Routes each variant
through `_fix_chat_template_for_tokenizer` so the full contract
(is_sharegpt probe, no==yes, warn/strict, `_validate_patched_template`)
applies instead of jumping straight to structural repair.
`apply_chat_template` swaps `base.chat_template` to the variant before
calling so tokenizer globals (bos_token, filters, raise_exception) are
preserved; falls back to bare Jinja for read-only stubs.
"""
def __init__(self, base_tokenizer, variant_template, variant_label = ""):
self._base = base_tokenizer
self._template = variant_template
base_name = getattr(base_tokenizer, "name_or_path", "unknown")
self.name_or_path = (
f"{base_name} ({variant_label})" if variant_label else base_name
)
@property
def chat_template(self):
return self._template
@chat_template.setter
def chat_template(self, value):
self._template = value
def apply_chat_template(self, *args, **kwargs):
base_original = getattr(self._base, "chat_template", None)
swapped = False
try:
try:
self._base.chat_template = self._template
swapped = True
except Exception:
swapped = False
if swapped:
return self._base.apply_chat_template(*args, **kwargs)
# Read-only base: fall back to sandboxed Jinja.
from jinja2.sandbox import SandboxedEnvironment
env = SandboxedEnvironment(
autoescape = False,
keep_trailing_newline = True,
)
messages = args[0] if args else kwargs.get("messages", [])
add_generation_prompt = kwargs.get("add_generation_prompt", False)
return env.from_string(self._template).render(
messages = messages,
add_generation_prompt = add_generation_prompt,
)
finally:
if swapped:
try:
self._base.chat_template = base_original
except Exception:
pass # best-effort restore
def fix_chat_template(tokenizer):
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is None:
return None
### 1. Check if add_generation_prompt works
# Check for ShareGPT style first
is_sharegpt = None
try:
messages = [
{"role": "user", "content": "Who are you?"},
]
tokenizer.apply_chat_template(
messages, add_generation_prompt = False, tokenize = False
)
is_sharegpt = False
except:
try:
messages = [
{"from": "human", "value": "Who are you?"},
]
tokenizer.apply_chat_template(
messages, add_generation_prompt = False, tokenize = False
# Multi-variant dict (e.g. Hermes-3 {default, tool_use}): route each
# variant through the full repair contract via _VariantTokenizerProxy.
if isinstance(chat_template, dict):
fixed = {}
for key, tmpl in chat_template.items():
if not isinstance(tmpl, str):
fixed[key] = tmpl
continue
proxy = _VariantTokenizerProxy(
tokenizer, tmpl, variant_label = f"variant={key!r}"
)
is_sharegpt = True
except:
is_sharegpt = None
fixed[key] = _fix_chat_template_for_tokenizer(proxy, tmpl)
return fixed
# Not ShareGPT or HF style - just return
if is_sharegpt is None:
return chat_template
# Tokenize
messages = [
{"role": "user", "content": "Who are you?"}
if not is_sharegpt
else {"from": "human", "value": "Who are you?"}
]
no = tokenizer.apply_chat_template(
messages, add_generation_prompt = False, tokenize = False
)
yes = tokenizer.apply_chat_template(
messages, add_generation_prompt = True, tokenize = False
)
if no == yes:
# SAME?! That's not good! We check for add_generation_prompt
if (
"{% if add_generation_prompt %}" not in chat_template
and "{%- if add_generation_prompt %}" not in chat_template
):
# Try fixing it by adding it
new_chat_template = _fix_chat_template(chat_template)
if (
"{% if add_generation_prompt %}" not in new_chat_template
and "{%- if add_generation_prompt %}" not in new_chat_template
):
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
"does not have a {% if add_generation_prompt %} for generation purposes.\n"
f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!"
)
# List-of-dicts form (older HF multi-template style).
if isinstance(chat_template, list):
fixed = []
for item in chat_template:
if not isinstance(item, dict) or "template" not in item:
fixed.append(item)
continue
tmpl = item["template"]
if not isinstance(tmpl, str):
fixed.append(item)
continue
label = f"variant={item.get('name', '?')!r}"
proxy = _VariantTokenizerProxy(tokenizer, tmpl, variant_label = label)
new_tmpl = _fix_chat_template_for_tokenizer(proxy, tmpl)
if new_tmpl is tmpl or new_tmpl == tmpl:
fixed.append(item)
else:
logger.warning_once(
"Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"
f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!"
)
chat_template = new_chat_template
else:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
"has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"
"Please file a bug report immediately - thanks!"
)
return chat_template
fixed.append({**item, "template": new_tmpl})
return fixed
return _fix_chat_template_for_tokenizer(tokenizer, chat_template)
def check_tokenizer(