* faster saving & inference

* Update llama.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py
This commit is contained in:
Daniel Han 2024-01-21 22:20:22 +11:00 committed by GitHub
parent a6f4fb0075
commit 3a9b2dee98
3 changed files with 84 additions and 48 deletions

View file

@ -144,7 +144,7 @@ def LlamaAttention_fast_forward_inference(
A = torch.matmul(A, Vnn)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, self.hidden_size)
A = original_apply_o(self, A)
A = self.o_proj(A)
return A, (Kn, Vn)
pass
@ -187,10 +187,9 @@ def LlamaAttention_fast_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Q, K, V = self.apply_qkv(self, hidden_states)
# Check for inference
if use_cache and past_key_value is not None and q_len == 1:
if past_key_value is not None and q_len == 1:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,
@ -206,6 +205,7 @@ def LlamaAttention_fast_forward(
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
@ -304,29 +304,7 @@ def LlamaDecoderLayer_fast_forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
bsz, q_len, hd = hidden_states.size()
if (self.training):
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
else:
if (past_key_value is not None and q_len == 1):
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
@ -347,6 +325,26 @@ def LlamaDecoderLayer_fast_forward(
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_mlp_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
pass
outputs = (hidden_states,)
@ -445,7 +443,7 @@ def LlamaModel_fast_forward(
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
elif True:#self.training:
attention_mask = None
padding_mask = None
else:
@ -524,10 +522,11 @@ def LlamaModel_fast_forward(
all_self_attns += (layer_outputs[1],)
pass
if (self.training):
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
else:
bsz, q_len, hd = hidden_states.size()
if (past_key_value is not None and q_len == 1):
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
pass
# add hidden states from the last decoder layer

View file

@ -47,10 +47,9 @@ def MistralAttention_fast_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
Q, K, V = self.apply_qkv(self, hidden_states)
# Check for inference
if use_cache and past_key_value is not None and q_len == 1:
if past_key_value is not None and q_len == 1:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,
@ -66,6 +65,7 @@ def MistralAttention_fast_forward(
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

View file

@ -94,8 +94,9 @@ def fast_save_pickle(shard, name):
torch.save(
shard,
name,
pickle_module = pickle,
pickle_protocol = pickle.HIGHEST_PROTOCOL,
# HIGHEST_PROTOCOL seems to not work with Pytorch!
# pickle_module = pickle,
# pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
return
pass
@ -783,12 +784,27 @@ def unsloth_save_pretrained_gguf(
del arguments["quantization_method"]
# Non blocking install GGUF first
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
python_install.wait()
if not os.path.exists("llama.cpp"):
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
python_install.wait()
else:
try:
new_save_directory = unsloth_save_model(**arguments)
makefile = None
except:
# Retry by recloning llama.cpp
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
python_install.wait()
pass
pass
for _ in range(3):
gc.collect()
@ -801,7 +817,10 @@ def unsloth_save_pretrained_gguf(
self, save_directory, token,
"GGUF converted", "gguf", file_location,
)
print(f"Saved to https://huggingface.co/{username}/{new_save_directory.lstrip('/.')}")
link = f"{username}/{new_save_directory.lstrip('/.')}" \
if username not in new_save_directory else \
new_save_directory.lstrip('/.')
print(f"Saved to https://huggingface.co/{link}")
pass
pass
@ -863,16 +882,31 @@ def unsloth_push_to_hub_gguf(
del arguments["quantization_method"]
# Non blocking install GGUF first
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
if not os.path.exists("llama.cpp"):
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
python_install.wait()
else:
try:
new_save_directory = unsloth_save_model(**arguments)
makefile = None
except:
# Retry by recloning llama.cpp
git_clone = install_llama_cpp_clone_non_blocking()
python_install = install_python_non_blocking(["gguf", "protobuf"])
git_clone.wait()
makefile = install_llama_cpp_make_non_blocking()
new_save_directory = unsloth_save_model(**arguments)
python_install.wait()
pass
pass
for _ in range(3):
gc.collect()
python_install.wait()
file_location = save_to_gguf(new_save_directory, quantization_method, makefile)
print("Unsloth: Uploading GGUF to Huggingface Hub...")
@ -880,7 +914,10 @@ def unsloth_push_to_hub_gguf(
self, repo_id, token,
"GGUF converted", "gguf", file_location,
)
print(f"Saved to https://huggingface.co/{username}/{new_save_directory.lstrip('/')}")
link = f"{username}/{new_save_directory.lstrip('/.')}" \
if username not in new_save_directory else \
new_save_directory.lstrip('/.')
print(f"Saved to https://huggingface.co/{link}")
pass