diff --git a/README.md b/README.md index 872bf7fd4..4e2ebe3c6 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ If you trained a model with Unsloth, we made a cool sticker!! # Installation Instructions - Conda Unsloth currently only supports Linux distros and Pytorch == 2.1. -``` +```bash conda install cudatoolkit xformers bitsandbytes pytorch pytorch-cuda=12.1 \ -c pytorch -c nvidia -c xformers -c conda-forge -y pip install "unsloth[kaggle] @ git+https://github.com/unslothai/unsloth.git" @@ -41,16 +41,16 @@ pip install "unsloth[kaggle] @ git+https://github.com/unslothai/unsloth.git" # Installation Instructions - Pip 1. Find your CUDA version via -``` +```python import torch; torch.version.cuda ``` 2. We only support Pytorch 2.1 (2.1.1 bugs out for now): You can update Pytorch via Pip (interchange cu121 / cu118) -``` +```bash pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.0 triton \ --index-url https://download.pytorch.org/whl/cu121 ``` 2. Select either cu118 for CUDA 11.8 or cu121 for CUDA 12.1. If you have a RTX 3060 or higher (A100, H100 etc), use the "ampere" path. -``` +```bash pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git" pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git" pip install "unsloth[cu118_ampere] @ git+https://github.com/unslothai/unsloth.git" @@ -59,13 +59,13 @@ pip install "unsloth[cu121_ampere] @ git+https://github.com/unslothai/unsloth.gi Change `cu121` to `cu118` for CUDA version 11.8 or 12.1. Go to https://pytorch.org/ to learn more. 4. If you get errors, try the below first, then go back to step 1: -``` +```bash pip install --upgrade pip ``` # Documentation We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code! -``` +```python from unsloth import FastLlamaModel, FastMistralModel import torch max_seq_length = 2048 # Can change to any number <= 4096 @@ -305,7 +305,7 @@ $$ # Troubleshooting 1. Sometimes `bitsandbytes` or `xformers` does not link properly. Try running: -``` +```bash !ldconfig /usr/lib64-nvidia ``` 2. Windows is not supported as of yet - we rely on Xformers and Triton support, so until both packages support Windows officially, Unsloth will then support Windows. @@ -315,5 +315,5 @@ $$ # Credits 1. [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support 2. [152334H](https://github.com/152334H) for experimental DPO support - +3. [atgctg](https://github.com/atgctg) for syntax highlighting diff --git a/images/unsloth made with love.png b/images/unsloth made with love.png index 20dac04f3..9bf7ec936 100644 Binary files a/images/unsloth made with love.png and b/images/unsloth made with love.png differ diff --git a/images/unsloth new logo.png b/images/unsloth new logo.png index fa05a8ef7..20dac04f3 100644 Binary files a/images/unsloth new logo.png and b/images/unsloth new logo.png differ diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index de124c9dd..769669ae6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -20,13 +20,36 @@ import gc warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") import bitsandbytes as bnb from transformers.models.llama.modeling_llama import logger -import platform +from platform import system as platform_system +platform_system = platform_system() __version__ = "2023.12" + +# Get Flash Attention v2 if Ampere (RTX 30xx, A100) +major_version, minor_version = torch.cuda.get_device_capability() +if major_version >= 8: + try: + from flash_attn import flash_attn_func + HAS_FLASH_ATTENTION = True + except: + HAS_FLASH_ATTENTION = False +else: + # Tri Dao's benchmark shows xformers is faster for now. + HAS_FLASH_ATTENTION = False +pass +import xformers.ops.fmha as xformers +xformers_attention = xformers.memory_efficient_attention +from xformers import __version__ as xformers_version + __all__ = [ "prepare_model_for_kbit_training", "patch_tokenizer", - "print_unsloth_message", + "xformers", + "xformers_attention", + "xformers_version", + "__version__", + "HAS_FLASH_ATTENTION", + "platform_system", ] @@ -71,6 +94,7 @@ pass def patch_tokenizer(model, tokenizer): + model.config.update({"unsloth_version" : __version__}) if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: # Fixes https://github.com/unslothai/unsloth/issues/5 if hasattr(tokenizer, "unk_token"): @@ -88,18 +112,3 @@ def patch_tokenizer(model, tokenizer): pass return model, tokenizer pass - - -def print_unsloth_message(name): - SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() - gpu_stats = torch.cuda.get_device_properties(0) - max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) - - statistics = \ - f"==((====))== Unsloth: Fast {name} patching release {__version__}\n"\ - f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\ - f"O^O/ \_/ \\ CUDA compute capability = {gpu_stats.major}.{gpu_stats.minor}\n"\ - f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\ - f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform.system()}\n' - print(statistics) -pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1f0046ec0..be39b90af 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -23,21 +23,9 @@ from transformers.models.llama.modeling_llama import ( ) from ..kernels import * from ._utils import * - -# Get Flash Attention v2 if Ampere (RTX 30xx, A100) -major_version, minor_version = torch.cuda.get_device_capability() -if major_version >= 8: - try: - from flash_attn import flash_attn_func - HAS_FLASH_ATTENTION = True - except: - HAS_FLASH_ATTENTION = False -else: - # Tri Dao's benchmark shows xformers is faster for now. - HAS_FLASH_ATTENTION = False -pass -import xformers.ops.fmha as xformers -xformers_attention = xformers.memory_efficient_attention +from ._utils import __version__ +if HAS_FLASH_ATTENTION: + from flash_attn import flash_attn_func # Final patching code from transformers.models.llama.modeling_llama import ( @@ -139,19 +127,20 @@ def LlamaAttention_fast_forward_inference( # V = repeat_kv(V, n_groups) if n_groups != 1: _, _, cached_len, _ = Kn.shape - Kn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Vn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim) - Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim) - pass + Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Knn = Knn.view(bsz, n_heads, cached_len, head_dim) + Vnn = Vnn.view(bsz, n_heads, cached_len, head_dim) + else: + Knn, Vnn = Kn, Vn # Attention - A = torch.matmul(Qn, Kn.transpose(2, 3)) + A = torch.matmul(Qn, Knn.transpose(2, 3)) A *= 1.0 / (self.head_dim**0.5) A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype) - A = torch.matmul(A, Vn) + A = torch.matmul(A, Vnn) A = A.transpose(1, 2) - A = A.reshape(bsz, 1, self.hidden_size) + A = A.view(bsz, 1, self.hidden_size) A = original_apply_o(self, A) return A, (Kn, Vn) pass @@ -359,13 +348,13 @@ def LlamaModel_fast_forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 @@ -419,7 +408,7 @@ def LlamaModel_fast_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`" ) use_cache = False pass @@ -614,7 +603,16 @@ class FastLlamaModel: rope_scaling = None, ): SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() - print_unsloth_message("Llama") + gpu_stats = torch.cuda.get_device_properties(0) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + + statistics = \ + f"==((====))== Unsloth: Fast Llama patching release {__version__}\n"\ + f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\ + f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\ + f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\ + f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n' + logger.warning_once(statistics) FastLlamaModel.pre_patch() if dtype is None: @@ -632,7 +630,7 @@ class FastLlamaModel: if (rope_scaling is None) and (max_seq_length > model_max_seq_length): rope_scaling = max_seq_length / model_max_seq_length logger.warning_once( - f"Unsloth: {model_name} can only handle sequence lengths of of most "\ + f"Unsloth: {model_name} can only handle sequence lengths of at most "\ f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\ f"{round(rope_scaling, 3)}, it can be magically be extended to "\ f"{max_seq_length}!" @@ -686,6 +684,7 @@ class FastLlamaModel: # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) + model.config.update({"unsloth_version" : __version__}) # We also do this for the lm_head lm_head = torch.nn.Linear(1, 1, bias = None) @@ -747,6 +746,7 @@ class FastLlamaModel: accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",),) + model.config.update({"unsloth_version" : __version__}) for module in target_modules: assert(module in accepted_modules) pass @@ -771,6 +771,9 @@ class FastLlamaModel: model = _get_peft_model(model, lora_config) # Do patching + n_mlp = 0 + n_qkv = 0 + n_o = 0 for idx, layer in enumerate(model.model.model.layers): # MLP patching @@ -780,6 +783,7 @@ class FastLlamaModel: # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) + n_mlp += 1 pass # QKV attention patching @@ -788,15 +792,22 @@ class FastLlamaModel: hasattr(layer.self_attn.v_proj, "lora_A"): layer.self_attn.apply_qkv = apply_lora_qkv + n_qkv += 1 pass # O attention patching if hasattr(layer.self_attn.o_proj, "lora_A"): layer.self_attn.apply_o = apply_lora_o + n_o += 1 pass pass + logger.warning_once( + f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\ + f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.", + ) + # Patch cross entropy loss labels # Fixes https://github.com/unslothai/unsloth/issues/10 extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda") diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index f75ebbcd5..0ace3f494 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -45,7 +45,7 @@ class FastLanguageModel: ) elif model_type == "mistral": if rope_scaling is not None: - logger.warning_once("Mistral models do not support RoPE scaling.") + logger.warning_once("Unsloth: Mistral models do not support RoPE scaling.") return FastMistralModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -57,7 +57,8 @@ class FastLanguageModel: ) else: raise NotImplementedError( - f"{model_name} not supported yet! Make an issue to https://github.com/unslothai/unsloth!", + f"Unsloth: {model_name} not supported yet!\n"\ + "Make an issue to https://github.com/unslothai/unsloth!", ) pass pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a91a8fc1..323ec39f8 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -13,6 +13,7 @@ # limitations under the License. from .llama import * +from ._utils import __version__ from transformers.models.mistral.modeling_mistral import ( MistralAttention, @@ -245,7 +246,16 @@ class FastMistralModel(FastLlamaModel): # rope_scaling = None, Mistral does not support RoPE scaling ): SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() - print_unsloth_message("Mistral") + gpu_stats = torch.cuda.get_device_properties(0) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + + statistics = \ + f"==((====))== Unsloth: Fast Mistral patching release {__version__}\n"\ + f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\ + f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\ + f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\ + f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n' + logger.warning_once(statistics) FastMistralModel.pre_patch() if dtype is None: