Chat Templates

This commit is contained in:
Daniel Han-Chen 2024-02-12 04:28:41 +11:00
parent acd635aa0d
commit 2cdf43d8b7
4 changed files with 42 additions and 1 deletions

View file

@ -82,3 +82,4 @@ pass
from .models import *
from .save import *
from .chat_templates import *

39
unsloth/chat_templates.py Normal file
View file

@ -0,0 +1,39 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"add_chat_template",
]
TEMPLATES = \
{
"chatml" : \
"{% for message in messages %}"\
"{% if message['from'] == 'human' %}"\
"{{'<|im_start|>user\n' + message['value'] + '<|im_end|>\n'}}"\
"{% elif message['from'] == 'gpt' %}"\
"{{'<|im_start|>assistant\n' + message['value'] + '<|im_end|>\n' }}"\
"{% else %}"\
"{{ '<|im_start|>system\n' + message['value'] + '<|im_end|>\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<|im_start|>assistant\n' }}"\
"{% endif %}",
}
def add_chat_template(tokenizer, method = "chatml"):
tokenizer.chat_template = TEMPLATES[method.lower()]
pass

View file

@ -16,6 +16,7 @@ import torch
from typing import Union, Optional, List, Any, Callable
import warnings
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer

View file

@ -540,7 +540,7 @@ def LlamaModel_fast_forward(
hidden_states = inputs_embeds
if past_key_values is None and self.gradient_checkpointing and self.training:
if past_key_values is None and self.training:
use_cache = False
# if use_cache:
# logger.warning_once(