attention mask

This commit is contained in:
Daniel Han-Chen 2024-01-28 04:18:19 +11:00
parent 6c7f0dbcb4
commit 166f8c812e
2 changed files with 2 additions and 2 deletions

View file

@ -489,7 +489,7 @@ def LlamaModel_fast_forward(
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif False:#self.training:
elif self.training:
attention_mask = None
padding_mask = None
else:

View file

@ -90,7 +90,7 @@ def MistralAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
if (attention_mask is None and not HAS_FLASH_ATTENTION):
if (not HAS_FLASH_ATTENTION):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)