mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix inference attention mask (#142)
* faster saving & inference
* Update llama.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* fast inference
* Update llama.py
* Update save.py
* Update llama.py
* Mistral correct RoPE scaling
* Max sequence lengths
* Apache 2
* fast_linear_forward
* Update utils.py
* Update utils.py
* No print
* Update utils.py
* Update utils.py
* inference
* Update llama.py
* Fast inference RoPE
* Update llama.py
* Update llama.py
* RoPE
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* LoRA
* Fast LoRA saving
* Update llama.py
* hidden_states
* q_len == 1
* q_len issue
* Update mistral.py
* Update mistral.py
* incorrect inference
* Update to transformers 4.37
* Graceful FA2 error + torch 2.1.1
* Update mapper.py
* Update pyproject.toml
* Fix saving and bnb-4bit
* Update fast_lora.py
* Update fast_lora.py
* remove patching
* Update llama.py
* Update llama.py
* Update swiglu.py
* Repatch
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update save.py
* Update fast_lora.py
* Update utils.py
* Update llama.py
* Update fast_lora.py
* Update swiglu.py
* Update save.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Revert "Update llama.py"
This reverts commit a208ec46e0.
* Update llama.py
* Works?
* Update pyproject.toml
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Swiglu
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* attention_mask
* Update llama.py
* Update llama.py
* labels
* Update mistral.py
* Update llama.py
* attention mask
* Update save.py
* Update save.py
* Update mistral.py
* attention mask
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update dpo.py
* Patch saving
* Update save.py
* Update save.py
* patch_saving_functions
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* print
* Mistral patch
* Update mistral.py
* Update save.py
* saving
* Update llama.py
* Update llama.py
This commit is contained in:
parent
90309ca8dc
commit
a3a2ad9382
1 changed files with 4 additions and 1 deletions
|
|
@ -488,7 +488,10 @@ def LlamaModel_fast_forward(
|
|||
|
||||
# Fix up attention mask by setting elements to 0
|
||||
# Specifically for DPO
|
||||
if self._has_no_labels and attention_mask is not None:
|
||||
if self._has_no_labels and attention_mask is not None and \
|
||||
attention_mask.shape[1] == seq_length:
|
||||
# Careful for inference the attention_mask is size (1, kv_seq_len)
|
||||
# Whilst the input_embeds is size (1, 1, 4096)
|
||||
inputs_requires_grad = inputs_embeds.requires_grad
|
||||
if inputs_requires_grad: inputs_embeds.requires_grad_(False)
|
||||
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue