mindspore/example/bert_clue/sample_process.py

101 lines
3.3 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""process txt"""
import re
import json
def process_one_example_p(tokenizer, text, max_seq_len=128):
"""process one testline"""
textlist = list(text)
tokens = []
for _, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
if len(tokens) >= max_seq_len - 1:
tokens = tokens[0:(max_seq_len - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
for _, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
ntokens.append("[SEP]")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("**NULL**")
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
feature = (input_ids, input_mask, segment_ids)
return feature
def label_generation(text, probs):
"""generate label"""
data = [text]
probs = [probs]
result = []
label2id = json.loads(open("./label2id.json").read())
id2label = [k for k, v in label2id.items()]
for index, prob in enumerate(probs):
for v in prob[1:len(data[index]) + 1]:
result.append(id2label[int(v)])
labels = {}
start = None
index = 0
for _, t in zip("".join(data), result):
if re.search("^[BS]", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = index
if re.search("^O", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = None
index += 1
if start is not None:
label = result[start][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
return labels