mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Prelim Feb release (#173)
* Works? * Update pyproject.toml * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Swiglu * Update swiglu.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update swiglu.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * attention_mask * Update llama.py * Update llama.py * labels * Update mistral.py * Update llama.py * attention mask * Update save.py * Update save.py * Update mistral.py * attention mask * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update dpo.py * Patch saving * Update save.py * Update save.py * patch_saving_functions * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * print * Mistral patch * Update mistral.py * Update save.py * saving * Update llama.py * Update llama.py * Fast inference repatch * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update mistral.py * Update __init__.py * Fix inference * Update mistral.py * fast lm_head * Remove fast path * Update rope_embedding.py * Update loader.py * LlamaAttention_fast_forward_inference * if past_key_value is not None and q_len == 1: * revert inference * Update loader.py * past_key_value * Update llama.py * Update llama.py * Fix SDPA * Update llama.py * padding * Inference * Update llama.py * Revert * Update mistral.py * faster inference * inference * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * inference * Update llama.py * Update utils.py * faster inference * Update llama.py * revert * lm_head * Update llama.py * inference * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * faster inference * Update llama.py * fast inference * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * torch compile * past_key_values * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update llama.py * fast inference + saving config.json * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update mistral.py * fast inference again * more temp matrices * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * fast inference * Update mistral.py * Update llama.py * SDPA * attention_mask * New version * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update save.py * Update save.py * Torch 2.2.0 * Update save.py * mistral swa * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Fix SWA inference * Fix llm_int8_skip_modules * SWA inference * Update save.py * Update save.py * Update pyproject.toml * __version__ * __version__ * Update save.py * Update save.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Chat Templates * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * patch tokenizer * Update chat_templates.py * Saving, LlamaRotaryEmbedding issues * Update llama.py * Update mistral.py
This commit is contained in:
parent
474fd32f91
commit
0439b8508d
8 changed files with 494 additions and 6 deletions
|
|
@ -29,7 +29,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
|
|||
| **CodeLlama 34b** A100 | [▶️ Start on Colab](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing) | 1.9x faster | 27% less |
|
||||
| **Mistral 7b** 1xT4 | [▶️ Start on Kaggle](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook) | 5x faster\* | 62% less |
|
||||
|
||||
- This [conversational notebook](https://colab.research.google.com/drive/1bMOKOBzxQWUIGZBs_B0zm8pimuEnZdfM?usp=sharing) is useful for ShareGPT ChatML datatsets.
|
||||
- This [conversational notebook](https://colab.research.google.com/drive/1Aau3lgPzeZKQ-98h69CCu1UJcvIBLmy2?usp=sharing) is useful for ShareGPT ChatML / Vicuna templates.
|
||||
- Our [raw text notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is useful for text completion.
|
||||
- Colab provides a free GPU sometimes. Kaggle has 30 hrs free per week on a 12 hr running cap.
|
||||
- \* Kaggle has 2x T4s, but we use 1. Due to overhead, 1x T4 is 5x faster. Use Colab as Kaggle takes 10 mins to install.
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ huggingface = [
|
|||
"peft>=0.7.1",
|
||||
"tqdm",
|
||||
"psutil",
|
||||
"wheel>=0.42.0",
|
||||
]
|
||||
cu118only = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
||||
|
|
|
|||
|
|
@ -82,3 +82,4 @@ pass
|
|||
|
||||
from .models import *
|
||||
from .save import *
|
||||
from .chat_templates import *
|
||||
|
|
|
|||
384
unsloth/chat_templates.py
Normal file
384
unsloth/chat_templates.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = [
|
||||
"get_chat_template",
|
||||
"test_chat_templates",
|
||||
]
|
||||
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
from torch import LongTensor, FloatTensor
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from .models._utils import patch_tokenizer
|
||||
|
||||
CHAT_TEMPLATES = {}
|
||||
|
||||
# Unsloth efficient template leverages from Zephyr
|
||||
unsloth_template = \
|
||||
"{{ bos_token }}"\
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{{ messages[0]['content'] + '\n' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ 'You are a helpful assistant to the user\n' }}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ '>>> User: ' + message['content'] + '\n' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '>>> Assistant: ' }}"\
|
||||
"{% endif %}"
|
||||
unsloth_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token,)
|
||||
|
||||
|
||||
# Zephyr has no BOS!
|
||||
zephyr_template = \
|
||||
"{% for message in messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '<|assistant|>\n' }}"\
|
||||
"{% endif %}"
|
||||
zephyr_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token,)
|
||||
|
||||
|
||||
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
|
||||
chatml_template = \
|
||||
"{% for message in messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '<|im_start|>assistant\n' }}"\
|
||||
"{% endif %}"
|
||||
chatml_eos_token = "<|im_end|>"
|
||||
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token,)
|
||||
|
||||
|
||||
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
|
||||
mistral_template = \
|
||||
"{{ bos_token }}"\
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{% if messages[1]['role'] == 'user' %}"\
|
||||
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
|
||||
"{% set loop_messages = messages[2:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% endif %}"\
|
||||
"{% else %}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ message['content'] + eos_token }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"
|
||||
mistral_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token,)
|
||||
|
||||
|
||||
# Adds BOS to every convo! And weird <<SYS>> system messages.
|
||||
llama_template = \
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{% if messages[1]['role'] == 'user' %}"\
|
||||
"{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
|
||||
"{% set loop_messages = messages[2:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% endif %}"\
|
||||
"{% else %}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"
|
||||
llama_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token,)
|
||||
|
||||
|
||||
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||
vicuna_template = \
|
||||
"{{ bos_token }}"\
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{{ messages[0]['content'] + ' ' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ 'USER: ' + message['content'] + ' ' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ 'ASSISTANT:' }}"\
|
||||
"{% endif %}"
|
||||
vicuna_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token,)
|
||||
|
||||
|
||||
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||
vicuna_old_template = \
|
||||
"{{ bos_token }}"\
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{{ messages[0]['content'] + '\n' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ '### Human: ' + message['content'] + '\n' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '### Assistant:' }}"\
|
||||
"{% endif %}"
|
||||
vicuna_old_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token,)
|
||||
|
||||
|
||||
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
|
||||
alpaca_template = \
|
||||
"{{ bos_token }}"\
|
||||
"{% if messages[0]['role'] == 'system' %}"\
|
||||
"{{ messages[0]['content'] + '\n\n' }}"\
|
||||
"{% set loop_messages = messages[1:] %}"\
|
||||
"{% else %}"\
|
||||
"{{ 'Below are some instructions that describes some tasks. Write responses that appropriately completes each request.\n\n' }}"\
|
||||
"{% set loop_messages = messages %}"\
|
||||
"{% endif %}"\
|
||||
"{% for message in loop_messages %}"\
|
||||
"{% if message['role'] == 'user' %}"\
|
||||
"{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
|
||||
"{% elif message['role'] == 'assistant' %}"\
|
||||
"{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
|
||||
"{% else %}"\
|
||||
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
||||
"{% endif %}"\
|
||||
"{% endfor %}"\
|
||||
"{% if add_generation_prompt %}"\
|
||||
"{{ '### Response:\n' }}"\
|
||||
"{% endif %}"
|
||||
alpaca_eos_token = "eos_token"
|
||||
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token,)
|
||||
|
||||
|
||||
def get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "chatml",
|
||||
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
|
||||
map_eos_token = True,
|
||||
):
|
||||
if map_eos_token is False:
|
||||
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
|
||||
pass
|
||||
|
||||
old_padding_side = tokenizer.padding_side
|
||||
|
||||
if type(chat_template) in (list, tuple):
|
||||
chat_template, stop_word = chat_template
|
||||
assert(type(chat_template) is str)
|
||||
assert(type(stop_word) is str)
|
||||
|
||||
elif type(chat_template) is str:
|
||||
|
||||
chat_template, stop_word = CHAT_TEMPLATES[chat_template]
|
||||
|
||||
if stop_word != "eos_token":
|
||||
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
||||
|
||||
# Replaces the old EOS token with a new one.
|
||||
# Useful for ChatML <|im_end|> for example.
|
||||
# Usually we train 2 more tokens <|im_start|> and <|im_end|>
|
||||
# But training the lm_head and embeddings are slow!
|
||||
# This is a HACK!
|
||||
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
|
||||
string_vocab = tokenizer._tokenizer.to_str()
|
||||
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
|
||||
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
||||
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
|
||||
pass
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
|
||||
f"{CHAT_TEMPLATES.keys()}"
|
||||
)
|
||||
pass
|
||||
|
||||
# For ShareGPT role -> from and content -> value
|
||||
chat_template = chat_template\
|
||||
.replace("'role'", "'" + mapping["role"] + "'")\
|
||||
.replace("'content'", "'" + mapping["content"] + "'")\
|
||||
.replace("'user'", "'" + mapping["user"] + "'")\
|
||||
.replace("'assistant'", "'" + mapping["assistant"] + "'")
|
||||
|
||||
_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
|
||||
tokenizer.padding_side = old_padding_side
|
||||
tokenizer.chat_template = chat_template
|
||||
|
||||
#stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
|
||||
|
||||
return tokenizer#, stopping_criteria
|
||||
pass
|
||||
|
||||
|
||||
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
__slots__ = "stop_token", "single_match", "length",
|
||||
|
||||
def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
|
||||
super().__init__()
|
||||
if stops == "eos_token":
|
||||
self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
|
||||
self.length = 1
|
||||
else:
|
||||
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
|
||||
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
|
||||
self.length = self.stop_token.shape[0]
|
||||
pass
|
||||
self.single_match = self.length == 1
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
|
||||
input_ids = input_ids.ravel()
|
||||
last_token = input_ids[-1]
|
||||
if self.single_match and (last_token == self.stop_token): return True
|
||||
|
||||
if input_ids.shape[0] >= self.length and \
|
||||
(input_ids[-self.length:] == self.stop_token).all(): return True
|
||||
return False
|
||||
pass
|
||||
pass
|
||||
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
|
||||
return stopping_criteria
|
||||
pass
|
||||
|
||||
|
||||
def test_chat_templates():
|
||||
messages = [
|
||||
{"role": "system","content": " You are a friendly chatbot.",},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "It's 4."},
|
||||
{"role": "user", "content": " But 2+2 is equal to 5. "},
|
||||
{"role": "assistant", "content": "No I'm sure its 4."},
|
||||
{"role": "user", "content": " No it's 100% 5! "},
|
||||
]
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
template = zephyr_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
assert(correct_prompt == our_prompt)
|
||||
|
||||
template = chatml_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
|
||||
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
assert(correct_prompt == our_prompt)
|
||||
|
||||
template = mistral_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
||||
assert(correct_prompt == our_prompt)
|
||||
|
||||
template = llama_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
|
||||
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
||||
assert(correct_prompt == our_prompt)
|
||||
|
||||
try:
|
||||
from fastchat.conversation import get_conv_template
|
||||
except:
|
||||
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
||||
from fastchat.conversation import get_conv_template
|
||||
correct_prompt = get_conv_template("vicuna_v1.1")
|
||||
for j in range(len(messages)-1):
|
||||
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
||||
correct_prompt.append_message(correct_prompt.roles[1], "")
|
||||
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
||||
|
||||
template = vicuna_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
||||
assert(correct_prompt == our_prompt)
|
||||
|
||||
try:
|
||||
from fastchat.conversation import get_conv_template
|
||||
except:
|
||||
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
||||
from fastchat.conversation import get_conv_template
|
||||
correct_prompt = get_conv_template("zero_shot")
|
||||
for j in range(len(messages)-1):
|
||||
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
||||
correct_prompt.append_message(correct_prompt.roles[1], "")
|
||||
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
||||
|
||||
template = vicuna_old_template
|
||||
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||
correct_tokenizer.chat_template = template
|
||||
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
||||
# We add </s> ourselves
|
||||
assert(correct_prompt == our_prompt.replace("</s>", ""))
|
||||
pass
|
||||
|
|
@ -16,6 +16,7 @@ import torch
|
|||
from typing import Union, Optional, List, Any, Callable
|
||||
import warnings
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
|
||||
import bitsandbytes as bnb
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
|
@ -116,21 +117,24 @@ pass
|
|||
|
||||
|
||||
def patch_tokenizer(model, tokenizer):
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
if model is not None:
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/5
|
||||
if hasattr(tokenizer, "unk_token"):
|
||||
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token})
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
name = model.config._name_or_path if model is not None else "Model"
|
||||
logger.warning_one(
|
||||
f"{model.config._name_or_path} does not have a padding or unknown token!\n"\
|
||||
f"{name} does not have a padding or unknown token!\n"\
|
||||
f"Will use the EOS token of id {tokenizer.eos_token_id} as padding."
|
||||
)
|
||||
assert(hasattr(tokenizer, "eos_token"))
|
||||
tokenizer.add_special_tokens({"pad_token" : tokenizer.eos_token})
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
|
||||
if model is not None:
|
||||
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
|
||||
pass
|
||||
return model, tokenizer
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -540,7 +540,7 @@ def LlamaModel_fast_forward(
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if past_key_values is None and self.gradient_checkpointing and self.training:
|
||||
if past_key_values is None and self.training:
|
||||
use_cache = False
|
||||
# if use_cache:
|
||||
# logger.warning_once(
|
||||
|
|
@ -776,6 +776,73 @@ def PeftModelForCausalLM_fast_forward(
|
|||
pass
|
||||
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
# https://github.com/huggingface/transformers/pull/27931
|
||||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
pass
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
pass
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
pass
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class FastLlamaModel:
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -787,6 +854,15 @@ class FastLlamaModel:
|
|||
LlamaModel .forward = LlamaModel_fast_forward
|
||||
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
# https://github.com/huggingface/transformers/pull/27931
|
||||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
||||
import transformers.models.llama.modeling_llama
|
||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
|
||||
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
|
||||
return
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -271,6 +271,14 @@ class FastMistralModel(FastLlamaModel):
|
|||
MistralModel .forward = LlamaModel_fast_forward
|
||||
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
|
||||
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
# https://github.com/huggingface/transformers/pull/27931
|
||||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
||||
import transformers.models.mistral.modeling_mistral
|
||||
transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
|
||||
return
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ LLAMA_LAYERNORMS = (
|
|||
"input_layernorm", "post_attention_layernorm",
|
||||
)
|
||||
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
|
||||
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
|
||||
ALLOWED_QUANTS = \
|
||||
{
|
||||
|
|
@ -59,10 +60,16 @@ ALLOWED_QUANTS = \
|
|||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q4_k" : "alias for q4_k_m",
|
||||
"q5_k" : "alias for q5_k_m",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
"iq2_xxs" : "2.06 bpw quantization",
|
||||
"iq2_xs" : "2.31 bpw quantization",
|
||||
"iq3_xxs" : "3.06 bpw quantization",
|
||||
"q3_k_xs" : "3-bit extra small quantization",
|
||||
}
|
||||
|
||||
def print_quantization_methods():
|
||||
|
|
@ -246,7 +253,8 @@ def unsloth_save_model(
|
|||
# If push_to_hub, we must remove the .../ part of a repo
|
||||
if push_to_hub and "/" in save_directory:
|
||||
|
||||
new_save_directory = save_directory[save_directory.find("/"):]
|
||||
# +1 solves absolute path issues
|
||||
new_save_directory = save_directory[save_directory.find("/")+1:]
|
||||
|
||||
logger.warning_once(
|
||||
f"Unsloth: You are pushing to hub, but you passed your HF username.\n"\
|
||||
|
|
@ -861,10 +869,16 @@ def unsloth_save_pretrained_gguf(
|
|||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q4_k" : "alias for q4_k_m",
|
||||
"q5_k" : "alias for q5_k_m",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
"iq2_xxs" : "2.06 bpw quantization",
|
||||
"iq2_xs" : "2.31 bpw quantization",
|
||||
"iq3_xxs" : "3.06 bpw quantization",
|
||||
"q3_k_xs" : "3-bit extra small quantization",
|
||||
"""
|
||||
if tokenizer is None:
|
||||
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue