detect if CPU supports float16
This commit is contained in:
parent
fe6169cb0f
commit
f055a88bdd
11
loader.py
11
loader.py
|
@ -8,6 +8,7 @@ import shutil
|
|||
|
||||
from model_config import ModelArgs
|
||||
from blackbox_model import Transformer
|
||||
from utils import device_supports_dtype
|
||||
|
||||
# how are weights sharded in llama2 - by rows or columns
|
||||
join_dim = {
|
||||
|
@ -185,7 +186,15 @@ def add_lora(model_path, lora_path):
|
|||
checkpoint_key = f'layers.{layer}.attention.{local_path}.weight'
|
||||
a_key = f'{attn_key}_lora_{layer}.A.weight'
|
||||
b_key = f'{attn_key}_lora_{layer}.B.weight'
|
||||
lora = lora_weights[b_key].mm(lora_weights[a_key]) * lora_scale
|
||||
|
||||
original_type = lora_weights[b_key].dtype
|
||||
|
||||
if device_supports_dtype('cpu', original_type):
|
||||
lora = lora_weights[b_key].mm(lora_weights[a_key]) * lora_scale
|
||||
else:
|
||||
lora = lora_weights[b_key].to(torch.float32).mm(lora_weights[a_key].to(torch.float32)) * lora_scale
|
||||
lora = lora.to(original_type)
|
||||
|
||||
subset = get_w_subset(local_path, lora, shards, ci)
|
||||
checkpoint[checkpoint_key] = checkpoint[checkpoint_key] + lora[subset].to(torch.bfloat16)
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
|
5
utils.py
5
utils.py
|
@ -9,7 +9,10 @@ def device_map(device):
|
|||
|
||||
def device_supports_dtype(device, dtype):
|
||||
try:
|
||||
tensor = torch.tensor([1.0, 2.0]).to(device).to(dtype)
|
||||
a = torch.rand(2, 2).to(device).to(dtype)
|
||||
b = torch.rand(2, 2).to(device).to(dtype)
|
||||
c = a.mm(b)
|
||||
logging.debug(f'success, {device} supports {dtype}')
|
||||
return True
|
||||
except TypeError as e:
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue