past_key_value

This commit is contained in:
Daniel Han-Chen 2024-01-31 03:50:47 +11:00
parent d347db0944
commit 5da05558a0
2 changed files with 4 additions and 4 deletions

View file

@ -310,7 +310,7 @@ def LlamaAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
# Check for inference
if past_key_value is not None and q_len == 1:
if past_key_value is not None:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,
@ -426,7 +426,7 @@ def LlamaDecoderLayer_fast_forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
bsz, q_len, hd = hidden_states.size()
if (past_key_value is not None and q_len == 1):
if past_key_value is not None:
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
@ -656,7 +656,7 @@ def LlamaModel_fast_forward(
pass
bsz, q_len, hd = hidden_states.size()
if (past_key_values is not None and q_len == 1):
if past_key_values is not None:
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states)

View file

@ -49,7 +49,7 @@ def MistralAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
# Check for inference
if past_key_value is not None and q_len == 1:
if past_key_value is not None:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,