Fix Gemma (#223)

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update llama.py

* Hotfix - fix DoRA, Gemma prompt template (#202) (#203)

* Update save.py

* saving

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update __init__.py

* Update save.py

* Update save.py

* Update save.py

* save

* trainer

* spaces

* original

* Gemma

* Update pyproject.toml

* Update mapper.py

* Update fast_lora.py

* FastGemmaModel

* model_type

* Update llama.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update pyproject.toml

* Small fixes

* Update pyproject.toml

* Approx gelu

* Update geglu.py

* Approx gelu

* Update llama.py

* Update __init__.py

* Update __init__.py

* Update _utils.py

* Update geglu.py

* Update gemma.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Fix Gemma merging

* Update rms_layernorm.py

* Update gemma.py

* Update pyproject.toml

* Layernorms

* Gemma precision

* Update gemma.py

* sqrt

* Update gemma.py

* Update save.py

* RoPE and Gemma precision

* Update rms_layernorm.py

* Fix warning

* Update chat_templates.py
This commit is contained in:
Daniel Han 2024-03-07 04:34:06 +11:00 committed by GitHub
parent fedcafe281
commit 70f271b1d3
8 changed files with 250 additions and 195 deletions

View file

@ -43,6 +43,7 @@ huggingface = [
"psutil",
"wheel>=0.42.0",
"numpy",
"triton",
]
cu118only = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
@ -104,6 +105,16 @@ cu121-torch211 = [
"bitsandbytes",
"unsloth[cu121onlytorch211]",
]
cu118-torch212 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch212]",
]
cu121-torch212 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch212]",
]
cu118-torch220 = [
"unsloth[huggingface]",
"bitsandbytes",

View file

@ -99,10 +99,14 @@ except:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
raise ImportError("Unsloth: CUDA is not linked properly.\n"\
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.")
warnings.warn(
"Unsloth: CUDA is not linked properly.\n"\
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
)
pass
from .models import *

View file

@ -257,6 +257,10 @@ def get_chat_template(
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
pass
if tokenizer.__class__.__name__.startswith("Gemma") and chat_template == "chatml":
chat_template = "gemma_chatml"
pass
old_padding_side = tokenizer.padding_side
if type(chat_template) in (list, tuple,):

View file

@ -53,6 +53,7 @@ def _rms_layernorm_forward(
pass
@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
@triton.jit
def _rms_layernorm_backward(
dY, dY_row_stride,
@ -61,6 +62,7 @@ def _rms_layernorm_backward(
r, r_row_stride,
dW, dW_row_stride,
n_cols, eps,
GEMMA : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
@ -84,16 +86,51 @@ def _rms_layernorm_backward(
inv_var = tl.load(r).to(tl.float32)
normed = X_row * inv_var
dY_W = dY_row * W_row
if GEMMA: dY_W = dY_row * (W_row + 1.0)
else: dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dY + col_offsets, output, mask = mask)
pass
@triton.jit
def _gemma_rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr,
):
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
# exactly. Essentially all in float32!
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
tl.store(Y + col_offsets, output, mask = mask)
pass
class Fast_RMS_Layernorm(torch.autograd.Function):
@staticmethod
def forward(ctx, X, W, eps):
def forward(ctx, X, W, eps, gemma = False):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
@ -103,7 +140,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_rms_layernorm_forward[(n_rows,)](
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
fx[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
W, W.stride(0),
@ -115,6 +153,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.GEMMA = gemma
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
pass
@ -135,18 +174,19 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
r, r .stride(0),
dW, dW.stride(0),
n_cols, ctx.eps,
GEMMA = ctx.GEMMA,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
return dX, None, None
return dX, None, None, None
pass
pass
def fast_rms_layernorm(layernorm, X):
def fast_rms_layernorm(layernorm, X, gemma = False):
W = layernorm.weight
eps = layernorm.variance_epsilon
out = Fast_RMS_Layernorm.apply(X, W, eps)
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
return out
pass

View file

@ -39,24 +39,28 @@ def _rope_embedding(
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*1 + col_offsets, mask = mask, other = 0)
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*0 + col_offsets, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*1 + col_offsets, mask = mask, other = 0).to(sin1.dtype)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
half_head_dim*0 + col_offsets,
Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
half_head_dim*1 + col_offsets,
Q2*cos1 + Q1*sin1, mask = mask)
pass

View file

@ -39,6 +39,7 @@ except:
pass
torch_nn_functional_gelu = torch.nn.functional.gelu
def fast_geglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
@ -48,7 +49,7 @@ def fast_geglu_inference(self, X):
gate = fast_linear_forward(self.gate_proj, X, out = temp[0])
up = fast_linear_forward(self. up_proj, X, out = temp[1])
gate = torch.nn.functional.gelu(gate, approximate = "tanh")
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
@ -57,6 +58,18 @@ def fast_geglu_inference(self, X):
pass
def fast_rms_layernorm_inference_gemma(self, X, out_weight):
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
XX *= variance.rsqrt_()
out_weight[:] = self.weight
out_weight += 1.0
XX *= out_weight
return XX.to(X.dtype)
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def GemmaDecoderLayer_fast_forward(
self,
@ -72,10 +85,11 @@ def GemmaDecoderLayer_fast_forward(
):
if past_key_value is not None:
do_prefill = not hasattr(self.self_attn, "paged_attention")
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda")
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
self.self_attn,
hidden_states,
@ -87,12 +101,12 @@ def GemmaDecoderLayer_fast_forward(
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
# hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
@ -108,7 +122,7 @@ def GemmaDecoderLayer_fast_forward(
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
# hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
@ -137,15 +151,18 @@ def GemmaModel_fast_forward_inference(
):
# Fix out of bounds tokenization
input_ids = input_ids[:,:self.max_seq_length]
out_weight = torch.empty_like(self.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda")
hidden_states = self.embed_tokens(input_ids)
hidden_states *= math_sqrt(self.config.hidden_size)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers):
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states,
@ -156,13 +173,13 @@ def GemmaModel_fast_forward_inference(
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
hidden_states = fast_rms_layernorm_inference_gemma(self.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
@ -173,91 +190,54 @@ def GemmaModel_fast_forward_inference(
pass
def GemmaForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
# Formulates cos and sin differently from Llama!
class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
if causal_mask is None and past_key_values is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
if past_key_values is not None and \
hasattr(self.model.layers[0].self_attn, "paged_attention"):
outputs = GemmaModel_fast_forward_inference(
self.model,
input_ids,
past_key_values,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.max_seq_len_cached = seq_len
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.max_seq_len_cached, device = "cpu", dtype = torch.int64).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
@ -270,7 +250,7 @@ class FastGemmaModel(FastLlamaModel):
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = GemmaForCausalLM_fast_forward
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
@ -278,7 +258,7 @@ class FastGemmaModel(FastLlamaModel):
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma.modeling_gemma
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
return
pass
@ -329,13 +309,14 @@ class FastGemmaModel(FastLlamaModel):
pass
pass
# Downcast RoPE embedding to correct data type
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
and (module.cos_cached.dtype != correct_dtype):
# RoPE must be done in float32 for Gemma
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
# and (module.cos_cached.dtype != correct_dtype):
module.cos_cached = module.cos_cached.to(correct_dtype)
module.sin_cached = module.sin_cached.to(correct_dtype)
pass
pass
# module.cos_cached = module.cos_cached.to(correct_dtype)
# module.sin_cached = module.sin_cached.to(correct_dtype)
# pass
# pass
pass
# Add 1 to weight
@ -358,8 +339,8 @@ class FastGemmaModel(FastLlamaModel):
# Must be in float32
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
# module = module.to(torch.float32)
# Don't convert to float32 since error analysis shows it makes it worse!!
module.weight += 1.0 # return output * (1 + self.weight)
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
pass

View file

@ -208,6 +208,7 @@ def LlamaAttention_fast_forward_inference(
pass
torch_nn_functional_silu = torch.nn.functional.silu
def fast_swiglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
@ -217,7 +218,7 @@ def fast_swiglu_inference(self, X):
gate = fast_linear_forward(self.gate_proj, X, out = temp[0])
up = fast_linear_forward(self. up_proj, X, out = temp[1])
gate = torch.nn.functional.silu(gate, inplace = True)
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
@ -509,7 +510,8 @@ def LlamaModel_fast_forward(
inputs_embeds = self.embed_tokens(input_ids)
# Mormalized from Gemma
if self.config.model_type == "gemma":
IS_GEMMA = self.config.model_type == "gemma"
if IS_GEMMA:
inputs_requires_grad = inputs_embeds.requires_grad
if not inputs_embeds.is_leaf:
inputs_embeds = inputs_embeds.detach()
@ -517,7 +519,12 @@ def LlamaModel_fast_forward(
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
pass
inputs_embeds *= math_sqrt(self.config.hidden_size)
# Match Gemma exactly by casting to bfloat16 / float16
# inputs_embeds *= math_sqrt(self.config.hidden_size)
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
inputs_embeds *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
# inputs_embeds *= math_sqrt(self.config.hidden_size)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass
@ -619,7 +626,7 @@ def LlamaModel_fast_forward(
all_self_attns += (layer_outputs[1],)
pass
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -681,91 +688,94 @@ def LlamaModel_fast_forward_inference(
pass
def LlamaForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
def CausalLM_fast_forward(fast_forward_inference):
def _CausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if causal_mask is None and past_key_values is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()
if causal_mask is None and past_key_values is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
if past_key_values is not None and \
hasattr(self.model.layers[0].self_attn, "paged_attention"):
outputs = LlamaModel_fast_forward_inference(
self.model,
input_ids,
past_key_values,
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
pass
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
if past_key_values is not None and \
hasattr(self.model.layers[0].self_attn, "paged_attention"):
outputs = fast_forward_inference(
self.model,
input_ids,
past_key_values,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
pass
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return _CausalLM_fast_forward
pass
@ -880,7 +890,7 @@ class FastLlamaModel:
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
# Solves https://github.com/unslothai/unsloth/issues/168

View file

@ -369,6 +369,7 @@ def unsloth_save_model(
# Switch to our fast saving modules if it's a slow PC!
n_cpus = psutil.cpu_count(logical = False)
if n_cpus is None: n_cpus = psutil.cpu_count()
if n_cpus is None: n_cpus = 1
if safe_serialization is None: