fix attention bug in eval_loss.py

This commit is contained in:
liang.zhao 2023-12-09 17:05:44 +08:00
parent 8dfbdb4c8a
commit a33d059030
2 changed files with 6 additions and 5 deletions

View File

@ -1,4 +1,4 @@
for LOSS_DATA in zh_finance zh_general zh_government zh_movie zh_news zh_tech
for LOSS_DATA in zh_finance zh_general zh_government zh_movie zh_game zh_tech
do
export HF_MODEL_PATH=YOUR_SKYWORK_HF_BASE_MODEL
export FLAG=skywork-13b-base

View File

@ -52,7 +52,7 @@ def compute_loss(tokenized_texts, attention_mask, model, tokenizer, add_start_to
with torch.no_grad():
logits = model(tokenized_texts, attention_mask=attention_mask).logits
logits = logits[:, :-1]
loss = loss_func(logits.transpose(1, 2), labels) * attention_mask[:, :-1]
loss = loss_func(logits.transpose(1, 2), labels) * attention_mask[:, 1:]
num_tokens = torch.sum(attention_mask).item() - attention_mask.size(0)
return torch.sum(loss).item(), num_tokens
@ -63,7 +63,6 @@ def load_model_tokenizer_config():
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", config=config, trust_remote_code=True).eval()
model.generation_config = GenerationConfig.from_pretrained(args.model_path, trust_remote_code=True)
while config.num_attention_heads % args.n_gpus != 0:
args.n_gpus //= 2
args.batch_size //= 2
@ -72,13 +71,15 @@ def load_model_tokenizer_config():
def main():
model, tokenizer, config = load_model_tokenizer_config()
tokenizer.padding_side = "right"
if not 'chatglm3' in args.model_path.lower():
tokenizer.padding_side = "right"
if "qwen-14b" in args.model_path.lower():
tokenizer.pad_token = '<|extra_0|>'
tokenizer.eos_token = '<|endoftext|>'
args.batch_size = 1
args.n_gpus = 1
elif 'chatglm3' in args.model_path.lower():
print(tokenizer.pad_token)
else:
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else "[PAD]"
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)