Update utils.py

This commit is contained in:
Daniel Han-Chen 2024-02-04 14:02:26 +11:00
parent 990068b977
commit 63ed23ae98

View file

@ -204,13 +204,13 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass
dtype = X.dtype
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)