Update llama.py

This commit is contained in:
Daniel Han-Chen 2024-02-08 03:39:45 +11:00
parent 31de486f1c
commit 601dc9ec4b

View file

@ -183,8 +183,8 @@ def LlamaAttention_fast_forward_inference(
pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
_, _, cached_len, _ = Knn.shape
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
@ -195,7 +195,7 @@ def LlamaAttention_fast_forward_inference(
# pass
# Attention
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:kv_seq_len])
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
A *= self.scalar
A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)