2023-10-07 02:01:07 +08:00
|
|
|
from candle.utils import load_safetensors, save_gguf, load_gguf
|
|
|
|
from candle.models.bert import BertModel, Config
|
|
|
|
import json
|
|
|
|
from candle import Tensor
|
|
|
|
from tqdm import tqdm
|
|
|
|
from dataclasses import fields
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from transformers import BertTokenizer, AutoModel
|
|
|
|
import torch
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model_name = "intfloat/e5-small-v2"
|
|
|
|
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
|
|
|
|
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
|
|
|
|
|
|
|
|
tensors = load_safetensors(model_file)
|
|
|
|
config = Config()
|
|
|
|
with open(config_file, "r") as f:
|
|
|
|
raw_config = json.load(f)
|
|
|
|
for field in fields(config):
|
|
|
|
if field.name in raw_config:
|
|
|
|
setattr(config, field.name, raw_config[field.name])
|
|
|
|
|
|
|
|
# Load the model
|
|
|
|
model = BertModel(config)
|
|
|
|
model.load_state_dict(tensors)
|
|
|
|
|
|
|
|
hf_model = AutoModel.from_pretrained(model_name)
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
sentences = [
|
|
|
|
"The cat sits outside",
|
|
|
|
"A man is playing guitar",
|
|
|
|
"I love pasta",
|
|
|
|
"The new movie is awesome",
|
|
|
|
"The cat plays in the garden",
|
|
|
|
"A woman watches TV",
|
|
|
|
"The new movie is so great",
|
|
|
|
"Do you like pizza?",
|
|
|
|
]
|
|
|
|
|
|
|
|
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
|
|
|
"""Average the hidden states according to the attention mask"""
|
|
|
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
|
|
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
|
|
|
|
|
|
|
tokenized = tokenizer(sentences, padding=True)
|
|
|
|
tokens = Tensor(tokenized["input_ids"])
|
|
|
|
token_type_ids = Tensor(tokenized["token_type_ids"])
|
2023-10-14 01:53:40 +08:00
|
|
|
attention_mask = Tensor(tokenized["attention_mask"])
|
|
|
|
encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
|
2023-10-07 02:01:07 +08:00
|
|
|
|
|
|
|
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
|
|
|
|
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
|
|
|
|
|
|
|
|
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
|
|
|
|
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
|
|
|
|
|
|
|
|
loss = torch.nn.L1Loss()
|
|
|
|
error = loss(hf_pooled, candle_pooled).mean().item()
|
2023-10-13 23:21:20 +08:00
|
|
|
print(f"Mean error between torch-reference and candle: {error}")
|
2023-10-07 02:01:07 +08:00
|
|
|
|
|
|
|
# Quantize all attention 'weights'
|
|
|
|
quantized_tensors = {}
|
|
|
|
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
|
|
|
|
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
|
|
|
|
# check if the tensor is k-quantizable
|
|
|
|
if tensor.shape[-1] % 256 == 0:
|
|
|
|
new_tensor = tensor.quantize("q4k")
|
|
|
|
else:
|
|
|
|
new_tensor = tensor.quantize("q5_0")
|
|
|
|
quantized_tensors[name] = new_tensor
|
|
|
|
else:
|
|
|
|
quantized_tensors[name] = tensor.quantize("q8_0")
|
|
|
|
|
|
|
|
print(f"Saving quantized tensors")
|
|
|
|
# Remove all None values from the config
|
|
|
|
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
|
|
|
|
# Save the model
|
|
|
|
quantized_model_file = "e5_small.gguf"
|
|
|
|
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
|
|
|
|
|
|
|
|
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
|
|
|
|
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
|
|
|
|
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
|
|
|
|
# Load the model from the gguf
|
|
|
|
tensors, raw_config = load_gguf(quantized_model_file)
|
|
|
|
config = Config()
|
|
|
|
for field in fields(config):
|
|
|
|
if field.name in raw_config:
|
|
|
|
setattr(config, field.name, raw_config[field.name])
|
|
|
|
model = BertModel(config)
|
|
|
|
# "embeddings.position_ids" is missing in the gguf as it is i64
|
|
|
|
model.load_state_dict(tensors, strict=False)
|
|
|
|
|
|
|
|
# Run the model again
|
|
|
|
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
|
|
|
|
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
|
|
|
|
|
|
|
|
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
|
|
|
|
error = loss(hf_pooled, candle_pooled_2).mean().item()
|
2023-10-13 23:21:20 +08:00
|
|
|
print(f"Mean error between torch-reference and quantized-candle: {error}")
|