detect if CPU supports float16

This commit is contained in:
Oleksandr Kuvshynov 2024-03-25 20:18:00 -04:00
parent fe6169cb0f
commit f055a88bdd
2 changed files with 14 additions and 2 deletions

View File

@ -8,6 +8,7 @@ import shutil
from model_config import ModelArgs
from blackbox_model import Transformer
from utils import device_supports_dtype
# how are weights sharded in llama2 - by rows or columns
join_dim = {
@ -185,7 +186,15 @@ def add_lora(model_path, lora_path):
checkpoint_key = f'layers.{layer}.attention.{local_path}.weight'
a_key = f'{attn_key}_lora_{layer}.A.weight'
b_key = f'{attn_key}_lora_{layer}.B.weight'
lora = lora_weights[b_key].mm(lora_weights[a_key]) * lora_scale
original_type = lora_weights[b_key].dtype
if device_supports_dtype('cpu', original_type):
lora = lora_weights[b_key].mm(lora_weights[a_key]) * lora_scale
else:
lora = lora_weights[b_key].to(torch.float32).mm(lora_weights[a_key].to(torch.float32)) * lora_scale
lora = lora.to(original_type)
subset = get_w_subset(local_path, lora, shards, ci)
checkpoint[checkpoint_key] = checkpoint[checkpoint_key] + lora[subset].to(torch.bfloat16)
torch.save(checkpoint, checkpoint_path)

View File

@ -9,7 +9,10 @@ def device_map(device):
def device_supports_dtype(device, dtype):
try:
tensor = torch.tensor([1.0, 2.0]).to(device).to(dtype)
a = torch.rand(2, 2).to(device).to(dtype)
b = torch.rand(2, 2).to(device).to(dtype)
c = a.mm(b)
logging.debug(f'success, {device} supports {dtype}')
return True
except TypeError as e:
return False