mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Gemma2 (#709)
* Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py * Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651) * Nightly (#649) * Update llama.py * offload * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * continued pretraining trainer * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * is_bfloat16_supported * Update __init__.py * Update README.md * Update llama.py * is_bfloat16_supported * Update __init__.py * Mistral v3 * Phi 3 medium * Update chat_templates.py * Update chat_templates.py * Phi-3 * Update save.py * Update README.md Mistral v3 to Mistral v0.3 * Untrained tokens * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update save.py * Update save.py * Update save.py * checkpoint * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * accelerate * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * train_dataloader * Update llama.py * Update llama.py * Update llama.py * use_fast_convert * Update save.py * Update save.py * Update save.py * Update save.py * remove_special_tokens * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Update chat_templates.py * Support bfloat16 GGUF * Update save.py * Update llama.py * fast_forward_inference * Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han <danielhanchen@gmail.com> * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> * Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving * Implemented better list management and then forgot to actually call the new list variable, fixed * Check type of given quantization method and return type error if not list or string * Update save.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> * Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652) This reverts commit 506cb68867296237e95bc53c32f1bfc9b1757960. * Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653) This reverts commit 2f48cc9af385579876fd45bd833169d1f1a2ea58. * Update llama.py * peft * patch * Update loader.py * retrain * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * offload * Update llama.py * Create a starter script for command-line training to integrate in ML ops pipelines. (#623) * Update chat_templates.py * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Ollama * Update chat_templates.py * ollama * Update mapper.py * Update chat_templates.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Fixes * clearer messages * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update llama.py * Update llama.py * log * Update __init__.py * Update llama.py * Update __init__.py * Create Merge.png * Create ollama.png * Gemma2 * Update llama.py * Update loader.py * Update pyproject.toml * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Revert Gemma2 * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update rms_layernorm.py * Update gemma2.py * logit softcapping * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update llama.py * Update gemma2.py * Update llama.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update _utils.py * Update _utils.py * Update gemma2.py * compile flags * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update gemma2.py * Update gemma2.py * fixes * Update _utils.py * Fix generation * Update llama.py * Update llama.py * Update _utils.py * Update _utils.py * Update _utils.py * pad token * Update gemma2.py * pad token * Update _utils.py * Update llama.py * Update gemma2.py * edit warning * Update tokenizer_utils.py --------- Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén <rickardeden@gmail.com> Co-authored-by: XiaoYang <xyangk@gmail.com> Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com> Co-authored-by: Alberto Ferrer <albertof@barrahome.org> Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de> Co-authored-by: Walter Korman <lemurware@gmail.com> Co-authored-by: ArcadaLabs-Jason <52756218+ArcadaLabs-Jason@users.noreply.github.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
This commit is contained in:
parent
cfddc79bc8
commit
cc4c5d7785
17 changed files with 772 additions and 60 deletions
BIN
images/Merge.png
Normal file
BIN
images/Merge.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
BIN
images/ollama.png
Normal file
BIN
images/ollama.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 66 KiB |
|
|
@ -34,7 +34,7 @@ exclude = ["images*"]
|
|||
[project.optional-dependencies]
|
||||
huggingface = [
|
||||
"tyro",
|
||||
"transformers>=4.38.2",
|
||||
"transformers>=4.42.3",
|
||||
"datasets>=2.16.0",
|
||||
"sentencepiece>=0.2.0",
|
||||
"tqdm",
|
||||
|
|
@ -185,9 +185,9 @@ colab-ampere-torch220 = [
|
|||
]
|
||||
colab-new = [
|
||||
"tyro",
|
||||
"transformers>=4.38.2",
|
||||
"transformers>=4.42.3",
|
||||
"datasets>=2.16.0",
|
||||
"sentencepiece",
|
||||
"sentencepiece>=0.2.0",
|
||||
"tqdm",
|
||||
"psutil",
|
||||
"wheel>=0.42.0",
|
||||
|
|
|
|||
|
|
@ -19,14 +19,17 @@ from .utils import calculate_settings, MAX_FUSED_SIZE
|
|||
from transformers.models.llama.modeling_llama import logger
|
||||
|
||||
|
||||
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(
|
||||
logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
logsumexp_ptr,
|
||||
labels_ptr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
DO_SOFTCAPPING : tl.constexpr,
|
||||
SOFTCAP : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
|
|
@ -58,13 +61,19 @@ def _cross_entropy_forward(
|
|||
mask = col_offsets < VOCAB_SIZE
|
||||
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
||||
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
||||
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
|
||||
|
||||
logits = logits.to(tl.float32)
|
||||
c = tl.max(logits, 0)
|
||||
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
||||
|
||||
if label_idx != -100:
|
||||
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = logsumexp - x
|
||||
x = tl.load(logits_ptr + label_idx)
|
||||
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
||||
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
|
||||
loss = logsumexp - x.to(tl.float32)
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(logsumexp_ptr, logsumexp)
|
||||
|
|
@ -72,15 +81,18 @@ def _cross_entropy_forward(
|
|||
pass
|
||||
|
||||
|
||||
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
|
||||
@triton.jit
|
||||
def _chunked_cross_entropy_forward(
|
||||
logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
logsumexp_ptr,
|
||||
labels_ptr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
N_CHUNKS : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
N_CHUNKS : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
DO_SOFTCAPPING : tl.constexpr,
|
||||
SOFTCAP : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
256K vocab divided in 4 chunks
|
||||
|
|
@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
|
|||
mask = col_offsets < VOCAB_SIZE
|
||||
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
||||
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
||||
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
|
||||
|
||||
logits = logits.to(tl.float32)
|
||||
c = tl.max(logits, 0)
|
||||
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
||||
|
||||
|
|
@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
|
|||
# Do the -x separately
|
||||
if label_idx != -100:
|
||||
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = -1.0 * x
|
||||
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
||||
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
|
||||
loss = -1.0 * x.to(tl.float32)
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
|
|
@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
|
|||
pass
|
||||
|
||||
|
||||
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(
|
||||
logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
logsumexp_ptr,
|
||||
labels_ptr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
VOCAB_SIZE : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
DO_SOFTCAPPING : tl.constexpr,
|
||||
SOFTCAP : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
|
|
@ -173,15 +194,27 @@ def _cross_entropy_backward(
|
|||
else:
|
||||
dloss = 0.0
|
||||
|
||||
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
||||
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
||||
if DO_SOFTCAPPING:
|
||||
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
||||
partial = tl.math.tanh(x / SOFTCAP)
|
||||
x = SOFTCAP * partial
|
||||
pass
|
||||
|
||||
logsumexp = tl.load(logsumexp_ptr + row_idx)
|
||||
y = tl.exp(x - logsumexp)
|
||||
y = tl.exp(x.to(tl.float32) - logsumexp)
|
||||
y = tl.where(
|
||||
col_offsets == label_idx,
|
||||
y - 1.0, # exp(x - logsumexp) - 1
|
||||
y, # exp(x - logsumexp)
|
||||
)
|
||||
|
||||
if DO_SOFTCAPPING:
|
||||
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
||||
y = y * (1.0 - partial*partial)
|
||||
pass
|
||||
|
||||
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
||||
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
||||
pass
|
||||
|
|
@ -191,40 +224,46 @@ MAX_FUSED_SIZE = 65536 # 2**16
|
|||
|
||||
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
def forward(ctx, logits, labels, logit_softcapping = 0):
|
||||
n_rows, vocab_size = logits.shape
|
||||
|
||||
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
||||
n_chunks = div + (mod != 0)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
DO_SOFTCAPPING = (logit_softcapping != 0)
|
||||
|
||||
if n_chunks == 1:
|
||||
# For small vocabs <= 65336 like Llama, Mistral
|
||||
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
||||
SOFTCAP = logit_softcapping,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
else:
|
||||
# For large vocabs > 65336 like Gemma 256K
|
||||
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
N_CHUNKS = n_chunks,
|
||||
BLOCK_SIZE = MAX_FUSED_SIZE,
|
||||
num_warps = 32,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
N_CHUNKS = n_chunks,
|
||||
BLOCK_SIZE = MAX_FUSED_SIZE,
|
||||
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
||||
SOFTCAP = logit_softcapping,
|
||||
num_warps = 32,
|
||||
)
|
||||
# logsumexp(chunked_logsumexp) - x
|
||||
# Do the -x separately
|
||||
|
|
@ -234,6 +273,8 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
pass
|
||||
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
|
||||
ctx.logit_softcapping = logit_softcapping
|
||||
return losses
|
||||
pass
|
||||
|
||||
|
|
@ -251,16 +292,18 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
|
|||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = 8,
|
||||
VOCAB_SIZE = vocab_size,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
||||
SOFTCAP = ctx.logit_softcapping,
|
||||
num_warps = 8,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
|
|
@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
|
|||
loss = Fast_CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
logit_softcapping,
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ pass
|
|||
def geglu_exact_forward_kernel(gate, up):
|
||||
batch, seq_len, hd = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
||||
return out
|
||||
|
|
@ -133,7 +133,7 @@ pass
|
|||
def geglu_approx_forward_kernel(gate, up):
|
||||
batch, seq_len, hd = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
|
||||
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward(
|
|||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
|
||||
inv_var = tl.math.rsqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
output = normed * (W_row + 1.0)
|
||||
|
|
@ -137,8 +137,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
||||
fx[(n_rows,)](
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ pass
|
|||
def swiglu_fg_kernel(e, g):
|
||||
batch, seq_len, hd = e.shape
|
||||
n_elements = e.numel()
|
||||
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
|
||||
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
||||
return h
|
||||
|
|
|
|||
|
|
@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None):
|
|||
|
||||
# Create weight matrix
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype = dtype, device = "cuda")
|
||||
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
|
||||
else:
|
||||
assert(out.shape == shape)
|
||||
assert(out.dtype == dtype)
|
||||
|
||||
# NF4 dequantization of statistics
|
||||
n_elements_absmax = absmax.numel()
|
||||
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
|
||||
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
# Do dequantization
|
||||
ptr_out_absmax = get_ptr(out_absmax)
|
||||
|
|
@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None):
|
|||
bout = shape[0]
|
||||
|
||||
if out is None:
|
||||
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
|
||||
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
|
||||
# else:
|
||||
# assert(out.shape == (1, 1, bout,))
|
||||
# pass
|
||||
|
|
@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None):
|
|||
ldb = ctypes.c_int32(ldb)
|
||||
ldc = ctypes.c_int32(ldc)
|
||||
|
||||
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
|
||||
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
||||
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
|
||||
|
|
|
|||
|
|
@ -21,6 +21,12 @@ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "
|
|||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
|
||||
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
|
||||
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
|
||||
|
||||
# Stop "Special tokens have been added in the vocabulary, ..."
|
||||
import logging
|
||||
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
|
@ -31,7 +37,7 @@ import numpy as np
|
|||
import os
|
||||
import psutil
|
||||
|
||||
__version__ = "2024.6"
|
||||
__version__ = "2024.7"
|
||||
|
||||
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
|
||||
major_version, minor_version = torch.cuda.get_device_capability()
|
||||
|
|
@ -80,8 +86,49 @@ __all__ = [
|
|||
"offload_output_embeddings",
|
||||
"is_bfloat16_supported",
|
||||
"unsloth_offloaded_gradient_checkpoint",
|
||||
"torch_compile_options",
|
||||
]
|
||||
|
||||
# Just remove max_autotune_gemm warning
|
||||
import functools
|
||||
@functools.lru_cache(None)
|
||||
def is_big_gpu(index):
|
||||
sms = torch.cuda.get_device_properties(index).multi_processor_count
|
||||
if sms < 80: # V100
|
||||
# log.warning("not enough SMs to use max_autotune_gemm mode")
|
||||
return False
|
||||
return True
|
||||
import torch._inductor.utils
|
||||
torch._inductor.utils.is_big_gpu = is_big_gpu
|
||||
|
||||
|
||||
# Torch compile arguments
|
||||
torch_compile_arguments = [
|
||||
"config.dce = True",
|
||||
"config.memory_planning = True",
|
||||
"config.memory_pool = 'combined'",
|
||||
"config.coordinate_descent_tuning = True",
|
||||
"config.max_autotune_gemm = False", # GEMM is unnecessary
|
||||
"config.autotune_multi_device = False",
|
||||
"config.max_autotune_gemm_backends = 'ATEN'", # Not much faster
|
||||
"config.aggressive_fusion = False", # Careful changes results!
|
||||
"config.cuda.enable_cuda_lto = True",
|
||||
"config.cuda.use_fast_math = True",
|
||||
"config.cuda.compile_opt_level = '-O2'",
|
||||
]
|
||||
import torch._inductor.config as config
|
||||
for _try_compile_argument in torch_compile_arguments:
|
||||
try: exec(_try_compile_argument)
|
||||
except: pass
|
||||
pass
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : True,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False, # Output Triton kernel outputs!
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
|
||||
def prepare_model_for_kbit_training(
|
||||
model : Any,
|
||||
|
|
|
|||
|
|
@ -247,6 +247,8 @@ class FastGemmaModel(FastLlamaModel):
|
|||
GemmaModel .forward = LlamaModel_fast_forward
|
||||
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
fix_prepare_inputs_for_generation(GemmaForCausalLM)
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
|
|
|
|||
538
unsloth/models/gemma2.py
Normal file
538
unsloth/models/gemma2.py
Normal file
|
|
@ -0,0 +1,538 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .llama import *
|
||||
from ._utils import __version__
|
||||
from .gemma import (
|
||||
GemmaFixedRotaryEmbedding,
|
||||
fast_geglu_inference,
|
||||
)
|
||||
from transformers.models.gemma2.modeling_gemma2 import (
|
||||
Gemma2Attention,
|
||||
Gemma2DecoderLayer,
|
||||
Gemma2Model,
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.models.gemma2.modeling_gemma2 import *
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
# For Pytorch 2.1.1
|
||||
try:
|
||||
from transformers.models.gemma2.modeling_gemma2 import (
|
||||
Gemma2SdpaAttention,
|
||||
Gemma2FlashAttention2,
|
||||
)
|
||||
except:
|
||||
Gemma2SdpaAttention = Gemma2Attention
|
||||
Gemma2FlashAttention2 = Gemma2Attention
|
||||
pass
|
||||
|
||||
|
||||
# [TODO] We must randomnly use torch.compile?
|
||||
# I checked the gradients and formulas and I'm sure it's correct.
|
||||
# I'm stumped :(
|
||||
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
||||
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
|
||||
old_dtype = X.dtype
|
||||
X = X.float()
|
||||
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
|
||||
(1.0 + layernorm.weight.float())
|
||||
return X.to(old_dtype)
|
||||
pass
|
||||
|
||||
|
||||
# Logit softcapping
|
||||
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
||||
def gemma2_attention(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
n_heads = self.num_heads
|
||||
head_dim = self.head_dim
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
|
||||
# Grouped query attention
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
||||
|
||||
s = self.config.hidden_size // self.config.num_attention_heads
|
||||
t = self.config.attn_logit_softcapping
|
||||
|
||||
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
||||
A = torch.matmul(Q, K.transpose(2, 3))
|
||||
A = t * torch.tanh(A / t) # Logit softcapping
|
||||
A += causal_mask[:q_len, :q_len]
|
||||
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
||||
A = torch.matmul(A, V)
|
||||
A = A.transpose(1, 2).contiguous()
|
||||
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
||||
return A
|
||||
pass
|
||||
|
||||
|
||||
# Logit softcapping
|
||||
def Gemma2Attention_fast_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
*args, **kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
# Clear inference
|
||||
if hasattr(self, "paged_attention"):
|
||||
del self.paged_attention_K
|
||||
del self.paged_attention_V
|
||||
del self.paged_attention
|
||||
del self.temp_QA
|
||||
del self.temp_KV
|
||||
del self.RH_Q
|
||||
del self.attention
|
||||
pass
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
Q, K, V = self.apply_qkv(self, hidden_states)
|
||||
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
||||
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
||||
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = K.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if position_ids is None:
|
||||
cos = self.rotary_emb.cos_cached
|
||||
sin = self.rotary_emb.sin_cached
|
||||
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
|
||||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
pass
|
||||
|
||||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
pass
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
A = gemma2_attention(Q, K, V, causal_mask, self, bsz, kv_seq_len)
|
||||
A = self.apply_o(self, A)
|
||||
return A, None, past_key_value
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
||||
def Gemma2DecoderLayer_fast_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
*args, **kwargs,
|
||||
):
|
||||
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
|
||||
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
||||
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
|
||||
hidden_states += residual
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
|
||||
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
|
||||
hidden_states += residual
|
||||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
|
||||
hidden_states = residual + hidden_states
|
||||
pass
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions: outputs += (self_attn_weights,)
|
||||
if use_cache: outputs += (present_key_value,)
|
||||
return outputs
|
||||
pass
|
||||
|
||||
|
||||
from math import sqrt as math_sqrt
|
||||
KV_CACHE_INCREMENT = 256 # KV Cache update size
|
||||
torch_nn_functional_softmax = torch.nn.functional.softmax
|
||||
|
||||
def Gemma2Attention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
do_prefill = False,
|
||||
attention_mask = None,
|
||||
use_sliding_window = False,
|
||||
):
|
||||
Xn = hidden_states
|
||||
bsz, _, hd = hidden_states.size()
|
||||
K1, V1 = past_key_value
|
||||
dtype = Xn.dtype
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
attention_size = n_heads*head_dim
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
seq_len = K1.shape[-2]
|
||||
kv_seq_len = seq_len + 1
|
||||
|
||||
# Prefill phase
|
||||
# if not hasattr(self, "paged_attention"):
|
||||
if do_prefill:
|
||||
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
|
||||
self.paged_attention_K = self.paged_attention[:,0]
|
||||
self.paged_attention_V = self.paged_attention[:,1]
|
||||
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
||||
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
|
||||
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
||||
self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
|
||||
self.half_head_dim = head_dim // 2
|
||||
self. t = self.config.attn_logit_softcapping
|
||||
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
|
||||
elif kv_seq_len >= self.paged_attention.shape[0]:
|
||||
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
|
||||
self.paged_attention_K = self.paged_attention[:,0]
|
||||
self.paged_attention_V = self.paged_attention[:,1]
|
||||
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
|
||||
pass
|
||||
|
||||
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
||||
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
||||
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
||||
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
||||
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
|
||||
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
|
||||
h = self.half_head_dim
|
||||
|
||||
RH_Q = self.RH_Q
|
||||
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
|
||||
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
|
||||
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
|
||||
Qn *= cos
|
||||
Qn.addcmul_(RH_Q, sin)
|
||||
|
||||
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
|
||||
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
|
||||
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
|
||||
Kn *= cos
|
||||
Kn.addcmul_(RH_K, sin)
|
||||
|
||||
# New KV cache
|
||||
# Kn = torch.cat([K1, Kn], dim = 2)
|
||||
# Vn = torch.cat([V1, Vn], dim = 2)
|
||||
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
|
||||
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
|
||||
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
|
||||
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
|
||||
|
||||
# Handle sliding windows
|
||||
sliding_window = self.config.sliding_window
|
||||
if use_sliding_window and kv_seq_len > sliding_window:
|
||||
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
|
||||
slicing_tokens = 1 - sliding_window
|
||||
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
|
||||
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
|
||||
else:
|
||||
Knn, Vnn = Kn, Vn
|
||||
pass
|
||||
|
||||
# Grouped query attention
|
||||
_, _, cached_len, _ = Knn.shape
|
||||
if n_groups != 1:
|
||||
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
pass
|
||||
# else:
|
||||
# Knn, Vnn = Knn, Vnn
|
||||
# pass
|
||||
|
||||
# Attention
|
||||
# if bsz == 1:
|
||||
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
||||
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
|
||||
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
||||
|
||||
A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping
|
||||
|
||||
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
|
||||
A = torch.matmul(A, Vnn, out = Qn)
|
||||
# else:
|
||||
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
|
||||
# pass
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, attention_size)
|
||||
A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1][:,:,:self.hidden_size])
|
||||
return A, (Kn, Vn)
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
|
||||
# @torch.inference_mode
|
||||
def Gemma2Model_fast_forward_inference(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
attention_mask = None,
|
||||
):
|
||||
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
|
||||
input_ids = input_ids[:,:self.max_seq_length]
|
||||
hidden_states = self.model.embed_tokens(input_ids)
|
||||
hidden_states = hidden_states.to(self.config.torch_dtype)
|
||||
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
||||
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
||||
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
|
||||
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
seq_len = past_key_values[0][0].shape[-2]
|
||||
if bsz != 1:
|
||||
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(bsz, q_len),
|
||||
hidden_states,
|
||||
seq_len,
|
||||
sliding_window = self.config.sliding_window,
|
||||
)
|
||||
GA = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(bsz, q_len),
|
||||
hidden_states,
|
||||
seq_len,
|
||||
)
|
||||
else:
|
||||
SWA = attention_mask
|
||||
GA = attention_mask
|
||||
pass
|
||||
|
||||
next_decoder_cache = []
|
||||
for idx, decoder_layer in enumerate(self.model.layers):
|
||||
|
||||
use_sliding_window = idx % 2 == 0
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
|
||||
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states = hidden_states,
|
||||
past_key_value = past_key_values[idx],
|
||||
position_ids = position_ids,
|
||||
attention_mask = SWA if use_sliding_window else GA,
|
||||
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
||||
use_sliding_window = use_sliding_window,
|
||||
)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
|
||||
hidden_states += residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
|
||||
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
|
||||
hidden_states += residual
|
||||
|
||||
next_decoder_cache.append(present_key_value)
|
||||
pass
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state = hidden_states,
|
||||
past_key_values = next_decoder_cache,
|
||||
hidden_states = [],
|
||||
attentions = [],
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class FastGemma2Model(FastLlamaModel):
|
||||
|
||||
@staticmethod
|
||||
def pre_patch():
|
||||
Gemma2Attention .forward = Gemma2Attention_fast_forward
|
||||
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
|
||||
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
|
||||
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
|
||||
Gemma2Model .forward = LlamaModel_fast_forward
|
||||
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
|
||||
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
||||
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
# https://github.com/huggingface/transformers/pull/27931
|
||||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
||||
import transformers.models.gemma2.modeling_gemma2
|
||||
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
|
||||
return
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def post_patch(model):
|
||||
# Patch model for Gemma
|
||||
layers = model.model.layers
|
||||
|
||||
# Torch.compile fails on embedding matrix??
|
||||
# Workaround randomnly fixes it for torch versions < 2.2
|
||||
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
|
||||
# We also do this for the lm_head
|
||||
lm_head = torch.nn.Linear(1, 1, bias = None)
|
||||
del lm_head.weight
|
||||
lm_head.weight = model.lm_head.weight
|
||||
lm_head.in_features = lm_head.weight.shape[1]
|
||||
lm_head.out_features = lm_head.weight.shape[0]
|
||||
model.lm_head = lm_head
|
||||
|
||||
# Gemma has tied weights! This means lm_head == embed_tokens
|
||||
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
|
||||
lm_head = torch.nn.Linear(1, 1, bias = None)
|
||||
del lm_head.weight
|
||||
lm_head.weight = model.model.embed_tokens.weight
|
||||
lm_head.in_features = lm_head.weight.shape[1]
|
||||
lm_head.out_features = lm_head.weight.shape[0]
|
||||
model.lm_head = lm_head
|
||||
pass
|
||||
|
||||
# Also patch all dtypes - BnB seems to not allocate the correct type?
|
||||
# BnB default dtype seems to be float16!
|
||||
correct_dtype = lm_head.weight.dtype
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
|
||||
weight = module.weight
|
||||
quant_state = weight.quant_state
|
||||
|
||||
if type(quant_state) is list:
|
||||
# BnB seems to have float16 as default!
|
||||
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
|
||||
else:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
||||
quant_state.dtype = correct_dtype
|
||||
pass
|
||||
pass
|
||||
# Downcast RoPE embedding to correct data type
|
||||
# RoPE must be done in float32 for Gemma
|
||||
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
|
||||
# and (module.cos_cached.dtype != correct_dtype):
|
||||
|
||||
# module.cos_cached = module.cos_cached.to(correct_dtype)
|
||||
# module.sin_cached = module.sin_cached.to(correct_dtype)
|
||||
# pass
|
||||
# pass
|
||||
pass
|
||||
|
||||
# Add 1 to weight
|
||||
# return output * (1 + self.weight)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
|
||||
|
||||
# Freeze all parameters except LoRA
|
||||
# We do this first since += 1 seems to not be liked by requires_grad = True
|
||||
for name, param in model.named_parameters():
|
||||
if ".lora_A." in name or ".lora_B." in name:
|
||||
param.requires_grad_(True)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
pass
|
||||
|
||||
# Patch RMS Layernorm
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, Gemma2RMSNorm):
|
||||
# Must be in float32
|
||||
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
|
||||
# module = module.to(torch.float32)
|
||||
# Leave + 1 to Triton kernel itself
|
||||
# module.weight += 1.0 # return output * (1 + self.weight)
|
||||
if not hasattr(module, "variance_epsilon"):
|
||||
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
|
||||
pass
|
||||
|
||||
# Clear deleted GPU items
|
||||
import gc
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
pass
|
||||
pass
|
||||
|
|
@ -15,6 +15,8 @@
|
|||
import torch
|
||||
import gc
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from ._utils import *
|
||||
from ._utils import __version__
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
logger,
|
||||
|
|
@ -25,8 +27,6 @@ from transformers.modeling_attn_mask_utils import (
|
|||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ..kernels import *
|
||||
from ._utils import *
|
||||
from ._utils import __version__
|
||||
from ..tokenizer_utils import *
|
||||
if HAS_FLASH_ATTENTION:
|
||||
from flash_attn import flash_attn_func
|
||||
|
|
@ -78,6 +78,24 @@ from math import sqrt as math_sqrt
|
|||
KV_CACHE_INCREMENT = 256 # KV Cache update size
|
||||
torch_nn_functional_softmax = torch.nn.functional.softmax
|
||||
|
||||
# Fix new HF's inference code
|
||||
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
|
||||
if "past_key_values" in kwargs:
|
||||
input_ids = input_ids[:,[-1]]
|
||||
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
|
||||
kwargs["position_ids"] = kwargs["cache_position"]
|
||||
return { "input_ids" : input_ids, **kwargs, }
|
||||
pass
|
||||
|
||||
|
||||
def fix_prepare_inputs_for_generation(module):
|
||||
# Fix prepare_inputs_for_generation
|
||||
if hasattr(module, "prepare_inputs_for_generation"):
|
||||
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -542,7 +560,8 @@ def LlamaModel_fast_forward(
|
|||
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
|
||||
|
||||
# Normalized from Gemma
|
||||
IS_GEMMA = self.config.model_type == "gemma"
|
||||
IS_GEMMA = self.config.model_type.startswith("gemma")
|
||||
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
|
||||
train_embed_tokens = self.embed_tokens.weight.requires_grad
|
||||
|
||||
if IS_GEMMA:
|
||||
|
|
@ -642,17 +661,38 @@ def LlamaModel_fast_forward(
|
|||
offloaded_gradient_checkpointing = True
|
||||
pass
|
||||
|
||||
# Gemma2 has alternating SWA and global attn
|
||||
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
n = self.config.max_position_embeddings
|
||||
self.SWA_mask = AttentionMaskConverter(
|
||||
is_causal = True,
|
||||
sliding_window = self.config.sliding_window,
|
||||
)\
|
||||
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
|
||||
.squeeze(0).squeeze(0)
|
||||
|
||||
self.GA_mask = AttentionMaskConverter(
|
||||
is_causal = True,
|
||||
)\
|
||||
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
|
||||
.squeeze(0).squeeze(0)
|
||||
pass
|
||||
|
||||
# Go through every layer!
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
|
||||
if output_hidden_states: all_hidden_states += (hidden_states,)
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
mask = causal_mask
|
||||
if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
|
||||
|
||||
if offloaded_gradient_checkpointing:
|
||||
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
mask,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
|
|
@ -670,7 +710,7 @@ def LlamaModel_fast_forward(
|
|||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
mask,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
use_reentrant = True,
|
||||
|
|
@ -681,7 +721,7 @@ def LlamaModel_fast_forward(
|
|||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
causal_mask=mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
|
|
@ -838,6 +878,7 @@ def CausalLM_fast_forward(fast_forward_inference):
|
|||
logits = logits.to(self.config.torch_dtype)
|
||||
|
||||
loss = None
|
||||
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
|
|
@ -849,7 +890,12 @@ def CausalLM_fast_forward(fast_forward_inference):
|
|||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
logit_softcapping = logit_softcapping,
|
||||
)
|
||||
elif logit_softcapping != 0:
|
||||
logits *= (1.0 / logit_softcapping)
|
||||
torch.tanh(logits, out = logits)
|
||||
logits *= logit_softcapping
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
|
|
@ -983,11 +1029,22 @@ def _wrap_fast_inference(generate, device_type, dtype, model):
|
|||
pass
|
||||
internal_model._flag_for_generation = True
|
||||
|
||||
# For newer HF
|
||||
kwargs["cache_implementation"] = "dynamic"
|
||||
|
||||
# Set pad token
|
||||
old_pad_token_id = getattr(model.config, "pad_token_id", None)
|
||||
old_eos_token_id = getattr(model.config, "eos_token_id", None)
|
||||
model.config.pad_token_id = old_eos_token_id
|
||||
|
||||
# Autocasted
|
||||
with torch.autocast(device_type = device_type, dtype = dtype):
|
||||
output = generate(*args, **kwargs)
|
||||
pass
|
||||
|
||||
# Revert
|
||||
model.config.pad_token_id = old_pad_token_id
|
||||
|
||||
# Unset a flag for generation!
|
||||
internal_model = model
|
||||
while hasattr(internal_model, "model"):
|
||||
|
|
@ -1013,6 +1070,7 @@ class FastLlamaModel:
|
|||
LlamaModel .forward = LlamaModel_fast_forward
|
||||
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
fix_prepare_inputs_for_generation(LlamaForCausalLM)
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
|
|
@ -1056,7 +1114,7 @@ class FastLlamaModel:
|
|||
f"==((====))== Unsloth: Fast {model_patcher.__name__[4:-5]} patching release {__version__}\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
|
||||
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
|
||||
print(statistics)
|
||||
model_patcher.pre_patch()
|
||||
|
|
@ -1200,11 +1258,11 @@ class FastLlamaModel:
|
|||
'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
|
||||
output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output)
|
||||
output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)
|
||||
if output > 1: raise RuntimeError(
|
||||
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\
|
||||
if output > 1: print(
|
||||
'********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\
|
||||
'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\
|
||||
'We do have a separate beta version, which you can contact us about!\\n'\\
|
||||
'Thank you for your understanding and we appreciate it immensely!')
|
||||
'********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\
|
||||
'********************\\nThank you for your understanding and we appreciate it immensely!')
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()"""
|
||||
|
|
@ -1760,6 +1818,7 @@ class FastLlamaModel:
|
|||
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
|
||||
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
|
||||
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||
else:
|
||||
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -26,8 +26,11 @@ major, minor = transformers_version.split(".")[:2]
|
|||
major, minor = int(major), int(minor)
|
||||
SUPPORTS_FOURBIT = (major > 4) or (major == 4 and minor >= 37)
|
||||
SUPPORTS_GEMMA = (major > 4) or (major == 4 and minor >= 38)
|
||||
SUPPORTS_GEMMA2 = (major > 4) or (major == 4 and minor >= 42)
|
||||
if SUPPORTS_GEMMA:
|
||||
from .gemma import FastGemmaModel
|
||||
from .gemma import FastGemmaModel
|
||||
if SUPPORTS_GEMMA2:
|
||||
from .gemma2 import FastGemma2Model
|
||||
del major, minor
|
||||
|
||||
|
||||
|
|
@ -138,6 +141,15 @@ class FastLanguageModel(FastLlamaModel):
|
|||
f"to obtain the latest transformers build, then restart this session."\
|
||||
)
|
||||
dispatch_model = FastGemmaModel
|
||||
elif model_type == "gemma2":
|
||||
if not SUPPORTS_GEMMA2:
|
||||
raise RuntimeError(
|
||||
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
|
||||
f"The minimum required version is 4.43.\n"\
|
||||
f'Try `pip install --upgrade "transformers>=4.43"`\n'\
|
||||
f"to obtain the latest transformers build, then restart this session."\
|
||||
)
|
||||
dispatch_model = FastGemma2Model
|
||||
elif model_type == "qwen2":
|
||||
dispatch_model = FastQwen2Model
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -191,6 +191,14 @@ __INT_TO_FLOAT_MAPPER = \
|
|||
"mistralai/Codestral-22B-v0.1" : (
|
||||
"mistral-community/Codestral-22B-v0.1",
|
||||
),
|
||||
"unsloth/gemma-2-9b-bnb-4bit" : (
|
||||
"unsloth/gemma-2-9b",
|
||||
"google/gemma-2-9b",
|
||||
),
|
||||
"unsloth/gemma-2-27b-bnb-4bit" : (
|
||||
"unsloth/gemma-2-27b",
|
||||
"google/gemma-2-27b",
|
||||
),
|
||||
}
|
||||
|
||||
INT_TO_FLOAT_MAPPER = {}
|
||||
|
|
|
|||
|
|
@ -275,7 +275,8 @@ class FastMistralModel(FastLlamaModel):
|
|||
MistralModel .forward = LlamaModel_fast_forward
|
||||
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
|
||||
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
||||
|
||||
fix_prepare_inputs_for_generation(MistralForCausalLM)
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ class FastQwen2Model(FastLlamaModel):
|
|||
Qwen2Model .forward = LlamaModel_fast_forward
|
||||
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
fix_prepare_inputs_for_generation(Qwen2ForCausalLM)
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
|
|
|
|||
|
|
@ -963,11 +963,11 @@ def patch_sft_trainer_tokenizer():
|
|||
" 'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
|
||||
"output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output)\n"\
|
||||
"output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)\n"\
|
||||
"if output > 1: raise RuntimeError(\n"\
|
||||
" 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\
|
||||
"if output > 1: print(\n"\
|
||||
" '********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\
|
||||
" 'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\\n"\
|
||||
" 'We do have a separate beta version, which you can contact us about!\\n'\\\n"\
|
||||
" 'Thank you for your understanding and we appreciate it immensely!')\n"\
|
||||
" '********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\\n"\
|
||||
" '********************\\nThank you for your understanding and we appreciate it immensely!')\n"\
|
||||
"for _ in range(3):\n"\
|
||||
" gc.collect()\n"\
|
||||
" torch.cuda.empty_cache()\n"\
|
||||
|
|
|
|||
Loading…
Reference in a new issue