* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

* Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651)

* Nightly (#649)

* Update llama.py

* offload

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* continued pretraining trainer

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* is_bfloat16_supported

* Update __init__.py

* Update README.md

* Update llama.py

* is_bfloat16_supported

* Update __init__.py

* Mistral v3

* Phi 3 medium

* Update chat_templates.py

* Update chat_templates.py

* Phi-3

* Update save.py

* Update README.md

Mistral v3 to Mistral v0.3

* Untrained tokens

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update save.py

* Update save.py

* Update save.py

* checkpoint

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* accelerate

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* train_dataloader

* Update llama.py

* Update llama.py

* Update llama.py

* use_fast_convert

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* remove_special_tokens

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Update chat_templates.py

* Support bfloat16 GGUF

* Update save.py

* Update llama.py

* fast_forward_inference

* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

---------

Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com>
Co-authored-by: Rickard Edén <rickardeden@gmail.com>
Co-authored-by: XiaoYang <xyangk@gmail.com>
Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com>
Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com>
Co-authored-by: Alberto Ferrer <albertof@barrahome.org>
Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de>
Co-authored-by: Walter Korman <lemurware@gmail.com>

* Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving

* Implemented better list management and then forgot to actually call the new list variable, fixed

* Check type of given quantization method and return type error if not list or string

* Update save.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com>
Co-authored-by: Rickard Edén <rickardeden@gmail.com>
Co-authored-by: XiaoYang <xyangk@gmail.com>
Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com>
Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com>
Co-authored-by: Alberto Ferrer <albertof@barrahome.org>
Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de>
Co-authored-by: Walter Korman <lemurware@gmail.com>

* Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652)

This reverts commit 506cb68867296237e95bc53c32f1bfc9b1757960.

* Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653)

This reverts commit 2f48cc9af385579876fd45bd833169d1f1a2ea58.

* Update llama.py

* peft

* patch

* Update loader.py

* retrain

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* offload

* Update llama.py

* Create a starter script for command-line training to integrate in ML ops pipelines. (#623)

* Update chat_templates.py

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Ollama

* Update chat_templates.py

* ollama

* Update mapper.py

* Update chat_templates.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Fixes

* clearer messages

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* log

* Update __init__.py

* Update llama.py

* Update __init__.py

* Create Merge.png

* Create ollama.png

* Gemma2

* Update llama.py

* Update loader.py

* Update pyproject.toml

* Update pyproject.toml

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Revert Gemma2

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update rms_layernorm.py

* Update gemma2.py

* logit softcapping

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update llama.py

* Update gemma2.py

* Update llama.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update llama.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update _utils.py

* Update _utils.py

* Update gemma2.py

* compile flags

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update gemma2.py

* Update gemma2.py

* fixes

* Update _utils.py

* Fix generation

* Update llama.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* pad token

* Update gemma2.py

* pad token

* Update _utils.py

* Update llama.py

* Update gemma2.py

* edit warning

* Update tokenizer_utils.py

---------

Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com>
Co-authored-by: Rickard Edén <rickardeden@gmail.com>
Co-authored-by: XiaoYang <xyangk@gmail.com>
Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com>
Co-authored-by: Sébastien De Greef <sebdg@binarycompany.com>
Co-authored-by: Alberto Ferrer <albertof@barrahome.org>
Co-authored-by: Thomas Viehmann <tv.github-private@beamnet.de>
Co-authored-by: Walter Korman <lemurware@gmail.com>
Co-authored-by: ArcadaLabs-Jason <52756218+ArcadaLabs-Jason@users.noreply.github.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
This commit is contained in:
Daniel Han 2024-07-02 22:51:01 -07:00 committed by GitHub
parent cfddc79bc8
commit cc4c5d7785
17 changed files with 772 additions and 60 deletions

BIN
images/Merge.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

BIN
images/ollama.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View file

@ -34,7 +34,7 @@ exclude = ["images*"]
[project.optional-dependencies]
huggingface = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
@ -185,9 +185,9 @@ colab-ampere-torch220 = [
]
colab-new = [
"tyro",
"transformers>=4.38.2",
"transformers>=4.42.3",
"datasets>=2.16.0",
"sentencepiece",
"sentencepiece>=0.2.0",
"tqdm",
"psutil",
"wheel>=0.42.0",

View file

@ -19,14 +19,17 @@ from .utils import calculate_settings, MAX_FUSED_SIZE
from transformers.models.llama.modeling_llama import logger
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
@ -58,13 +61,19 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = logsumexp - x
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
@ -72,15 +81,18 @@ def _cross_entropy_forward(
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)
logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = -1.0 * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
pass
@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
@ -173,15 +194,27 @@ def _cross_entropy_backward(
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = tl.math.tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
@ -191,40 +224,46 @@ MAX_FUSED_SIZE = 65536 # 2**16
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
DO_SOFTCAPPING = (logit_softcapping != 0)
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
@ -234,6 +273,8 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
pass
ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass
@ -251,16 +292,18 @@ class Fast_CrossEntropyLoss(torch.autograd.Function):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = 8,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels):
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items

View file

@ -41,7 +41,7 @@ pass
def geglu_exact_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
@ -133,7 +133,7 @@ pass
def geglu_approx_forward_kernel(gate, up):
batch, seq_len, hd = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out

View file

@ -119,7 +119,7 @@ def _gemma_rms_layernorm_forward(
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
@ -137,8 +137,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](

View file

@ -41,7 +41,7 @@ pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h

View file

@ -105,14 +105,14 @@ def fast_dequantize(W, quant_state = None, out = None):
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda")
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
@ -161,7 +161,7 @@ def fast_gemv(X, W, quant_state, out = None):
bout = shape[0]
if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
# else:
# assert(out.shape == (1, 1, bout,))
# pass
@ -179,7 +179,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),

View file

@ -21,6 +21,12 @@ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
@ -31,7 +37,7 @@ import numpy as np
import os
import psutil
__version__ = "2024.6"
__version__ = "2024.7"
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
@ -80,8 +86,49 @@ __all__ = [
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
]
# Just remove max_autotune_gemm warning
import functools
@functools.lru_cache(None)
def is_big_gpu(index):
sms = torch.cuda.get_device_properties(index).multi_processor_count
if sms < 80: # V100
# log.warning("not enough SMs to use max_autotune_gemm mode")
return False
return True
import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
# Torch compile arguments
torch_compile_arguments = [
"config.dce = True",
"config.memory_planning = True",
"config.memory_pool = 'combined'",
"config.coordinate_descent_tuning = True",
"config.max_autotune_gemm = False", # GEMM is unnecessary
"config.autotune_multi_device = False",
"config.max_autotune_gemm_backends = 'ATEN'", # Not much faster
"config.aggressive_fusion = False", # Careful changes results!
"config.cuda.enable_cuda_lto = True",
"config.cuda.use_fast_math = True",
"config.cuda.compile_opt_level = '-O2'",
]
import torch._inductor.config as config
for _try_compile_argument in torch_compile_arguments:
try: exec(_try_compile_argument)
except: pass
pass
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}
def prepare_model_for_kbit_training(
model : Any,

View file

@ -247,6 +247,8 @@ class FastGemmaModel(FastLlamaModel):
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GemmaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.

538
unsloth/models/gemma2.py Normal file
View file

@ -0,0 +1,538 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llama import *
from ._utils import __version__
from .gemma import (
GemmaFixedRotaryEmbedding,
fast_geglu_inference,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2Model,
Gemma2ForCausalLM,
Gemma2RotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.models.gemma2.modeling_gemma2 import *
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
# For Pytorch 2.1.1
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2SdpaAttention,
Gemma2FlashAttention2,
)
except:
Gemma2SdpaAttention = Gemma2Attention
Gemma2FlashAttention2 = Gemma2Attention
pass
# [TODO] We must randomnly use torch.compile?
# I checked the gradients and formulas and I'm sure it's correct.
# I'm stumped :(
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
old_dtype = X.dtype
X = X.float()
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
(1.0 + layernorm.weight.float())
return X.to(old_dtype)
pass
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def gemma2_attention(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
s = self.config.hidden_size // self.config.num_attention_heads
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
# Logit softcapping
def Gemma2Attention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None
A = gemma2_attention(Q, K, V, causal_mask, self, bsz, kv_seq_len)
A = self.apply_o(self, A)
return A, None, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def Gemma2DecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
if output_attentions: outputs += (self_attn_weights,)
if use_cache: outputs += (present_key_value,)
return outputs
pass
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
def Gemma2Attention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
self.paged_attention_K = self.paged_attention[:,0]
self.paged_attention_V = self.paged_attention[:,1]
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Handle sliding windows
sliding_window = self.config.sliding_window
if use_sliding_window and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
# if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1][:,:,:self.hidden_size])
return A, (Kn, Vn)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def Gemma2Model_fast_forward_inference(
self,
input_ids,
past_key_values,
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
else:
SWA = attention_mask
GA = attention_mask
pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
use_sliding_window = idx % 2 == 0
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = SWA if use_sliding_window else GA,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
pass
class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
Gemma2Attention .forward = Gemma2Attention_fast_forward
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
Gemma2Model .forward = LlamaModel_fast_forward
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma2.modeling_gemma2
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@staticmethod
def post_patch(model):
# Patch model for Gemma
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Gemma has tied weights! This means lm_head == embed_tokens
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
# Downcast RoPE embedding to correct data type
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
# return output * (1 + self.weight)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
# Freeze all parameters except LoRA
# We do this first since += 1 seems to not be liked by requires_grad = True
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
pass
# Patch RMS Layernorm
for name, module in model.named_modules():
if isinstance(module, Gemma2RMSNorm):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass
# Clear deleted GPU items
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
pass
pass

View file

@ -15,6 +15,8 @@
import torch
import gc
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
logger,
@ -25,8 +27,6 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ..kernels import *
from ._utils import *
from ._utils import __version__
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
@ -78,6 +78,24 @@ from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
# Fix new HF's inference code
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
if "past_key_values" in kwargs:
input_ids = input_ids[:,[-1]]
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
kwargs["position_ids"] = kwargs["cache_position"]
return { "input_ids" : input_ids, **kwargs, }
pass
def fix_prepare_inputs_for_generation(module):
# Fix prepare_inputs_for_generation
if hasattr(module, "prepare_inputs_for_generation"):
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
pass
pass
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
@ -542,7 +560,8 @@ def LlamaModel_fast_forward(
inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
# Normalized from Gemma
IS_GEMMA = self.config.model_type == "gemma"
IS_GEMMA = self.config.model_type.startswith("gemma")
IS_GEMMA2 = self.config.model_type.startswith("gemma2")
train_embed_tokens = self.embed_tokens.weight.requires_grad
if IS_GEMMA:
@ -642,17 +661,38 @@ def LlamaModel_fast_forward(
offloaded_gradient_checkpointing = True
pass
# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
n = self.config.max_position_embeddings
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
# Go through every layer!
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states: all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
mask = causal_mask
if IS_GEMMA2: mask = self.SWA_mask if (idx % 2 == 0) else self.GA_mask
if offloaded_gradient_checkpointing:
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
hidden_states,
causal_mask,
mask,
attention_mask,
position_ids,
past_key_values,
@ -670,7 +710,7 @@ def LlamaModel_fast_forward(
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_mask,
mask,
attention_mask,
position_ids,
use_reentrant = True,
@ -681,7 +721,7 @@ def LlamaModel_fast_forward(
else:
layer_outputs = decoder_layer(
hidden_states,
causal_mask=causal_mask,
causal_mask=mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@ -838,6 +878,7 @@ def CausalLM_fast_forward(fast_forward_inference):
logits = logits.to(self.config.torch_dtype)
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
@ -849,7 +890,12 @@ def CausalLM_fast_forward(fast_forward_inference):
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
)
elif logit_softcapping != 0:
logits *= (1.0 / logit_softcapping)
torch.tanh(logits, out = logits)
logits *= logit_softcapping
pass
if not return_dict:
@ -983,11 +1029,22 @@ def _wrap_fast_inference(generate, device_type, dtype, model):
pass
internal_model._flag_for_generation = True
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# Set pad token
old_pad_token_id = getattr(model.config, "pad_token_id", None)
old_eos_token_id = getattr(model.config, "eos_token_id", None)
model.config.pad_token_id = old_eos_token_id
# Autocasted
with torch.autocast(device_type = device_type, dtype = dtype):
output = generate(*args, **kwargs)
pass
# Revert
model.config.pad_token_id = old_pad_token_id
# Unset a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
@ -1013,6 +1070,7 @@ class FastLlamaModel:
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(LlamaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
@ -1056,7 +1114,7 @@ class FastLlamaModel:
f"==((====))== Unsloth: Fast {model_patcher.__name__[4:-5]} patching release {__version__}\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)
model_patcher.pre_patch()
@ -1200,11 +1258,11 @@ class FastLlamaModel:
'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output)
output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)
if output > 1: raise RuntimeError(
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\
if output > 1: print(
'********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\
'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\
'We do have a separate beta version, which you can contact us about!\\n'\\
'Thank you for your understanding and we appreciate it immensely!')
'********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\
'********************\\nThank you for your understanding and we appreciate it immensely!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
@ -1760,6 +1818,7 @@ class FastLlamaModel:
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
pass

View file

@ -26,8 +26,11 @@ major, minor = transformers_version.split(".")[:2]
major, minor = int(major), int(minor)
SUPPORTS_FOURBIT = (major > 4) or (major == 4 and minor >= 37)
SUPPORTS_GEMMA = (major > 4) or (major == 4 and minor >= 38)
SUPPORTS_GEMMA2 = (major > 4) or (major == 4 and minor >= 42)
if SUPPORTS_GEMMA:
from .gemma import FastGemmaModel
from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
from .gemma2 import FastGemma2Model
del major, minor
@ -138,6 +141,15 @@ class FastLanguageModel(FastLlamaModel):
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "gemma2":
if not SUPPORTS_GEMMA2:
raise RuntimeError(
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
f"The minimum required version is 4.43.\n"\
f'Try `pip install --upgrade "transformers>=4.43"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
else:

View file

@ -191,6 +191,14 @@ __INT_TO_FLOAT_MAPPER = \
"mistralai/Codestral-22B-v0.1" : (
"mistral-community/Codestral-22B-v0.1",
),
"unsloth/gemma-2-9b-bnb-4bit" : (
"unsloth/gemma-2-9b",
"google/gemma-2-9b",
),
"unsloth/gemma-2-27b-bnb-4bit" : (
"unsloth/gemma-2-27b",
"google/gemma-2-27b",
),
}
INT_TO_FLOAT_MAPPER = {}

View file

@ -275,7 +275,8 @@ class FastMistralModel(FastLlamaModel):
MistralModel .forward = LlamaModel_fast_forward
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(MistralForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.

View file

@ -43,6 +43,7 @@ class FastQwen2Model(FastLlamaModel):
Qwen2Model .forward = LlamaModel_fast_forward
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Qwen2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.

View file

@ -963,11 +963,11 @@ def patch_sft_trainer_tokenizer():
" 'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
"output = re.findall(rb'([\\d]{1,})[\\s]{1,}M', output)\n"\
"output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)\n"\
"if output > 1: raise RuntimeError(\n"\
" 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\
"if output > 1: print(\n"\
" '********************\\nUnsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\\\n"\
" 'enabling it will require much more work, so we have to prioritize. Please understand!\\n'\\\n"\
" 'We do have a separate beta version, which you can contact us about!\\n'\\\n"\
" 'Thank you for your understanding and we appreciate it immensely!')\n"\
" '********************\\nWe do have a separate beta version, which you can contact us about!\\n'\\\n"\
" '********************\\nThank you for your understanding and we appreciate it immensely!')\n"\
"for _ in range(3):\n"\
" gc.collect()\n"\
" torch.cuda.empty_cache()\n"\