Pre-release 2023 December version (Mistral, Prelim DPO, WSL, bug fixes) (#16)

* Immediate bug fixes

* Update README.md

* Update README.md

* Update llama.py

* Update llama.py

* Rope Scaling and max_seq_len will change

* Update llama.py

* new images

* Update README.md

* Images

* Update README.md

* Update pyproject.toml

* GQA

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py
This commit is contained in:
Daniel Han 2023-12-06 02:59:22 +11:00 committed by GitHub
parent d1c1748266
commit 6416f235cf
12 changed files with 122 additions and 3047 deletions

View file

@ -1,23 +1,25 @@
<div class="align-center">
<img src="./images/unsloth new logo.png" width="400" />
<a href="https://discord.gg/u54VK8m8tk"><img src="./images/Discord.png" width="180"></a>
<a href="https://colab.research.google.com/drive/1oW55fBmwzCOrBVX66RcpptL3a99qWBxb?usp=sharing"><img src="./images/try live demo green.png" width="130"></a>
</div>
## 80% faster 50% less memory local QLoRA finetuning
## 2-5x faster 50% less memory local LLM finetuning
* Manual autograd engine - hand derived backprop steps.
* QLoRA / LoRA 80% faster, 50% less memory.
* All kernels written in OpenAI's Triton language.
* 2x to 5x faster than QLoRA. 50% less memory usage.
* All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language.
* 0% loss in accuracy - no approximation methods - all exact.
* No change of hardware necessary. Supports NVIDIA GPUs since 2018+. CUDA 7.5+. Tesla T4, RTX 20, 30, 40 series, A100, H100s
* Flash Attention support via Xformers.
* Supports 4bit and 16bit LoRA finetuning.
* No change of hardware necessary. Supports NVIDIA GPUs since 2018+. Minimum CUDA Compute Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU](https://developer.nvidia.com/cuda-gpus)
* [Flash Attention v2](https://github.com/Dao-AILab/flash-attention) support via [Xformers](https://github.com/facebookresearch/xformers).
* **NEW!** Works on **Linux** and **Windows** via WSL.
* **NEW!** Experimental support for [DPO (Direct Preference Optimization)](https://arxiv.org/abs/2305.18290)!
* Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
* Train Slim Orca **fully locally in 260 hours from 1301 hours (5x faster).**
* Open source version trains 5x faster or you can check out [Unsloth Pro and Max](https://unsloth.ai/) codepaths for **30x faster training**!
<div class="align-center">
<img src="./images/Slim Orca 2GPUs.png" width="400" />
<img src="./images/LAION%202GPU.svg" width="400" />
<img src="./images/LAION 2GPU.png" width="400" />
</div>
1. Try our Colab examples for [the Alpaca 52K dataset](https://colab.research.google.com/drive/1oW55fBmwzCOrBVX66RcpptL3a99qWBxb?usp=sharing) or [the Slim Orca 518K dataset](https://colab.research.google.com/drive/1VNqLARpE8N8eYwNrUSDoHVjtbR9W0_c7?usp=sharing).
@ -49,7 +51,13 @@ pip install --upgrade --force-reinstall --no-cache-dir torch triton \
```
Change `cu121` to `cu118` for CUDA version 11.8 or 12.1. Go to https://pytorch.org/ to learn more.
# Alpaca Example
4. If you get errors, try the below first, then go back to step 1:
```
pip install --upgrade pip
```
# Documentation
We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
```
from unsloth import FastLlamaModel
import torch
@ -59,7 +67,7 @@ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False
# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
@ -80,12 +88,16 @@ model = FastLlamaModel.get_peft_model(
max_seq_length = max_seq_length,
)
trainer = .... Use Huggingface's Trainer and dataset loading
trainer = .... Use Huggingface's Trainer and dataset loading (TRL, transformers etc)
```
If you trained a model with Unsloth, we made a cool sticker!!
<img src="./images/unsloth made with love.png" width="200" />
# DPO (Direct Preference Optimization) Experimental support
[152334H](https://github.com/152334H) hacked Unsloth to work with DPO via TRL!
1. Hack the model's `config.json` to be llama model. [Example gist](https://gist.github.com/152334H/d8a68b51b83bac008a02e69ecc81d5c1).
2. Use Unsloth for DPO for both base and reference models. [Example gist](https://gist.github.com/152334H/4847f3a8cca12894877e6b30698b0b64).
# Future Milestones and limitations
1. Support sqrt gradient checkpointing which further slashes memory usage by 25%.
@ -94,6 +106,9 @@ If you trained a model with Unsloth, we made a cool sticker!!
# Performance comparisons on 1 Tesla T4 GPU:
**Time taken for 1 epoch**
One Tesla T4 on Google Colab
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
@ -113,19 +128,28 @@ If you trained a model with Unsloth, we made a cool sticker!!
# Performance comparisons on 2 Tesla T4 GPUs via DDP:
**Time taken for 1 epoch**
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
Two Tesla T4s on Kaggle
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m |
| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) |
| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) |
| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
**Peak Memory Usage on a Multi GPU System (2 GPUs)**
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
| --- | --- | --- | --- | --- | --- |
| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB |
| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB |
| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB |
| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
### For replication of timings:
* [Huggingface LAION DDP reference implementation](https://www.kaggle.com/code/danielhanchen/huggingface-original-laion-oig) 60 steps on DDP Kaggle 2 Tesla T4 GPUs takes 40 minutes and 46 seconds
* [Unsloth LAION DDP fast implementation](https://www.kaggle.com/code/danielhanchen/unsloth-laion-chip2-kaggle) 60 steps on DDP Kaggle 2 Tesla T4 GPUs - **Unsloth only uses 1 GPU whilst Pro plans use more.** takes 4 minutes and 34 seconds **(8.64x speedup)**
# Troubleshooting
1. Sometimes `bitsandbytes` or `xformers` does not link properly. Try running:
@ -136,4 +160,8 @@ If you trained a model with Unsloth, we made a cool sticker!!
3. If it doesn't install - maybe try updating `pip`.
# Credits
1. [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
2. [152334H](https://github.com/152334H) for experimental DPO support
<img src="./images/unsloth loading page render.png" width="300" />

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

After

Width:  |  Height:  |  Size: 18 KiB

BIN
images/LAION 2GPU.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

File diff suppressed because it is too large Load diff

Before

Width:  |  Height:  |  Size: 39 KiB

File diff suppressed because it is too large Load diff

Before

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View file

@ -33,7 +33,7 @@ exclude = ["images*"]
[project.optional-dependencies]
huggingface = [
"transformers",
"transformers",
"datasets",
"sentencepiece",
"accelerate",
@ -70,4 +70,4 @@ colab = [
[project.urls]
homepage = "http://www.unsloth.ai"
documentation = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2023.11"
__version__ = "2023.12"
import os
import warnings
import importlib
@ -35,7 +35,7 @@ if "CUDA_VISIBLE_DEVICES" in os.environ:
)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
else:
warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
pass

View file

@ -43,7 +43,6 @@ def _cross_entropy_forward(logits_ptr, logits_row_stride,
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
max_logits = tl.max(logits, 0)
@ -88,7 +87,6 @@ def _cross_entropy_backward(logits_ptr, logits_row_stride,
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:

View file

@ -35,7 +35,6 @@ def _rope_embedding(
mask = col_offsets < half_head_dim
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
rot_position = row_position % seqlen
Q += row_position* Q_row_stride + head_position*head_dim
@ -48,8 +47,6 @@ def _rope_embedding(
Q2 = tl.load(Q + half_head_dim*1 + col_offsets, mask = mask, other = 0)
# RoPE repeats sin and cos so 128 = [64, 64].
# sin2 = tl.load(sin + half_head_dim*1, mask = mask, other = 0)
# cos2 = tl.load(cos + half_head_dim*1, mask = mask, other = 0)
if BACKWARD_PASS:
"""
@ -62,11 +59,8 @@ def _rope_embedding(
where R.T is again the same [ 0, -I]
but the minus is transposed. [ I, 0]
"""
# sin1, sin2 = -sin1, -sin2
sin1 = -sin1
# tl.store(Q + half_head_dim*0, Q1*cos1 - Q2*sin1, mask = mask)
# tl.store(Q + half_head_dim*1, Q2*cos2 + Q1*sin2, mask = mask)
# RoPE repeats sin and cos so 128 = [64, 64].
tl.store(Q + half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)

View file

@ -13,12 +13,12 @@
# limitations under the License.
import triton
MAX_FUSED_SIZE = 65535 # 2**16 - 1
MAX_FUSED_SIZE = 65536 # 2**16 Solves https://github.com/unslothai/unsloth/issues/7
next_power_of_2 = triton.next_power_of_2
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
# CUDA only supports 65535 - 2^16-1 threads per block
# CUDA only supports 65536 - 2^16 threads per block
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")

View file

@ -16,14 +16,15 @@ import torch
from typing import Optional, Tuple, List, Union
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
# apply_rotary_pos_emb,
# repeat_kv,
# _prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask,
logger,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ..kernels import *
from ._utils import (
prepare_model_for_kbit_training,
)
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
@ -37,7 +38,6 @@ else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
@ -55,12 +55,9 @@ import bitsandbytes as bnb
import numpy as np
import types
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from ._utils import (
prepare_model_for_kbit_training,
)
def original_apply_qkv(self, X):
@ -92,10 +89,6 @@ def LlamaAttention_fast_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# Q = self.q_proj(hidden_states)
# K = self.k_proj(hidden_states)
# V = self.v_proj(hidden_states)
Q, K, V = self.apply_qkv(self, hidden_states)
n_heads = self.num_heads
@ -112,8 +105,6 @@ def LlamaAttention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
# Q, K = apply_rotary_pos_emb(Q, K, cos, sin, position_ids)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
@ -130,10 +121,9 @@ def LlamaAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
# no_attention_mask = attention_mask is None
# Ignore attention_mask
if (not HAS_FLASH_ATTENTION): #and no_attention_mask:
# Xformers doesnt support backward pass for GQA (yet)
# TEMP fix
if (n_groups == 1) and (not HAS_FLASH_ATTENTION):
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
@ -143,18 +133,17 @@ def LlamaAttention_fast_forward(
# Grouped query attention
if n_groups != 1:
Q = Q.reshape(bsz, q_len, n_groups, n_kv_heads, head_dim)
K = K.reshape(bsz, q_len, n_groups, 1, head_dim)
V = V.reshape(bsz, q_len, n_groups, 1, head_dim)
K = K .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
V = V .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
Q = Q.reshape(bsz, q_len, n_kv_heads, n_groups, head_dim)
K = K.reshape(bsz, q_len, n_kv_heads, 1, head_dim)
V = V.reshape(bsz, q_len, n_kv_heads, 1, head_dim)
K = K .expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
V = V .expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION:# and no_attention_mask:
elif HAS_FLASH_ATTENTION:
# Flash Attention
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
Q = Q.transpose(1, 2)
@ -163,37 +152,22 @@ def LlamaAttention_fast_forward(
# Flash Attention v2 auto supports grouped query attention
A = flash_attn_func(Q, K, V, causal = True)
else:
# Uses Pytorch's scaled dot product attention
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
pass
# Grouped query attention
# K = repeat_kv(K, n_groups)
# V = repeat_kv(V, n_groups)
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
pass
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = attention_mask is None)
A = scaled_dot_product_attention(Q, K, V, attn_mask = None, is_causal = True)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2)
pass
attn_output = A.reshape(bsz, q_len, self.hidden_size)
# attn_output = self.o_proj(attn_output)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
@ -227,7 +201,6 @@ def LlamaDecoderLayer_fast_forward(
"""
residual = hidden_states
# hidden_states = self.input_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
# Self Attention
@ -245,7 +218,6 @@ def LlamaDecoderLayer_fast_forward(
# Fully Connected
residual = hidden_states
# hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
@ -308,7 +280,7 @@ def LlamaModel_fast_forward(
if (past_key_values_length != 0):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,#dtype=torch.long,
dtype = torch.int32,
device = "cuda",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
@ -326,11 +298,7 @@ def LlamaModel_fast_forward(
inputs_embeds = self.embed_tokens(input_ids)
# Ignore attention_mask
if True:
# if attention_mask is None:
# attention_mask = torch.ones(
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
# )
if attention_mask is None:
padding_mask = None
else:
if 0 in attention_mask:
@ -339,7 +307,7 @@ def LlamaModel_fast_forward(
padding_mask = None
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length,
)
pass
@ -403,7 +371,6 @@ def LlamaModel_fast_forward(
all_self_attns += (layer_outputs[1],)
pass
# hidden_states = self.norm(hidden_states)
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
# add hidden states from the last decoder layer
@ -466,19 +433,13 @@ def LlamaForCausalLM_fast_forward(
loss = None
if labels is not None:
# logits = logits.float()
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# shift_labels = shift_labels.view(-1)
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
# loss_fct = torch.nn.CrossEntropyLoss(
# ignore_index = self.ignore_index,
# label_smoothing = self.label_smoothing,
# )
# loss = loss_fct(shift_logits, shift_labels)
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
@ -547,13 +508,14 @@ class FastLlamaModel:
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
):
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
statistics = \
"==((====))== Unsloth: Fast Llama patching release 23.11\n"\
"==((====))== Unsloth: Fast Llama patching release 2023.12\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
f"O^O/ \_/ \\ CUDA compute capability = {gpu_stats.major}.{gpu_stats.minor}\n"\
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
@ -570,9 +532,20 @@ class FastLlamaModel:
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# [TODO]: Determine RoPE scaling
# https://github.com/huggingface/transformers/pull/24653
assert(max_seq_length <= 4096)
# RoPE scaling
model_max_seq_length = \
AutoConfig.from_pretrained(model_name, token = token).max_position_embeddings
if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
rope_scaling = max_seq_length / model_max_seq_length
logger.warning_once(
f"Unsloth: {model_name} can only handle sequence lengths of of most "\
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
f"{round(rope_scaling, 3)}, it can be magically be extended to "\
f"{max_seq_length}!"
)
rope_scaling = {"type": "linear", "factor": rope_scaling,}
pass
bnb_config = None
if load_in_4bit:
@ -589,6 +562,7 @@ class FastLlamaModel:
torch_dtype = dtype,
quantization_config = bnb_config,
token = token,
rope_scaling = rope_scaling,
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
@ -596,9 +570,22 @@ class FastLlamaModel:
padding_side = "right",
token = token,
)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token
config = model.config.update({"pad_token_id" : tokenizer.unk_token_id});
if not hasattr(tokenizer, "pad_token"):
# Fixes https://github.com/unslothai/unsloth/issues/5
if hasattr(tokenizer, "unk_token"):
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token})
tokenizer.pad_token = tokenizer.unk_token
else:
logger.warning_one(
f"{model_name} does not have a padding or unknown token!\n"\
f"Will use the EOS token of id {tokenizer.eos_token_id} as padding."
)
assert(hasattr(tokenizer, "eos_token"))
tokenizer.add_special_tokens({"pad_token" : tokenizer.eos_token})
tokenizer.pad_token = tokenizer.eos_token
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
pass
model = FastLlamaModel.post_patch(model)
@ -607,6 +594,8 @@ class FastLlamaModel:
layer.self_attn.apply_qkv = original_apply_qkv
layer.self_attn.apply_o = original_apply_o
pass
model.max_seq_length = max_seq_length
return model, tokenizer
pass
@ -668,6 +657,8 @@ class FastLlamaModel:
random_state = 3407,
max_seq_length = 2048,
):
assert(max_seq_length <= model.max_seq_length)
if lora_dropout != 0:
raise TypeError("Unsloth: Fast Llama patching only works with dropout = 0.")
if bias != "none":
@ -727,8 +718,14 @@ class FastLlamaModel:
pass
# Patch cross entropy loss labels
model.model.extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda")
# Fixes https://github.com/unslothai/unsloth/issues/10
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda")
model.model.extra_ignored_labels = extra_ignored_labels
internal_model = model
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
pass
return model
pass
pass