add tokenization and score file

This commit is contained in:
yoonlee666 2020-09-21 17:51:16 +08:00
parent fe219b5680
commit 01c9e8b373
7 changed files with 457 additions and 28 deletions

View File

@ -18,12 +18,11 @@ Bert finetune and evaluation script.
'''
import os
import json
import argparse
from src.bert_for_finetune import BertFinetuneCell, BertNER
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_ner_dataset
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
import mindspore.common.dtype as mstype
from mindspore import context
@ -99,7 +98,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None):
load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None):
""" do eval """
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
@ -114,7 +113,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
if assessment_method == "clue_benchmark":
from src.cluener_evaluation import submit
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file)
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf,
label_file=label_file, tag_to_index=tag_to_index)
else:
if assessment_method == "accuracy":
callback = Accuracy()
@ -161,7 +161,7 @@ def parse_args():
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
help="Enable eval data shuffle, default is false")
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark")
parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark")
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
@ -180,10 +180,10 @@ def parse_args():
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "":
raise ValueError("'vocab_file_path' must be set to do clue benchmark")
if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "":
raise ValueError("'label2id_file_path' must be set to use crf")
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "":
raise ValueError("'label2id_file_path' must be set to do clue benchmark")
if args_opt.use_crf.lower() == "true" and args_opt.label_file_path == "":
raise ValueError("'label_file_path' must be set to use crf")
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label_file_path == "":
raise ValueError("'label_file_path' must be set to do clue benchmark")
return args_opt
@ -205,11 +205,12 @@ def run_ner():
bert_net_cfg.compute_type = mstype.float32
else:
raise Exception("Target error, GPU or Ascend is supported.")
tag_to_index = None
label_list = []
with open(args_opt.label_file_path) as f:
for label in f:
label_list.append(label.strip())
tag_to_index = convert_labels_to_index(label_list)
if args_opt.use_crf.lower() == "true":
with open(args_opt.label2id_file_path) as json_file:
tag_to_index = json.load(json_file)
max_val = max(tag_to_index.values())
tag_to_index["<START>"] = max_val + 1
tag_to_index["<STOP>"] = max_val + 2
@ -240,7 +241,7 @@ def run_ner():
schema_file_path=args_opt.schema_file_path,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index)
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index)
if __name__ == "__main__":
run_ner()

View File

@ -38,7 +38,7 @@ python ${PROJECT_DIR}/../run_ner.py \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--vocab_file_path="" \
--label2id_file_path="" \
--label_file_path="" \
--save_finetune_checkpoint_path="" \
--load_pretrain_checkpoint_path="" \
--load_finetune_checkpoint_path="" \

View File

@ -23,9 +23,9 @@ from src import tokenization
from src.sample_process import label_generation, process_one_example_p
from src.CRF import postprocess
from src.finetune_eval_config import bert_net_cfg
from src.score import get_result
def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
def process(model=None, text="", tokenizer_=None, use_crf="", tag_to_index=None, vocab=""):
"""
process text.
"""
@ -34,7 +34,7 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
res = []
ids = []
for i in data:
feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length)
feature = process_one_example_p(tokenizer_, vocab, i, max_seq_len=bert_net_cfg.seq_length)
features.append(feature)
input_ids, input_mask, token_type_id = feature
input_ids = Tensor(np.array(input_ids), mstype.int32)
@ -52,10 +52,10 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
ids = logits.asnumpy()
ids = np.argmax(ids, axis=-1)
ids = list(ids)
res = label_generation(text=text, probs=ids, label2id_file=label2id_file)
res = label_generation(text=text, probs=ids, tag_to_index=tag_to_index)
return res
def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
def submit(model=None, path="", vocab_file="", use_crf="", label_file="", tag_to_index=None):
"""
submit task
"""
@ -66,8 +66,11 @@ def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
continue
oneline = json.loads(line.strip())
res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
use_crf=use_crf, label2id_file=label2id_file)
print("text", oneline["text"])
print("res:", res)
use_crf=use_crf, tag_to_index=tag_to_index, vocab=vocab_file)
data.append(json.dumps({"label": res}, ensure_ascii=False))
open("ner_predict.json", "w").write("\n".join(data))
labels = []
with open(label_file) as f:
for label in f:
labels.append(label.strip())
get_result(labels, "ner_predict.json", path)

View File

@ -16,9 +16,9 @@
"""process txt"""
import re
import json
from src.tokenization import convert_tokens_to_ids
def process_one_example_p(tokenizer, text, max_seq_len=128):
def process_one_example_p(tokenizer, vocab, text, max_seq_len=128):
"""process one testline"""
textlist = list(text)
tokens = []
@ -37,7 +37,7 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
segment_ids.append(0)
ntokens.append("[SEP]")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_ids = convert_tokens_to_ids(vocab, ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len:
input_ids.append(0)
@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
feature = (input_ids, input_mask, segment_ids)
return feature
def label_generation(text="", probs=None, label2id_file=""):
def label_generation(text="", probs=None, tag_to_index=None):
"""generate label"""
data = [text]
probs = [probs]
result = []
label2id = json.loads(open(label2id_file).read())
label2id = tag_to_index
id2label = [k for k, v in label2id.items()]
for index, prob in enumerate(probs):

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Calculate average F1 score among labels.
"""
import json
def get_f1_score_for_each_label(pre_lines, gold_lines, label):
"""
Get F1 score for each label.
Args:
pre_lines: listed label info from pre_file.
gold_lines: listed label info from gold_file.
label:
Returns:
F1 score for this label.
"""
TP = 0
FP = 0
FN = 0
index = 0
while index < len(pre_lines):
pre_line = pre_lines[index].get(label, {})
gold_line = gold_lines[index].get(label, {})
for sample in pre_line:
if sample in gold_line:
TP += 1
else:
FP += 1
for sample in gold_line:
if sample not in pre_line:
FN += 1
index += 1
f1 = 2 * TP / (2 * TP + FP + FN)
return f1
def get_f1_score(labels, pre_file, gold_file):
"""
Get F1 scores for each label.
Args:
labels: list of labels.
pre_file: prediction file.
gold_file: ground truth file.
Returns:
average F1 score on all labels.
"""
pre_lines = [json.loads(line.strip())['label'] for line in open(pre_file) if line.strip()]
gold_lines = [json.loads(line.strip())['label'] for line in open(gold_file) if line.strip()]
if len(pre_lines) != len(gold_lines):
raise ValueError("pre file and gold file have different line count.")
f1_sum = 0
for label in labels:
f1 = get_f1_score_for_each_label(pre_lines, gold_lines, label)
print('label: %s, F1: %.6f' % (label, f1))
f1_sum += f1
return f1_sum/len(labels)
def get_result(labels, pre_file, gold_file):
avg = get_f1_score(labels, pre_file=pre_file, gold_file=gold_file)
print("avg F1: %.6f" % avg)

View File

@ -0,0 +1,329 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Tokenization.
"""
import unicodedata
import collections
def convert_to_unicode(text):
"""
Convert text into unicode type.
Args:
text: input str.
Returns:
input str in unicode.
"""
ret = text
if isinstance(text, str):
ret = text
elif isinstance(text, bytes):
ret = text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
return ret
def vocab_to_dict_key_token(vocab_file):
"""Loads a vocab file into a dict, key is token."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def vocab_to_dict_key_id(vocab_file):
"""Loads a vocab file into a dict, key is id."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[index] = token
index += 1
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def convert_tokens_to_ids(vocab_file, tokens):
"""
Convert tokens to ids.
Args:
vocab_file: path to vocab.txt.
tokens: list of tokens.
Returns:
list of ids.
"""
vocab_dict = vocab_to_dict_key_token(vocab_file)
output = []
for token in tokens:
output.append(vocab_dict[token])
return output
def convert_ids_to_tokens(vocab_file, ids):
"""
Convert ids to tokens.
Args:
vocab_file: path to vocab.txt.
ids: list of ids.
Returns:
list of tokens.
"""
vocab_dict = vocab_to_dict_key_id(vocab_file)
output = []
for _id in ids:
output.append(vocab_dict[_id])
return output
class FullTokenizer():
"""
Full tokenizer
"""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab_dict = vocab_to_dict_key_token(vocab_file)
self.do_lower_case = do_lower_case
self.basic_tokenize = BasicTokenizer(do_lower_case)
self.wordpiece_tokenize = WordpieceTokenizer(self.vocab_dict)
def tokenize(self, text):
"""
Do full tokenization.
Args:
text: str of text.
Returns:
list of tokens.
"""
tokens_ret = []
text = convert_to_unicode(text)
for tokens in self.basic_tokenize.tokenize(text):
wordpiece_tokens = self.wordpiece_tokenize.tokenize(tokens)
tokens_ret.extend(wordpiece_tokens)
return tokens_ret
class BasicTokenizer():
"""
Basic tokenizer
"""
def __init__(self, do_lower_case=True):
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""
Do basic tokenization.
Args:
text: text in unicode.
Returns:
a list of tokens split from text
"""
text = self._clean_text(text)
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
aaa = self._run_split_on_punc(token)
split_tokens.extend(aaa)
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
i = 0
start_new_word = True
output = []
for char in text:
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((0x4E00 <= cp <= 0x9FFF) or
(0x3400 <= cp <= 0x4DBF) or
(0x20000 <= cp <= 0x2A6DF) or
(0x2A700 <= cp <= 0x2B73F) or
(0x2B740 <= cp <= 0x2B81F) or
(0x2B820 <= cp <= 0x2CEAF) or
(0xF900 <= cp <= 0xFAFF) or
(0x2F800 <= cp <= 0x2FA1F)):
return True
return False
class WordpieceTokenizer():
"""
Wordpiece tokenizer
"""
def __init__(self, vocab):
self.vocab_dict = vocab
def tokenize(self, tokens):
"""
Do word-piece tokenization
Args:
tokens: a word.
Returns:
a list of tokens that can be found in vocab dict.
"""
output_tokens = []
tokens = convert_to_unicode(tokens)
for token in whitespace_tokenize(tokens):
chars = list(token)
len_chars = len(chars)
start = 0
end = len_chars
while start < len_chars:
while start < end:
substr = "".join(token[start:end])
if start != 0:
substr = "##" + substr
if substr in self.vocab_dict:
output_tokens.append(substr)
start = end
end = len_chars
else:
end = end - 1
if start == end and start != len_chars:
output_tokens.append("[UNK]")
break
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
whitespace_char = [" ", "\t", "\n", "\r"]
if char in whitespace_char:
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
control_char = ["\t", "\n", "\r"]
if char in control_char:
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((33 <= cp <= 47) or (58 <= cp <= 64) or
(91 <= cp <= 96) or (123 <= cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

View File

@ -19,6 +19,7 @@ Functional Cells used in Bert finetune and evaluation.
import os
import math
import collections
import numpy as np
import mindspore.nn as nn
from mindspore import log as logger
@ -213,3 +214,19 @@ class BertLearningRate(LearningRateSchedule):
else:
lr = decay_lr
return lr
def convert_labels_to_index(label_list):
"""
Convert label_list to indices for NER task.
"""
label2id = collections.OrderedDict()
label2id["O"] = 0
prefix = ["S_", "B_", "M_", "E_"]
index = 0
for label in label_list:
for pre in prefix:
index += 1
sub_label = pre + label
label2id[sub_label] = index
return label2id