forked from OSSInnovation/mindspore
!6616 add tokenization and score file
Merge pull request !6616 from yoonlee666/token
This commit is contained in:
commit
b309850036
|
@ -18,12 +18,11 @@ Bert finetune and evaluation script.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
from src.bert_for_finetune import BertFinetuneCell, BertNER
|
from src.bert_for_finetune import BertFinetuneCell, BertNER
|
||||||
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
|
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
|
||||||
from src.dataset import create_ner_dataset
|
from src.dataset import create_ner_dataset
|
||||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
|
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
|
||||||
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
|
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -99,7 +98,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
|
||||||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||||
|
|
||||||
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
|
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
|
||||||
load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None):
|
load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None):
|
||||||
""" do eval """
|
""" do eval """
|
||||||
if load_checkpoint_path == "":
|
if load_checkpoint_path == "":
|
||||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||||
|
@ -114,7 +113,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
|
||||||
|
|
||||||
if assessment_method == "clue_benchmark":
|
if assessment_method == "clue_benchmark":
|
||||||
from src.cluener_evaluation import submit
|
from src.cluener_evaluation import submit
|
||||||
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file)
|
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf,
|
||||||
|
label_file=label_file, tag_to_index=tag_to_index)
|
||||||
else:
|
else:
|
||||||
if assessment_method == "accuracy":
|
if assessment_method == "accuracy":
|
||||||
callback = Accuracy()
|
callback = Accuracy()
|
||||||
|
@ -161,7 +161,7 @@ def parse_args():
|
||||||
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||||
help="Enable eval data shuffle, default is false")
|
help="Enable eval data shuffle, default is false")
|
||||||
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
|
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
|
||||||
parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark")
|
parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark")
|
||||||
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||||
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||||
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||||
|
@ -180,10 +180,10 @@ def parse_args():
|
||||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "":
|
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "":
|
||||||
raise ValueError("'vocab_file_path' must be set to do clue benchmark")
|
raise ValueError("'vocab_file_path' must be set to do clue benchmark")
|
||||||
if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "":
|
if args_opt.use_crf.lower() == "true" and args_opt.label_file_path == "":
|
||||||
raise ValueError("'label2id_file_path' must be set to use crf")
|
raise ValueError("'label_file_path' must be set to use crf")
|
||||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "":
|
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label_file_path == "":
|
||||||
raise ValueError("'label2id_file_path' must be set to do clue benchmark")
|
raise ValueError("'label_file_path' must be set to do clue benchmark")
|
||||||
return args_opt
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
@ -205,11 +205,12 @@ def run_ner():
|
||||||
bert_net_cfg.compute_type = mstype.float32
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
else:
|
else:
|
||||||
raise Exception("Target error, GPU or Ascend is supported.")
|
raise Exception("Target error, GPU or Ascend is supported.")
|
||||||
|
label_list = []
|
||||||
tag_to_index = None
|
with open(args_opt.label_file_path) as f:
|
||||||
|
for label in f:
|
||||||
|
label_list.append(label.strip())
|
||||||
|
tag_to_index = convert_labels_to_index(label_list)
|
||||||
if args_opt.use_crf.lower() == "true":
|
if args_opt.use_crf.lower() == "true":
|
||||||
with open(args_opt.label2id_file_path) as json_file:
|
|
||||||
tag_to_index = json.load(json_file)
|
|
||||||
max_val = max(tag_to_index.values())
|
max_val = max(tag_to_index.values())
|
||||||
tag_to_index["<START>"] = max_val + 1
|
tag_to_index["<START>"] = max_val + 1
|
||||||
tag_to_index["<STOP>"] = max_val + 2
|
tag_to_index["<STOP>"] = max_val + 2
|
||||||
|
@ -240,7 +241,7 @@ def run_ner():
|
||||||
schema_file_path=args_opt.schema_file_path,
|
schema_file_path=args_opt.schema_file_path,
|
||||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
|
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
|
||||||
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index)
|
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_ner()
|
run_ner()
|
||||||
|
|
|
@ -38,7 +38,7 @@ python ${PROJECT_DIR}/../run_ner.py \
|
||||||
--train_data_shuffle="true" \
|
--train_data_shuffle="true" \
|
||||||
--eval_data_shuffle="false" \
|
--eval_data_shuffle="false" \
|
||||||
--vocab_file_path="" \
|
--vocab_file_path="" \
|
||||||
--label2id_file_path="" \
|
--label_file_path="" \
|
||||||
--save_finetune_checkpoint_path="" \
|
--save_finetune_checkpoint_path="" \
|
||||||
--load_pretrain_checkpoint_path="" \
|
--load_pretrain_checkpoint_path="" \
|
||||||
--load_finetune_checkpoint_path="" \
|
--load_finetune_checkpoint_path="" \
|
||||||
|
|
|
@ -23,9 +23,9 @@ from src import tokenization
|
||||||
from src.sample_process import label_generation, process_one_example_p
|
from src.sample_process import label_generation, process_one_example_p
|
||||||
from src.CRF import postprocess
|
from src.CRF import postprocess
|
||||||
from src.finetune_eval_config import bert_net_cfg
|
from src.finetune_eval_config import bert_net_cfg
|
||||||
|
from src.score import get_result
|
||||||
|
|
||||||
|
def process(model=None, text="", tokenizer_=None, use_crf="", tag_to_index=None, vocab=""):
|
||||||
def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
|
||||||
"""
|
"""
|
||||||
process text.
|
process text.
|
||||||
"""
|
"""
|
||||||
|
@ -34,7 +34,7 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
||||||
res = []
|
res = []
|
||||||
ids = []
|
ids = []
|
||||||
for i in data:
|
for i in data:
|
||||||
feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length)
|
feature = process_one_example_p(tokenizer_, vocab, i, max_seq_len=bert_net_cfg.seq_length)
|
||||||
features.append(feature)
|
features.append(feature)
|
||||||
input_ids, input_mask, token_type_id = feature
|
input_ids, input_mask, token_type_id = feature
|
||||||
input_ids = Tensor(np.array(input_ids), mstype.int32)
|
input_ids = Tensor(np.array(input_ids), mstype.int32)
|
||||||
|
@ -52,10 +52,10 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
||||||
ids = logits.asnumpy()
|
ids = logits.asnumpy()
|
||||||
ids = np.argmax(ids, axis=-1)
|
ids = np.argmax(ids, axis=-1)
|
||||||
ids = list(ids)
|
ids = list(ids)
|
||||||
res = label_generation(text=text, probs=ids, label2id_file=label2id_file)
|
res = label_generation(text=text, probs=ids, tag_to_index=tag_to_index)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
|
def submit(model=None, path="", vocab_file="", use_crf="", label_file="", tag_to_index=None):
|
||||||
"""
|
"""
|
||||||
submit task
|
submit task
|
||||||
"""
|
"""
|
||||||
|
@ -66,8 +66,11 @@ def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
|
||||||
continue
|
continue
|
||||||
oneline = json.loads(line.strip())
|
oneline = json.loads(line.strip())
|
||||||
res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
|
res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
|
||||||
use_crf=use_crf, label2id_file=label2id_file)
|
use_crf=use_crf, tag_to_index=tag_to_index, vocab=vocab_file)
|
||||||
print("text", oneline["text"])
|
|
||||||
print("res:", res)
|
|
||||||
data.append(json.dumps({"label": res}, ensure_ascii=False))
|
data.append(json.dumps({"label": res}, ensure_ascii=False))
|
||||||
open("ner_predict.json", "w").write("\n".join(data))
|
open("ner_predict.json", "w").write("\n".join(data))
|
||||||
|
labels = []
|
||||||
|
with open(label_file) as f:
|
||||||
|
for label in f:
|
||||||
|
labels.append(label.strip())
|
||||||
|
get_result(labels, "ner_predict.json", path)
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
"""process txt"""
|
"""process txt"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import json
|
from src.tokenization import convert_tokens_to_ids
|
||||||
|
|
||||||
def process_one_example_p(tokenizer, text, max_seq_len=128):
|
def process_one_example_p(tokenizer, vocab, text, max_seq_len=128):
|
||||||
"""process one testline"""
|
"""process one testline"""
|
||||||
textlist = list(text)
|
textlist = list(text)
|
||||||
tokens = []
|
tokens = []
|
||||||
|
@ -37,7 +37,7 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
ntokens.append("[SEP]")
|
ntokens.append("[SEP]")
|
||||||
segment_ids.append(0)
|
segment_ids.append(0)
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
input_ids = convert_tokens_to_ids(vocab, ntokens)
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1] * len(input_ids)
|
||||||
while len(input_ids) < max_seq_len:
|
while len(input_ids) < max_seq_len:
|
||||||
input_ids.append(0)
|
input_ids.append(0)
|
||||||
|
@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
|
||||||
feature = (input_ids, input_mask, segment_ids)
|
feature = (input_ids, input_mask, segment_ids)
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
def label_generation(text="", probs=None, label2id_file=""):
|
def label_generation(text="", probs=None, tag_to_index=None):
|
||||||
"""generate label"""
|
"""generate label"""
|
||||||
data = [text]
|
data = [text]
|
||||||
probs = [probs]
|
probs = [probs]
|
||||||
result = []
|
result = []
|
||||||
label2id = json.loads(open(label2id_file).read())
|
label2id = tag_to_index
|
||||||
id2label = [k for k, v in label2id.items()]
|
id2label = [k for k, v in label2id.items()]
|
||||||
|
|
||||||
for index, prob in enumerate(probs):
|
for index, prob in enumerate(probs):
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Calculate average F1 score among labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
def get_f1_score_for_each_label(pre_lines, gold_lines, label):
|
||||||
|
"""
|
||||||
|
Get F1 score for each label.
|
||||||
|
Args:
|
||||||
|
pre_lines: listed label info from pre_file.
|
||||||
|
gold_lines: listed label info from gold_file.
|
||||||
|
label:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
F1 score for this label.
|
||||||
|
"""
|
||||||
|
TP = 0
|
||||||
|
FP = 0
|
||||||
|
FN = 0
|
||||||
|
index = 0
|
||||||
|
while index < len(pre_lines):
|
||||||
|
pre_line = pre_lines[index].get(label, {})
|
||||||
|
gold_line = gold_lines[index].get(label, {})
|
||||||
|
for sample in pre_line:
|
||||||
|
if sample in gold_line:
|
||||||
|
TP += 1
|
||||||
|
else:
|
||||||
|
FP += 1
|
||||||
|
for sample in gold_line:
|
||||||
|
if sample not in pre_line:
|
||||||
|
FN += 1
|
||||||
|
index += 1
|
||||||
|
f1 = 2 * TP / (2 * TP + FP + FN)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
|
||||||
|
def get_f1_score(labels, pre_file, gold_file):
|
||||||
|
"""
|
||||||
|
Get F1 scores for each label.
|
||||||
|
Args:
|
||||||
|
labels: list of labels.
|
||||||
|
pre_file: prediction file.
|
||||||
|
gold_file: ground truth file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
average F1 score on all labels.
|
||||||
|
"""
|
||||||
|
pre_lines = [json.loads(line.strip())['label'] for line in open(pre_file) if line.strip()]
|
||||||
|
gold_lines = [json.loads(line.strip())['label'] for line in open(gold_file) if line.strip()]
|
||||||
|
if len(pre_lines) != len(gold_lines):
|
||||||
|
raise ValueError("pre file and gold file have different line count.")
|
||||||
|
f1_sum = 0
|
||||||
|
for label in labels:
|
||||||
|
f1 = get_f1_score_for_each_label(pre_lines, gold_lines, label)
|
||||||
|
print('label: %s, F1: %.6f' % (label, f1))
|
||||||
|
f1_sum += f1
|
||||||
|
|
||||||
|
return f1_sum/len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(labels, pre_file, gold_file):
|
||||||
|
avg = get_f1_score(labels, pre_file=pre_file, gold_file=gold_file)
|
||||||
|
print("avg F1: %.6f" % avg)
|
|
@ -0,0 +1,329 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tokenization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unicodedata
|
||||||
|
import collections
|
||||||
|
|
||||||
|
def convert_to_unicode(text):
|
||||||
|
"""
|
||||||
|
Convert text into unicode type.
|
||||||
|
Args:
|
||||||
|
text: input str.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
input str in unicode.
|
||||||
|
"""
|
||||||
|
ret = text
|
||||||
|
if isinstance(text, str):
|
||||||
|
ret = text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
ret = text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_to_dict_key_token(vocab_file):
|
||||||
|
"""Loads a vocab file into a dict, key is token."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
index = 0
|
||||||
|
with open(vocab_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
token = convert_to_unicode(reader.readline())
|
||||||
|
if not token:
|
||||||
|
break
|
||||||
|
token = token.strip()
|
||||||
|
vocab[token] = index
|
||||||
|
index += 1
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def vocab_to_dict_key_id(vocab_file):
|
||||||
|
"""Loads a vocab file into a dict, key is id."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
index = 0
|
||||||
|
with open(vocab_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
token = convert_to_unicode(reader.readline())
|
||||||
|
if not token:
|
||||||
|
break
|
||||||
|
token = token.strip()
|
||||||
|
vocab[index] = token
|
||||||
|
index += 1
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_tokenize(text):
|
||||||
|
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(vocab_file, tokens):
|
||||||
|
"""
|
||||||
|
Convert tokens to ids.
|
||||||
|
Args:
|
||||||
|
vocab_file: path to vocab.txt.
|
||||||
|
tokens: list of tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of ids.
|
||||||
|
"""
|
||||||
|
vocab_dict = vocab_to_dict_key_token(vocab_file)
|
||||||
|
output = []
|
||||||
|
for token in tokens:
|
||||||
|
output.append(vocab_dict[token])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(vocab_file, ids):
|
||||||
|
"""
|
||||||
|
Convert ids to tokens.
|
||||||
|
Args:
|
||||||
|
vocab_file: path to vocab.txt.
|
||||||
|
ids: list of ids.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of tokens.
|
||||||
|
"""
|
||||||
|
vocab_dict = vocab_to_dict_key_id(vocab_file)
|
||||||
|
output = []
|
||||||
|
for _id in ids:
|
||||||
|
output.append(vocab_dict[_id])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FullTokenizer():
|
||||||
|
"""
|
||||||
|
Full tokenizer
|
||||||
|
"""
|
||||||
|
def __init__(self, vocab_file, do_lower_case=True):
|
||||||
|
self.vocab_dict = vocab_to_dict_key_token(vocab_file)
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.basic_tokenize = BasicTokenizer(do_lower_case)
|
||||||
|
self.wordpiece_tokenize = WordpieceTokenizer(self.vocab_dict)
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""
|
||||||
|
Do full tokenization.
|
||||||
|
Args:
|
||||||
|
text: str of text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of tokens.
|
||||||
|
"""
|
||||||
|
tokens_ret = []
|
||||||
|
text = convert_to_unicode(text)
|
||||||
|
for tokens in self.basic_tokenize.tokenize(text):
|
||||||
|
wordpiece_tokens = self.wordpiece_tokenize.tokenize(tokens)
|
||||||
|
tokens_ret.extend(wordpiece_tokens)
|
||||||
|
return tokens_ret
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTokenizer():
|
||||||
|
"""
|
||||||
|
Basic tokenizer
|
||||||
|
"""
|
||||||
|
def __init__(self, do_lower_case=True):
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""
|
||||||
|
Do basic tokenization.
|
||||||
|
Args:
|
||||||
|
text: text in unicode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of tokens split from text
|
||||||
|
"""
|
||||||
|
text = self._clean_text(text)
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
|
||||||
|
orig_tokens = whitespace_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if self.do_lower_case:
|
||||||
|
token = token.lower()
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
aaa = self._run_split_on_punc(token)
|
||||||
|
split_tokens.extend(aaa)
|
||||||
|
|
||||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""Strips accents from a piece of text."""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text):
|
||||||
|
"""Splits punctuation on a piece of text."""
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
if _is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||||
|
continue
|
||||||
|
if _is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""Adds whitespace around any CJK character."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self._is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if ((0x4E00 <= cp <= 0x9FFF) or
|
||||||
|
(0x3400 <= cp <= 0x4DBF) or
|
||||||
|
(0x20000 <= cp <= 0x2A6DF) or
|
||||||
|
(0x2A700 <= cp <= 0x2B73F) or
|
||||||
|
(0x2B740 <= cp <= 0x2B81F) or
|
||||||
|
(0x2B820 <= cp <= 0x2CEAF) or
|
||||||
|
(0xF900 <= cp <= 0xFAFF) or
|
||||||
|
(0x2F800 <= cp <= 0x2FA1F)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class WordpieceTokenizer():
|
||||||
|
"""
|
||||||
|
Wordpiece tokenizer
|
||||||
|
"""
|
||||||
|
def __init__(self, vocab):
|
||||||
|
self.vocab_dict = vocab
|
||||||
|
|
||||||
|
def tokenize(self, tokens):
|
||||||
|
"""
|
||||||
|
Do word-piece tokenization
|
||||||
|
Args:
|
||||||
|
tokens: a word.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of tokens that can be found in vocab dict.
|
||||||
|
"""
|
||||||
|
output_tokens = []
|
||||||
|
tokens = convert_to_unicode(tokens)
|
||||||
|
for token in whitespace_tokenize(tokens):
|
||||||
|
chars = list(token)
|
||||||
|
len_chars = len(chars)
|
||||||
|
start = 0
|
||||||
|
end = len_chars
|
||||||
|
while start < len_chars:
|
||||||
|
while start < end:
|
||||||
|
substr = "".join(token[start:end])
|
||||||
|
if start != 0:
|
||||||
|
substr = "##" + substr
|
||||||
|
if substr in self.vocab_dict:
|
||||||
|
output_tokens.append(substr)
|
||||||
|
start = end
|
||||||
|
end = len_chars
|
||||||
|
else:
|
||||||
|
end = end - 1
|
||||||
|
if start == end and start != len_chars:
|
||||||
|
output_tokens.append("[UNK]")
|
||||||
|
break
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_whitespace(char):
|
||||||
|
"""Checks whether `chars` is a whitespace character."""
|
||||||
|
# \t, \n, and \r are technically contorl characters but we treat them
|
||||||
|
# as whitespace since they are generally considered as such.
|
||||||
|
whitespace_char = [" ", "\t", "\n", "\r"]
|
||||||
|
if char in whitespace_char:
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Zs":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_control(char):
|
||||||
|
"""Checks whether `chars` is a control character."""
|
||||||
|
# These are technically control characters but we count them as whitespace
|
||||||
|
# characters.
|
||||||
|
control_char = ["\t", "\n", "\r"]
|
||||||
|
if char in control_char:
|
||||||
|
return False
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat in ("Cc", "Cf"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_punctuation(char):
|
||||||
|
"""Checks whether `chars` is a punctuation character."""
|
||||||
|
cp = ord(char)
|
||||||
|
# We treat all non-letter/number ASCII as punctuation.
|
||||||
|
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||||
|
# Punctuation class but we treat them as punctuation anyways, for
|
||||||
|
# consistency.
|
||||||
|
if ((33 <= cp <= 47) or (58 <= cp <= 64) or
|
||||||
|
(91 <= cp <= 96) or (123 <= cp <= 126)):
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("P"):
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -19,6 +19,7 @@ Functional Cells used in Bert finetune and evaluation.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
@ -213,3 +214,19 @@ class BertLearningRate(LearningRateSchedule):
|
||||||
else:
|
else:
|
||||||
lr = decay_lr
|
lr = decay_lr
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def convert_labels_to_index(label_list):
|
||||||
|
"""
|
||||||
|
Convert label_list to indices for NER task.
|
||||||
|
"""
|
||||||
|
label2id = collections.OrderedDict()
|
||||||
|
label2id["O"] = 0
|
||||||
|
prefix = ["S_", "B_", "M_", "E_"]
|
||||||
|
index = 0
|
||||||
|
for label in label_list:
|
||||||
|
for pre in prefix:
|
||||||
|
index += 1
|
||||||
|
sub_label = pre + label
|
||||||
|
label2id[sub_label] = index
|
||||||
|
return label2id
|
||||||
|
|
Loading…
Reference in New Issue