candle/candle-pyo3/e5.py

106 lines
4.3 KiB
Python
Raw Permalink Normal View History

from candle.utils import load_safetensors, save_gguf, load_gguf
from candle.models.bert import BertModel, Config
import json
from candle import Tensor
from tqdm import tqdm
from dataclasses import fields
import os
import time
from huggingface_hub import hf_hub_download
from transformers import BertTokenizer, AutoModel
import torch
if __name__ == "__main__":
model_name = "intfloat/e5-small-v2"
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
tensors = load_safetensors(model_file)
config = Config()
with open(config_file, "r") as f:
raw_config = json.load(f)
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
# Load the model
model = BertModel(config)
model.load_state_dict(tensors)
hf_model = AutoModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
]
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
"""Average the hidden states according to the attention mask"""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
attention_mask = Tensor(tokenized["attention_mask"])
encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
loss = torch.nn.L1Loss()
error = loss(hf_pooled, candle_pooled).mean().item()
2023-10-13 23:21:20 +08:00
print(f"Mean error between torch-reference and candle: {error}")
# Quantize all attention 'weights'
quantized_tensors = {}
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
# check if the tensor is k-quantizable
if tensor.shape[-1] % 256 == 0:
new_tensor = tensor.quantize("q4k")
else:
new_tensor = tensor.quantize("q5_0")
quantized_tensors[name] = new_tensor
else:
quantized_tensors[name] = tensor.quantize("q8_0")
print(f"Saving quantized tensors")
# Remove all None values from the config
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
# Save the model
quantized_model_file = "e5_small.gguf"
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
# Load the model from the gguf
tensors, raw_config = load_gguf(quantized_model_file)
config = Config()
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
model = BertModel(config)
# "embeddings.position_ids" is missing in the gguf as it is i64
model.load_state_dict(tensors, strict=False)
# Run the model again
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
error = loss(hf_pooled, candle_pooled_2).mean().item()
2023-10-13 23:21:20 +08:00
print(f"Mean error between torch-reference and quantized-candle: {error}")