mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix Gemma (#223)
* Update llama.py * gemma * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Fast CE Loss * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * CE * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update geglu.py * Update cross_entropy_loss.py * revert * Update llama.py * Update llama.py * norm * Update gemma.py * Update gemma.py * position_ids * Update gemma.py * Update gemma.py * pos * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * revert * revert * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * rope * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update llama.py * Hotfix - fix DoRA, Gemma prompt template (#202) (#203) * Update save.py * saving * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update __init__.py * Update save.py * Update save.py * Update save.py * save * trainer * spaces * original * Gemma * Update pyproject.toml * Update mapper.py * Update fast_lora.py * FastGemmaModel * model_type * Update llama.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * gemma * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Fast CE Loss * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * CE * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update geglu.py * Update cross_entropy_loss.py * revert * Update llama.py * Update llama.py * norm * Update gemma.py * Update gemma.py * position_ids * Update gemma.py * Update gemma.py * pos * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * revert * revert * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * rope * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update pyproject.toml * Small fixes * Update pyproject.toml * Approx gelu * Update geglu.py * Approx gelu * Update llama.py * Update __init__.py * Update __init__.py * Update _utils.py * Update geglu.py * Update gemma.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Fix Gemma merging * Update rms_layernorm.py * Update gemma.py * Update pyproject.toml * Layernorms * Gemma precision * Update gemma.py * sqrt * Update gemma.py * Update save.py * RoPE and Gemma precision * Update rms_layernorm.py * Fix warning * Update chat_templates.py
This commit is contained in:
parent
fedcafe281
commit
70f271b1d3
8 changed files with 250 additions and 195 deletions
|
|
@ -43,6 +43,7 @@ huggingface = [
|
|||
"psutil",
|
||||
"wheel>=0.42.0",
|
||||
"numpy",
|
||||
"triton",
|
||||
]
|
||||
cu118only = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
||||
|
|
@ -104,6 +105,16 @@ cu121-torch211 = [
|
|||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch211]",
|
||||
]
|
||||
cu118-torch212 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu118onlytorch212]",
|
||||
]
|
||||
cu121-torch212 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch212]",
|
||||
]
|
||||
cu118-torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
|
|
|
|||
|
|
@ -99,10 +99,14 @@ except:
|
|||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
libcuda_dirs()
|
||||
except:
|
||||
raise ImportError("Unsloth: CUDA is not linked properly.\n"\
|
||||
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
|
||||
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
|
||||
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.")
|
||||
warnings.warn(
|
||||
"Unsloth: CUDA is not linked properly.\n"\
|
||||
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
|
||||
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
|
||||
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
|
||||
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
|
||||
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
|
||||
)
|
||||
pass
|
||||
|
||||
from .models import *
|
||||
|
|
|
|||
|
|
@ -257,6 +257,10 @@ def get_chat_template(
|
|||
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
|
||||
pass
|
||||
|
||||
if tokenizer.__class__.__name__.startswith("Gemma") and chat_template == "chatml":
|
||||
chat_template = "gemma_chatml"
|
||||
pass
|
||||
|
||||
old_padding_side = tokenizer.padding_side
|
||||
|
||||
if type(chat_template) in (list, tuple,):
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ def _rms_layernorm_forward(
|
|||
pass
|
||||
|
||||
|
||||
@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
|
||||
@triton.jit
|
||||
def _rms_layernorm_backward(
|
||||
dY, dY_row_stride,
|
||||
|
|
@ -61,6 +62,7 @@ def _rms_layernorm_backward(
|
|||
r, r_row_stride,
|
||||
dW, dW_row_stride,
|
||||
n_cols, eps,
|
||||
GEMMA : tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
|
@ -84,16 +86,51 @@ def _rms_layernorm_backward(
|
|||
inv_var = tl.load(r).to(tl.float32)
|
||||
normed = X_row * inv_var
|
||||
|
||||
dY_W = dY_row * W_row
|
||||
if GEMMA: dY_W = dY_row * (W_row + 1.0)
|
||||
else: dY_W = dY_row * W_row
|
||||
|
||||
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
||||
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
||||
tl.store(dY + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gemma_rms_layernorm_forward(
|
||||
Y, Y_row_stride,
|
||||
X, X_row_stride,
|
||||
W, W_row_stride,
|
||||
r, r_row_stride,
|
||||
n_cols, eps,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
|
||||
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
|
||||
# exactly. Essentially all in float32!
|
||||
row_idx = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
Y += row_idx * Y_row_stride
|
||||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
output = normed * (W_row + 1.0)
|
||||
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
class Fast_RMS_Layernorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, X, W, eps):
|
||||
def forward(ctx, X, W, eps, gemma = False):
|
||||
shape = X.shape
|
||||
dim = shape[-1]
|
||||
X = X.view(-1, dim)
|
||||
|
|
@ -103,7 +140,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_rms_layernorm_forward[(n_rows,)](
|
||||
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
||||
fx[(n_rows,)](
|
||||
Y, Y.stride(0),
|
||||
X, X.stride(0),
|
||||
W, W.stride(0),
|
||||
|
|
@ -115,6 +153,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.GEMMA = gemma
|
||||
ctx.save_for_backward(X, W, r)
|
||||
return Y.view(*shape)
|
||||
pass
|
||||
|
|
@ -135,18 +174,19 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
r, r .stride(0),
|
||||
dW, dW.stride(0),
|
||||
n_cols, ctx.eps,
|
||||
GEMMA = ctx.GEMMA,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
dX = dY.view(*shape)
|
||||
return dX, None, None
|
||||
return dX, None, None, None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_rms_layernorm(layernorm, X):
|
||||
def fast_rms_layernorm(layernorm, X, gemma = False):
|
||||
W = layernorm.weight
|
||||
eps = layernorm.variance_epsilon
|
||||
out = Fast_RMS_Layernorm.apply(X, W, eps)
|
||||
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
|
||||
return out
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -39,24 +39,28 @@ def _rope_embedding(
|
|||
half_head_dim = head_dim // 2
|
||||
mask = col_offsets < half_head_dim
|
||||
|
||||
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*1 + col_offsets, mask = mask, other = 0)
|
||||
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
|
||||
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
||||
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0).to(sin1.dtype)
|
||||
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*1 + col_offsets, mask = mask, other = 0).to(sin1.dtype)
|
||||
|
||||
if BACKWARD_PASS:
|
||||
# See our blog post for more info.
|
||||
sin1 = -sin1
|
||||
pass
|
||||
|
||||
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
|
||||
half_head_dim*0 + col_offsets,
|
||||
Q1*cos1 - Q2*sin1, mask = mask)
|
||||
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
|
||||
half_head_dim*1 + col_offsets,
|
||||
Q2*cos1 + Q1*sin1, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ except:
|
|||
pass
|
||||
|
||||
|
||||
torch_nn_functional_gelu = torch.nn.functional.gelu
|
||||
def fast_geglu_inference(self, X):
|
||||
# gate = self.gate_proj(X)
|
||||
# up = self.up_proj(X)
|
||||
|
|
@ -48,7 +49,7 @@ def fast_geglu_inference(self, X):
|
|||
|
||||
gate = fast_linear_forward(self.gate_proj, X, out = temp[0])
|
||||
up = fast_linear_forward(self. up_proj, X, out = temp[1])
|
||||
gate = torch.nn.functional.gelu(gate, approximate = "tanh")
|
||||
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
|
||||
gate *= up
|
||||
|
||||
# X = self.down_proj(gate)
|
||||
|
|
@ -57,6 +58,18 @@ def fast_geglu_inference(self, X):
|
|||
pass
|
||||
|
||||
|
||||
def fast_rms_layernorm_inference_gemma(self, X, out_weight):
|
||||
XX = X.to(torch.float32)
|
||||
variance = XX.square().mean(-1, keepdim = True)
|
||||
variance += self.variance_epsilon
|
||||
XX *= variance.rsqrt_()
|
||||
out_weight[:] = self.weight
|
||||
out_weight += 1.0
|
||||
XX *= out_weight
|
||||
return XX.to(X.dtype)
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
||||
def GemmaDecoderLayer_fast_forward(
|
||||
self,
|
||||
|
|
@ -72,10 +85,11 @@ def GemmaDecoderLayer_fast_forward(
|
|||
):
|
||||
if past_key_value is not None:
|
||||
do_prefill = not hasattr(self.self_attn, "paged_attention")
|
||||
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda")
|
||||
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
|
||||
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
|
||||
self.self_attn,
|
||||
hidden_states,
|
||||
|
|
@ -87,12 +101,12 @@ def GemmaDecoderLayer_fast_forward(
|
|||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
|
||||
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
|
||||
hidden_states += residual
|
||||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
|
||||
# hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
|
|
@ -108,7 +122,7 @@ def GemmaDecoderLayer_fast_forward(
|
|||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
|
||||
# hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
|
@ -137,15 +151,18 @@ def GemmaModel_fast_forward_inference(
|
|||
):
|
||||
# Fix out of bounds tokenization
|
||||
input_ids = input_ids[:,:self.max_seq_length]
|
||||
out_weight = torch.empty_like(self.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda")
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states *= math_sqrt(self.config.hidden_size)
|
||||
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
||||
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
||||
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
|
||||
|
||||
next_decoder_cache = []
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
|
||||
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
|
||||
decoder_layer.self_attn,
|
||||
hidden_states,
|
||||
|
|
@ -156,13 +173,13 @@ def GemmaModel_fast_forward_inference(
|
|||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
|
||||
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
|
||||
hidden_states += residual
|
||||
|
||||
next_decoder_cache.append(present_key_value)
|
||||
pass
|
||||
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm_inference_gemma(self.norm, hidden_states, out_weight)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state = hidden_states,
|
||||
|
|
@ -173,91 +190,54 @@ def GemmaModel_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
def GemmaForCausalLM_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
|
||||
# Formulates cos and sin differently from Llama!
|
||||
class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
||||
# Fixes https://github.com/huggingface/transformers/pull/28837
|
||||
# https://github.com/microsoft/DeepSpeed/issues/4932
|
||||
# The precision of RoPE buffers is not correct, so we cast to int64.
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
if causal_mask is None and past_key_values is None:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
self.model._has_no_labels = labels is None
|
||||
|
||||
if past_key_values is not None and \
|
||||
hasattr(self.model.layers[0].self_attn, "paged_attention"):
|
||||
outputs = GemmaModel_fast_forward_inference(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
)
|
||||
else:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
|
||||
pass
|
||||
|
||||
hidden_states = outputs[0]
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
|
||||
# in FP32. They are applied (multiplied) in FP32 as well.
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
|
||||
freq_exponents = (2.0 / self.dim) * (
|
||||
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
|
||||
)
|
||||
timescale = self.base**freq_exponents
|
||||
positions = torch.arange(self.max_seq_len_cached, device = "cpu", dtype = torch.int64).float()
|
||||
radians_new = positions[..., None] / timescale[None, None, :]
|
||||
radians_new = radians_new.squeeze(0)
|
||||
|
||||
emb = torch.cat((radians_new, radians_new), dim = -1)
|
||||
# We must do RoPE in float32!
|
||||
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
|
||||
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent = False)
|
||||
self.register_buffer("sin_cached", sin, persistent = False)
|
||||
pass
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/10
|
||||
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
|
||||
pass
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
def forward(self, x, position_ids=None, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -270,7 +250,7 @@ class FastGemmaModel(FastLlamaModel):
|
|||
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
|
||||
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
|
||||
GemmaModel .forward = LlamaModel_fast_forward
|
||||
GemmaForCausalLM .forward = GemmaForCausalLM_fast_forward
|
||||
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
|
|
@ -278,7 +258,7 @@ class FastGemmaModel(FastLlamaModel):
|
|||
# https://github.com/huggingface/transformers/pull/27931
|
||||
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
||||
import transformers.models.gemma.modeling_gemma
|
||||
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = LlamaRotaryEmbedding
|
||||
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
|
||||
return
|
||||
pass
|
||||
|
||||
|
|
@ -329,13 +309,14 @@ class FastGemmaModel(FastLlamaModel):
|
|||
pass
|
||||
pass
|
||||
# Downcast RoPE embedding to correct data type
|
||||
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
|
||||
and (module.cos_cached.dtype != correct_dtype):
|
||||
# RoPE must be done in float32 for Gemma
|
||||
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
|
||||
# and (module.cos_cached.dtype != correct_dtype):
|
||||
|
||||
module.cos_cached = module.cos_cached.to(correct_dtype)
|
||||
module.sin_cached = module.sin_cached.to(correct_dtype)
|
||||
pass
|
||||
pass
|
||||
# module.cos_cached = module.cos_cached.to(correct_dtype)
|
||||
# module.sin_cached = module.sin_cached.to(correct_dtype)
|
||||
# pass
|
||||
# pass
|
||||
pass
|
||||
|
||||
# Add 1 to weight
|
||||
|
|
@ -358,8 +339,8 @@ class FastGemmaModel(FastLlamaModel):
|
|||
# Must be in float32
|
||||
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
|
||||
# module = module.to(torch.float32)
|
||||
# Don't convert to float32 since error analysis shows it makes it worse!!
|
||||
module.weight += 1.0 # return output * (1 + self.weight)
|
||||
# Leave + 1 to Triton kernel itself
|
||||
# module.weight += 1.0 # return output * (1 + self.weight)
|
||||
if not hasattr(module, "variance_epsilon"):
|
||||
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -208,6 +208,7 @@ def LlamaAttention_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
torch_nn_functional_silu = torch.nn.functional.silu
|
||||
def fast_swiglu_inference(self, X):
|
||||
# gate = self.gate_proj(X)
|
||||
# up = self.up_proj(X)
|
||||
|
|
@ -217,7 +218,7 @@ def fast_swiglu_inference(self, X):
|
|||
|
||||
gate = fast_linear_forward(self.gate_proj, X, out = temp[0])
|
||||
up = fast_linear_forward(self. up_proj, X, out = temp[1])
|
||||
gate = torch.nn.functional.silu(gate, inplace = True)
|
||||
gate = torch_nn_functional_silu(gate, inplace = True)
|
||||
gate *= up
|
||||
|
||||
# X = self.down_proj(gate)
|
||||
|
|
@ -509,7 +510,8 @@ def LlamaModel_fast_forward(
|
|||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# Mormalized from Gemma
|
||||
if self.config.model_type == "gemma":
|
||||
IS_GEMMA = self.config.model_type == "gemma"
|
||||
if IS_GEMMA:
|
||||
inputs_requires_grad = inputs_embeds.requires_grad
|
||||
if not inputs_embeds.is_leaf:
|
||||
inputs_embeds = inputs_embeds.detach()
|
||||
|
|
@ -517,7 +519,12 @@ def LlamaModel_fast_forward(
|
|||
elif inputs_requires_grad:
|
||||
inputs_embeds.requires_grad_(False)
|
||||
pass
|
||||
inputs_embeds *= math_sqrt(self.config.hidden_size)
|
||||
# Match Gemma exactly by casting to bfloat16 / float16
|
||||
# inputs_embeds *= math_sqrt(self.config.hidden_size)
|
||||
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
||||
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
||||
inputs_embeds *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
|
||||
# inputs_embeds *= math_sqrt(self.config.hidden_size)
|
||||
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
|
||||
pass
|
||||
|
||||
|
|
@ -619,7 +626,7 @@ def LlamaModel_fast_forward(
|
|||
all_self_attns += (layer_outputs[1],)
|
||||
pass
|
||||
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
@ -681,91 +688,94 @@ def LlamaModel_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
def LlamaForCausalLM_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
def CausalLM_fast_forward(fast_forward_inference):
|
||||
def _CausalLM_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
if causal_mask is None and past_key_values is None:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
if causal_mask is None and past_key_values is None:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
self.model._has_no_labels = labels is None
|
||||
|
||||
if past_key_values is not None and \
|
||||
hasattr(self.model.layers[0].self_attn, "paged_attention"):
|
||||
outputs = LlamaModel_fast_forward_inference(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
else:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pass
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = outputs[0]
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
pass
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
self.model._has_no_labels = labels is None
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/10
|
||||
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
|
||||
if past_key_values is not None and \
|
||||
hasattr(self.model.layers[0].self_attn, "paged_attention"):
|
||||
outputs = fast_forward_inference(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
)
|
||||
else:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pass
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
|
||||
hidden_states = outputs[0]
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
pass
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/10
|
||||
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
|
||||
pass
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
)
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
return _CausalLM_fast_forward
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -880,7 +890,7 @@ class FastLlamaModel:
|
|||
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
|
||||
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
|
||||
LlamaModel .forward = LlamaModel_fast_forward
|
||||
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
|
||||
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
|
|
|
|||
|
|
@ -369,6 +369,7 @@ def unsloth_save_model(
|
|||
|
||||
# Switch to our fast saving modules if it's a slow PC!
|
||||
n_cpus = psutil.cpu_count(logical = False)
|
||||
if n_cpus is None: n_cpus = psutil.cpu_count()
|
||||
if n_cpus is None: n_cpus = 1
|
||||
|
||||
if safe_serialization is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue