mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Update llama.py
This commit is contained in:
parent
a229db5a85
commit
b7ddf962d2
1 changed files with 31 additions and 46 deletions
|
|
@ -1967,48 +1967,41 @@ class FastLlamaModel:
|
|||
if "embed_tokens" in new_target_modules:
|
||||
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
|
||||
|
||||
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
|
||||
# Now patch lm_head and embed_tokens
|
||||
if dtype == torch.float16:
|
||||
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
modules_to_save_dtype = torch.float32
|
||||
else:
|
||||
# Can be bfloat16
|
||||
modules_to_save_dtype = dtype
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.model.model.embed_tokens.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True)
|
||||
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
|
||||
model.get_input_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
|
||||
# [TODO] Move old embed_tokens to CPU - should be disk!
|
||||
model.model.model.embed_tokens.original_module\
|
||||
model.get_input_embeddings().original_module\
|
||||
.to(device = "cpu", non_blocking = True)
|
||||
model.model.model.embed_tokens.original_module.requires_grad_(False)
|
||||
model.get_input_embeddings().original_module.requires_grad_(False)
|
||||
pass
|
||||
|
||||
if "lm_head" in new_target_modules:
|
||||
print("Unsloth: Training lm_head in mixed precision to save VRAM")
|
||||
|
||||
dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype
|
||||
# Now patch lm_head and embed_tokens
|
||||
if dtype == torch.float16:
|
||||
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
modules_to_save_dtype = torch.float32
|
||||
else:
|
||||
# Can be bfloat16
|
||||
modules_to_save_dtype = dtype
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
model.model.lm_head.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True)
|
||||
model.model.lm_head.modules_to_save.default.requires_grad_(True)
|
||||
|
||||
model.get_output_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
|
||||
# [TODO] Move old lm_head to CPU - should be disk!
|
||||
model.model.lm_head.original_module\
|
||||
model.get_output_embeddings().original_module\
|
||||
.to(device = "cpu", non_blocking = True)
|
||||
model.model.lm_head.original_module.requires_grad_(False)
|
||||
model.get_output_embeddings().original_module.requires_grad_(False)
|
||||
pass
|
||||
|
||||
return model
|
||||
|
|
@ -2237,42 +2230,34 @@ class FastLlamaModel:
|
|||
|
||||
if train_embed_tokens:
|
||||
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
|
||||
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
|
||||
assert(hasattr(model.get_input_embeddings(), "modules_to_save"))
|
||||
|
||||
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
|
||||
# Now patch lm_head and embed_tokens
|
||||
if dtype == torch.float16:
|
||||
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
modules_to_save_dtype = torch.float32
|
||||
else:
|
||||
# Can be bfloat16
|
||||
modules_to_save_dtype = dtype
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.model.model.embed_tokens.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True)
|
||||
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
|
||||
model.get_input_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
pass
|
||||
|
||||
if train_lm_head:
|
||||
print("Unsloth: Training lm_head in mixed precision to save VRAM")
|
||||
assert(hasattr(model.model.lm_head, "modules_to_save"))
|
||||
assert(hasattr(model.get_output_embeddings(), "modules_to_save"))
|
||||
|
||||
dtype = model.model.lm_head.modules_to_save.default.weight.dtype
|
||||
# Now patch lm_head and embed_tokens
|
||||
if dtype == torch.float16:
|
||||
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
modules_to_save_dtype = torch.float32
|
||||
else:
|
||||
# Can be bfloat16
|
||||
modules_to_save_dtype = dtype
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.model.lm_head.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True)
|
||||
model.model.lm_head.modules_to_save.default.requires_grad_(True)
|
||||
model.get_output_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
pass
|
||||
|
||||
# Patch tokenizer to pad to the right
|
||||
|
|
|
|||
Loading…
Reference in a new issue