mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
past_key_value
This commit is contained in:
parent
d347db0944
commit
5da05558a0
2 changed files with 4 additions and 4 deletions
|
|
@ -310,7 +310,7 @@ def LlamaAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if past_key_value is not None and q_len == 1:
|
||||
if past_key_value is not None:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
@ -426,7 +426,7 @@ def LlamaDecoderLayer_fast_forward(
|
|||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
if (past_key_value is not None and q_len == 1):
|
||||
if past_key_value is not None:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
|
|
@ -656,7 +656,7 @@ def LlamaModel_fast_forward(
|
|||
pass
|
||||
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
if (past_key_values is not None and q_len == 1):
|
||||
if past_key_values is not None:
|
||||
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
|
||||
else:
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def MistralAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if past_key_value is not None and q_len == 1:
|
||||
if past_key_value is not None:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
|
|||
Loading…
Reference in a new issue