mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Update llama.py
This commit is contained in:
parent
998097394a
commit
81128a4504
1 changed files with 2 additions and 2 deletions
|
|
@ -183,8 +183,8 @@ def LlamaAttention_fast_forward_inference(
|
|||
pass
|
||||
|
||||
# Grouped query attention
|
||||
_, _, cached_len, _ = Knn.shape
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Knn.shape
|
||||
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
|
|
@ -195,7 +195,7 @@ def LlamaAttention_fast_forward_inference(
|
|||
# pass
|
||||
|
||||
# Attention
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:kv_seq_len])
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
|
||||
A *= self.scalar
|
||||
A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
|
||||
A = torch.matmul(A, Vnn, out = Qn)
|
||||
|
|
|
|||
Loading…
Reference in a new issue