Fix Gemma fast inference (#215)

* Update save.py

* save

* trainer

* spaces

* original

* Gemma

* Update pyproject.toml

* Update mapper.py

* Update fast_lora.py

* FastGemmaModel

* model_type

* Update llama.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update llama.py

* Hotfix - fix DoRA, Gemma prompt template (#202) (#203)

* Update save.py

* saving

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update __init__.py

* Update save.py

* Update save.py

* Update save.py

* save

* trainer

* spaces

* original

* Gemma

* Update pyproject.toml

* Update mapper.py

* Update fast_lora.py

* FastGemmaModel

* model_type

* Update llama.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update pyproject.toml

* Small fixes

* Update pyproject.toml

* Approx gelu

* Update geglu.py

* Approx gelu

* Update llama.py

* Update __init__.py

* Update __init__.py

* Update _utils.py

* Update geglu.py

* Update gemma.py
This commit is contained in:
Daniel Han 2024-03-03 19:36:06 +11:00 committed by GitHub
parent fa2a43baf3
commit 7b7665d9d6

View file

@ -173,6 +173,94 @@ def GemmaModel_fast_forward_inference(
pass
def GemmaForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if causal_mask is None and past_key_values is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
if past_key_values is not None and \
hasattr(self.model.layers[0].self_attn, "paged_attention"):
outputs = GemmaModel_fast_forward_inference(
self.model,
input_ids,
past_key_values,
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
pass
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
pass
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
class FastGemmaModel(FastLlamaModel):
@staticmethod
@ -182,7 +270,7 @@ class FastGemmaModel(FastLlamaModel):
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
GemmaModel .forward = LlamaModel_fast_forward
GemmaForCausalLM .forward = LlamaForCausalLM_fast_forward
GemmaForCausalLM .forward = GemmaForCausalLM_fast_forward
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.