mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Chat Templates
This commit is contained in:
parent
acd635aa0d
commit
2cdf43d8b7
4 changed files with 42 additions and 1 deletions
|
|
@ -82,3 +82,4 @@ pass
|
|||
|
||||
from .models import *
|
||||
from .save import *
|
||||
from .chat_templates import *
|
||||
|
|
|
|||
39
unsloth/chat_templates.py
Normal file
39
unsloth/chat_templates.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = [
|
||||
"add_chat_template",
|
||||
]
|
||||
|
||||
TEMPLATES = \
|
||||
{
|
||||
"chatml" : \
|
||||
"{% for message in messages %}"\
|
||||
"{% if message['from'] == 'human' %}"\
|
||||
"{{'<|im_start|>user\n' + message['value'] + '<|im_end|>\n'}}"\
|
||||
"{% elif message['from'] == 'gpt' %}"\
|
||||
"{{'<|im_start|>assistant\n' + message['value'] + '<|im_end|>\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ '<|im_start|>system\n' + message['value'] + '<|im_end|>\n' }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '<|im_start|>assistant\n' }}"\
|
||||
"{% endif %}",
|
||||
}
|
||||
|
||||
|
||||
def add_chat_template(tokenizer, method = "chatml"):
|
||||
tokenizer.chat_template = TEMPLATES[method.lower()]
|
||||
pass
|
||||
|
|
@ -16,6 +16,7 @@ import torch
|
|||
from typing import Union, Optional, List, Any, Callable
|
||||
import warnings
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
|
||||
import bitsandbytes as bnb
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
|
|
|||
|
|
@ -540,7 +540,7 @@ def LlamaModel_fast_forward(
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if past_key_values is None and self.gradient_checkpointing and self.training:
|
||||
if past_key_values is None and self.training:
|
||||
use_cache = False
|
||||
# if use_cache:
|
||||
# logger.warning_once(
|
||||
|
|
|
|||
Loading…
Reference in a new issue