forked from mindspore-Ecosystem/mindspore
add tokenization and score file
This commit is contained in:
parent
fe219b5680
commit
01c9e8b373
|
@ -18,12 +18,11 @@ Bert finetune and evaluation script.
|
|||
'''
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from src.bert_for_finetune import BertFinetuneCell, BertNER
|
||||
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
|
||||
from src.dataset import create_ner_dataset
|
||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
|
||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
|
||||
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
@ -99,7 +98,7 @@ def eval_result_print(assessment_method="accuracy", callback=None):
|
|||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||
|
||||
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
|
||||
load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None):
|
||||
load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None):
|
||||
""" do eval """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
|
@ -114,7 +113,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
|
|||
|
||||
if assessment_method == "clue_benchmark":
|
||||
from src.cluener_evaluation import submit
|
||||
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file)
|
||||
submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf,
|
||||
label_file=label_file, tag_to_index=tag_to_index)
|
||||
else:
|
||||
if assessment_method == "accuracy":
|
||||
callback = Accuracy()
|
||||
|
@ -161,7 +161,7 @@ def parse_args():
|
|||
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval data shuffle, default is false")
|
||||
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
|
||||
parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark")
|
||||
parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark")
|
||||
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
|
@ -180,10 +180,10 @@ def parse_args():
|
|||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "":
|
||||
raise ValueError("'vocab_file_path' must be set to do clue benchmark")
|
||||
if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "":
|
||||
raise ValueError("'label2id_file_path' must be set to use crf")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "":
|
||||
raise ValueError("'label2id_file_path' must be set to do clue benchmark")
|
||||
if args_opt.use_crf.lower() == "true" and args_opt.label_file_path == "":
|
||||
raise ValueError("'label_file_path' must be set to use crf")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label_file_path == "":
|
||||
raise ValueError("'label_file_path' must be set to do clue benchmark")
|
||||
return args_opt
|
||||
|
||||
|
||||
|
@ -205,11 +205,12 @@ def run_ner():
|
|||
bert_net_cfg.compute_type = mstype.float32
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
|
||||
tag_to_index = None
|
||||
label_list = []
|
||||
with open(args_opt.label_file_path) as f:
|
||||
for label in f:
|
||||
label_list.append(label.strip())
|
||||
tag_to_index = convert_labels_to_index(label_list)
|
||||
if args_opt.use_crf.lower() == "true":
|
||||
with open(args_opt.label2id_file_path) as json_file:
|
||||
tag_to_index = json.load(json_file)
|
||||
max_val = max(tag_to_index.values())
|
||||
tag_to_index["<START>"] = max_val + 1
|
||||
tag_to_index["<STOP>"] = max_val + 2
|
||||
|
@ -240,7 +241,7 @@ def run_ner():
|
|||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
|
||||
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index)
|
||||
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_ner()
|
||||
|
|
|
@ -38,7 +38,7 @@ python ${PROJECT_DIR}/../run_ner.py \
|
|||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--vocab_file_path="" \
|
||||
--label2id_file_path="" \
|
||||
--label_file_path="" \
|
||||
--save_finetune_checkpoint_path="" \
|
||||
--load_pretrain_checkpoint_path="" \
|
||||
--load_finetune_checkpoint_path="" \
|
||||
|
|
|
@ -23,9 +23,9 @@ from src import tokenization
|
|||
from src.sample_process import label_generation, process_one_example_p
|
||||
from src.CRF import postprocess
|
||||
from src.finetune_eval_config import bert_net_cfg
|
||||
from src.score import get_result
|
||||
|
||||
|
||||
def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
||||
def process(model=None, text="", tokenizer_=None, use_crf="", tag_to_index=None, vocab=""):
|
||||
"""
|
||||
process text.
|
||||
"""
|
||||
|
@ -34,7 +34,7 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
|||
res = []
|
||||
ids = []
|
||||
for i in data:
|
||||
feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length)
|
||||
feature = process_one_example_p(tokenizer_, vocab, i, max_seq_len=bert_net_cfg.seq_length)
|
||||
features.append(feature)
|
||||
input_ids, input_mask, token_type_id = feature
|
||||
input_ids = Tensor(np.array(input_ids), mstype.int32)
|
||||
|
@ -52,10 +52,10 @@ def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""):
|
|||
ids = logits.asnumpy()
|
||||
ids = np.argmax(ids, axis=-1)
|
||||
ids = list(ids)
|
||||
res = label_generation(text=text, probs=ids, label2id_file=label2id_file)
|
||||
res = label_generation(text=text, probs=ids, tag_to_index=tag_to_index)
|
||||
return res
|
||||
|
||||
def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
|
||||
def submit(model=None, path="", vocab_file="", use_crf="", label_file="", tag_to_index=None):
|
||||
"""
|
||||
submit task
|
||||
"""
|
||||
|
@ -66,8 +66,11 @@ def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""):
|
|||
continue
|
||||
oneline = json.loads(line.strip())
|
||||
res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_,
|
||||
use_crf=use_crf, label2id_file=label2id_file)
|
||||
print("text", oneline["text"])
|
||||
print("res:", res)
|
||||
use_crf=use_crf, tag_to_index=tag_to_index, vocab=vocab_file)
|
||||
data.append(json.dumps({"label": res}, ensure_ascii=False))
|
||||
open("ner_predict.json", "w").write("\n".join(data))
|
||||
labels = []
|
||||
with open(label_file) as f:
|
||||
for label in f:
|
||||
labels.append(label.strip())
|
||||
get_result(labels, "ner_predict.json", path)
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
"""process txt"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from src.tokenization import convert_tokens_to_ids
|
||||
|
||||
def process_one_example_p(tokenizer, text, max_seq_len=128):
|
||||
def process_one_example_p(tokenizer, vocab, text, max_seq_len=128):
|
||||
"""process one testline"""
|
||||
textlist = list(text)
|
||||
tokens = []
|
||||
|
@ -37,7 +37,7 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
|
|||
segment_ids.append(0)
|
||||
ntokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||
input_ids = convert_tokens_to_ids(vocab, ntokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_len:
|
||||
input_ids.append(0)
|
||||
|
@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128):
|
|||
feature = (input_ids, input_mask, segment_ids)
|
||||
return feature
|
||||
|
||||
def label_generation(text="", probs=None, label2id_file=""):
|
||||
def label_generation(text="", probs=None, tag_to_index=None):
|
||||
"""generate label"""
|
||||
data = [text]
|
||||
probs = [probs]
|
||||
result = []
|
||||
label2id = json.loads(open(label2id_file).read())
|
||||
label2id = tag_to_index
|
||||
id2label = [k for k, v in label2id.items()]
|
||||
|
||||
for index, prob in enumerate(probs):
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Calculate average F1 score among labels.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
def get_f1_score_for_each_label(pre_lines, gold_lines, label):
|
||||
"""
|
||||
Get F1 score for each label.
|
||||
Args:
|
||||
pre_lines: listed label info from pre_file.
|
||||
gold_lines: listed label info from gold_file.
|
||||
label:
|
||||
|
||||
Returns:
|
||||
F1 score for this label.
|
||||
"""
|
||||
TP = 0
|
||||
FP = 0
|
||||
FN = 0
|
||||
index = 0
|
||||
while index < len(pre_lines):
|
||||
pre_line = pre_lines[index].get(label, {})
|
||||
gold_line = gold_lines[index].get(label, {})
|
||||
for sample in pre_line:
|
||||
if sample in gold_line:
|
||||
TP += 1
|
||||
else:
|
||||
FP += 1
|
||||
for sample in gold_line:
|
||||
if sample not in pre_line:
|
||||
FN += 1
|
||||
index += 1
|
||||
f1 = 2 * TP / (2 * TP + FP + FN)
|
||||
return f1
|
||||
|
||||
|
||||
def get_f1_score(labels, pre_file, gold_file):
|
||||
"""
|
||||
Get F1 scores for each label.
|
||||
Args:
|
||||
labels: list of labels.
|
||||
pre_file: prediction file.
|
||||
gold_file: ground truth file.
|
||||
|
||||
Returns:
|
||||
average F1 score on all labels.
|
||||
"""
|
||||
pre_lines = [json.loads(line.strip())['label'] for line in open(pre_file) if line.strip()]
|
||||
gold_lines = [json.loads(line.strip())['label'] for line in open(gold_file) if line.strip()]
|
||||
if len(pre_lines) != len(gold_lines):
|
||||
raise ValueError("pre file and gold file have different line count.")
|
||||
f1_sum = 0
|
||||
for label in labels:
|
||||
f1 = get_f1_score_for_each_label(pre_lines, gold_lines, label)
|
||||
print('label: %s, F1: %.6f' % (label, f1))
|
||||
f1_sum += f1
|
||||
|
||||
return f1_sum/len(labels)
|
||||
|
||||
|
||||
def get_result(labels, pre_file, gold_file):
|
||||
avg = get_f1_score(labels, pre_file=pre_file, gold_file=gold_file)
|
||||
print("avg F1: %.6f" % avg)
|
|
@ -0,0 +1,329 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Tokenization.
|
||||
"""
|
||||
|
||||
import unicodedata
|
||||
import collections
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""
|
||||
Convert text into unicode type.
|
||||
Args:
|
||||
text: input str.
|
||||
|
||||
Returns:
|
||||
input str in unicode.
|
||||
"""
|
||||
ret = text
|
||||
if isinstance(text, str):
|
||||
ret = text
|
||||
elif isinstance(text, bytes):
|
||||
ret = text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
return ret
|
||||
|
||||
|
||||
def vocab_to_dict_key_token(vocab_file):
|
||||
"""Loads a vocab file into a dict, key is token."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def vocab_to_dict_key_id(vocab_file):
|
||||
"""Loads a vocab file into a dict, key is id."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[index] = token
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab_file, tokens):
|
||||
"""
|
||||
Convert tokens to ids.
|
||||
Args:
|
||||
vocab_file: path to vocab.txt.
|
||||
tokens: list of tokens.
|
||||
|
||||
Returns:
|
||||
list of ids.
|
||||
"""
|
||||
vocab_dict = vocab_to_dict_key_token(vocab_file)
|
||||
output = []
|
||||
for token in tokens:
|
||||
output.append(vocab_dict[token])
|
||||
return output
|
||||
|
||||
|
||||
def convert_ids_to_tokens(vocab_file, ids):
|
||||
"""
|
||||
Convert ids to tokens.
|
||||
Args:
|
||||
vocab_file: path to vocab.txt.
|
||||
ids: list of ids.
|
||||
|
||||
Returns:
|
||||
list of tokens.
|
||||
"""
|
||||
vocab_dict = vocab_to_dict_key_id(vocab_file)
|
||||
output = []
|
||||
for _id in ids:
|
||||
output.append(vocab_dict[_id])
|
||||
return output
|
||||
|
||||
|
||||
class FullTokenizer():
|
||||
"""
|
||||
Full tokenizer
|
||||
"""
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab_dict = vocab_to_dict_key_token(vocab_file)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.basic_tokenize = BasicTokenizer(do_lower_case)
|
||||
self.wordpiece_tokenize = WordpieceTokenizer(self.vocab_dict)
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Do full tokenization.
|
||||
Args:
|
||||
text: str of text.
|
||||
|
||||
Returns:
|
||||
list of tokens.
|
||||
"""
|
||||
tokens_ret = []
|
||||
text = convert_to_unicode(text)
|
||||
for tokens in self.basic_tokenize.tokenize(text):
|
||||
wordpiece_tokens = self.wordpiece_tokenize.tokenize(tokens)
|
||||
tokens_ret.extend(wordpiece_tokens)
|
||||
return tokens_ret
|
||||
|
||||
|
||||
class BasicTokenizer():
|
||||
"""
|
||||
Basic tokenizer
|
||||
"""
|
||||
def __init__(self, do_lower_case=True):
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Do basic tokenization.
|
||||
Args:
|
||||
text: text in unicode.
|
||||
|
||||
Returns:
|
||||
a list of tokens split from text
|
||||
"""
|
||||
text = self._clean_text(text)
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
aaa = self._run_split_on_punc(token)
|
||||
split_tokens.extend(aaa)
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
for char in text:
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((0x4E00 <= cp <= 0x9FFF) or
|
||||
(0x3400 <= cp <= 0x4DBF) or
|
||||
(0x20000 <= cp <= 0x2A6DF) or
|
||||
(0x2A700 <= cp <= 0x2B73F) or
|
||||
(0x2B740 <= cp <= 0x2B81F) or
|
||||
(0x2B820 <= cp <= 0x2CEAF) or
|
||||
(0xF900 <= cp <= 0xFAFF) or
|
||||
(0x2F800 <= cp <= 0x2FA1F)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WordpieceTokenizer():
|
||||
"""
|
||||
Wordpiece tokenizer
|
||||
"""
|
||||
def __init__(self, vocab):
|
||||
self.vocab_dict = vocab
|
||||
|
||||
def tokenize(self, tokens):
|
||||
"""
|
||||
Do word-piece tokenization
|
||||
Args:
|
||||
tokens: a word.
|
||||
|
||||
Returns:
|
||||
a list of tokens that can be found in vocab dict.
|
||||
"""
|
||||
output_tokens = []
|
||||
tokens = convert_to_unicode(tokens)
|
||||
for token in whitespace_tokenize(tokens):
|
||||
chars = list(token)
|
||||
len_chars = len(chars)
|
||||
start = 0
|
||||
end = len_chars
|
||||
while start < len_chars:
|
||||
while start < end:
|
||||
substr = "".join(token[start:end])
|
||||
if start != 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab_dict:
|
||||
output_tokens.append(substr)
|
||||
start = end
|
||||
end = len_chars
|
||||
else:
|
||||
end = end - 1
|
||||
if start == end and start != len_chars:
|
||||
output_tokens.append("[UNK]")
|
||||
break
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
whitespace_char = [" ", "\t", "\n", "\r"]
|
||||
if char in whitespace_char:
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
control_char = ["\t", "\n", "\r"]
|
||||
if char in control_char:
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((33 <= cp <= 47) or (58 <= cp <= 64) or
|
||||
(91 <= cp <= 96) or (123 <= cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -19,6 +19,7 @@ Functional Cells used in Bert finetune and evaluation.
|
|||
|
||||
import os
|
||||
import math
|
||||
import collections
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
|
@ -213,3 +214,19 @@ class BertLearningRate(LearningRateSchedule):
|
|||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
|
||||
def convert_labels_to_index(label_list):
|
||||
"""
|
||||
Convert label_list to indices for NER task.
|
||||
"""
|
||||
label2id = collections.OrderedDict()
|
||||
label2id["O"] = 0
|
||||
prefix = ["S_", "B_", "M_", "E_"]
|
||||
index = 0
|
||||
for label in label_list:
|
||||
for pre in prefix:
|
||||
index += 1
|
||||
sub_label = pre + label
|
||||
label2id[sub_label] = index
|
||||
return label2id
|
||||
|
|
Loading…
Reference in New Issue