mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix Gemma fast inference (#215)
* Update save.py * save * trainer * spaces * original * Gemma * Update pyproject.toml * Update mapper.py * Update fast_lora.py * FastGemmaModel * model_type * Update llama.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * gemma * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Fast CE Loss * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * CE * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update geglu.py * Update cross_entropy_loss.py * revert * Update llama.py * Update llama.py * norm * Update gemma.py * Update gemma.py * position_ids * Update gemma.py * Update gemma.py * pos * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * revert * revert * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * rope * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update llama.py * Hotfix - fix DoRA, Gemma prompt template (#202) (#203) * Update save.py * saving * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update __init__.py * Update save.py * Update save.py * Update save.py * save * trainer * spaces * original * Gemma * Update pyproject.toml * Update mapper.py * Update fast_lora.py * FastGemmaModel * model_type * Update llama.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * gemma * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Fast CE Loss * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * CE * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update geglu.py * Update cross_entropy_loss.py * revert * Update llama.py * Update llama.py * norm * Update gemma.py * Update gemma.py * position_ids * Update gemma.py * Update gemma.py * pos * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * revert * revert * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * rope * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update pyproject.toml * Small fixes * Update pyproject.toml * Approx gelu * Update geglu.py * Approx gelu * Update llama.py * Update __init__.py * Update __init__.py * Update _utils.py * Update geglu.py * Update gemma.py
This commit is contained in:
parent
fa2a43baf3
commit
7b7665d9d6
1 changed files with 89 additions and 1 deletions
|
|
@ -173,6 +173,94 @@ def GemmaModel_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
def GemmaForCausalLM_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
if causal_mask is None and past_key_values is None:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
self.model._has_no_labels = labels is None
|
||||
|
||||
if past_key_values is not None and \
|
||||
hasattr(self.model.layers[0].self_attn, "paged_attention"):
|
||||
outputs = GemmaModel_fast_forward_inference(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values,
|
||||
)
|
||||
else:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pass
|
||||
|
||||
hidden_states = outputs[0]
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
pass
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/10
|
||||
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda")
|
||||
pass
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
)
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class FastGemmaModel(FastLlamaModel):
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -182,7 +270,7 @@ class FastGemmaModel(FastLlamaModel):
|
|||
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
|
||||
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
|
||||
GemmaModel .forward = LlamaModel_fast_forward
|
||||
GemmaForCausalLM .forward = LlamaForCausalLM_fast_forward
|
||||
GemmaForCausalLM .forward = GemmaForCausalLM_fast_forward
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
# Solves https://github.com/unslothai/unsloth/issues/168
|
||||
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
||||
|
|
|
|||
Loading…
Reference in a new issue