forked from mindspore-Ecosystem/mindspore
!9258 add squad for bert
From: @yoonlee666 Reviewed-by: @c_34,@guoqi1024 Signed-off-by: @c_34
This commit is contained in:
commit
989744c61a
|
@ -144,12 +144,14 @@ def run_classifier():
|
|||
parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval, default is false")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
|
||||
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
|
||||
parser.add_argument("--epoch_num", type=int, default=3, help="Epoch number, default is 3.")
|
||||
parser.add_argument("--num_class", type=int, default=2, help="The number of class, default is 2.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable train data shuffle, default is true")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval data shuffle, default is false")
|
||||
parser.add_argument("--train_batch_size", type=int, default=32, help="Train batch size, default is 32")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=1, help="Eval batch size, default is 1")
|
||||
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
|
@ -188,7 +190,7 @@ def run_classifier():
|
|||
assessment_method=assessment_method)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_classification_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
ds = create_classification_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
|
@ -204,7 +206,7 @@ def run_classifier():
|
|||
ds.get_dataset_size(), epoch_num, "classifier")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_classification_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
ds = create_classification_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
|
|
|
@ -97,14 +97,12 @@ def eval_result_print(assessment_method="accuracy", callback=None):
|
|||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||
|
||||
def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="",
|
||||
load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None):
|
||||
def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_method="accuracy", data_file="",
|
||||
load_checkpoint_path="", vocab_file="", label_file="", tag_to_index=None, batch_size=1):
|
||||
""" do eval """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
if assessment_method == "clue_benchmark":
|
||||
optimizer_cfg.batch_size = 1
|
||||
net_for_pretraining = network(bert_net_cfg, optimizer_cfg.batch_size, False, num_class,
|
||||
net_for_pretraining = network(bert_net_cfg, batch_size, False, num_class,
|
||||
use_crf=(use_crf.lower() == "true"), tag_to_index=tag_to_index)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
@ -142,7 +140,7 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth
|
|||
|
||||
def parse_args():
|
||||
"""set and check parameters."""
|
||||
parser = argparse.ArgumentParser(description="run classifier")
|
||||
parser = argparse.ArgumentParser(description="run ner")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="Device type, default is Ascend")
|
||||
parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark"],
|
||||
|
@ -154,12 +152,14 @@ def parse_args():
|
|||
parser.add_argument("--use_crf", type=str, default="false", choices=["true", "false"],
|
||||
help="Use crf, default is false")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
|
||||
parser.add_argument("--num_class", type=int, default="41", help="The number of class, default is 41.")
|
||||
parser.add_argument("--epoch_num", type=int, default=5, help="Epoch number, default is 5.")
|
||||
parser.add_argument("--num_class", type=int, default=41, help="The number of class, default is 41.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable train data shuffle, default is true")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval data shuffle, default is false")
|
||||
parser.add_argument("--train_batch_size", type=int, default=32, help="Train batch size, default is 32")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=1, help="Eval batch size, default is 1")
|
||||
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark")
|
||||
parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark")
|
||||
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
|
@ -184,6 +184,8 @@ def parse_args():
|
|||
raise ValueError("'label_file_path' must be set to use crf")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label_file_path == "":
|
||||
raise ValueError("'label_file_path' must be set to do clue benchmark")
|
||||
if args_opt.assessment_method.lower() == "clue_benchmark":
|
||||
args_opt.eval_batch_size = 1
|
||||
return args_opt
|
||||
|
||||
|
||||
|
@ -217,11 +219,11 @@ def run_ner():
|
|||
number_labels = len(tag_to_index)
|
||||
else:
|
||||
number_labels = args_opt.num_class
|
||||
netwithloss = BertNER(bert_net_cfg, optimizer_cfg.batch_size, True, num_labels=number_labels,
|
||||
use_crf=(args_opt.use_crf.lower() == "true"),
|
||||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_ner_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
netwithloss = BertNER(bert_net_cfg, args_opt.train_batch_size, True, num_labels=number_labels,
|
||||
use_crf=(args_opt.use_crf.lower() == "true"),
|
||||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
|
||||
|
@ -236,12 +238,13 @@ def run_ner():
|
|||
ds.get_dataset_size(), epoch_num, "ner")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_ner_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
|
||||
load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label_file_path, tag_to_index)
|
||||
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
|
||||
args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
|
||||
args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_ner()
|
||||
|
|
|
@ -22,9 +22,6 @@ import collections
|
|||
from src.bert_for_finetune import BertSquadCell, BertSquad
|
||||
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
|
||||
from src.dataset import create_squad_dataset
|
||||
from src import tokenization
|
||||
from src.create_squad_data import read_squad_examples, convert_examples_to_features
|
||||
from src.run_squad import write_predictions
|
||||
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
@ -85,22 +82,10 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
model.train(epoch_num, dataset, callbacks=callbacks)
|
||||
|
||||
|
||||
def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384):
|
||||
def do_eval(dataset=None, load_checkpoint_path="", eval_batch_size=1):
|
||||
""" do eval """
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
|
||||
eval_examples = read_squad_examples(eval_json, False)
|
||||
eval_features = convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=seq_length,
|
||||
doc_stride=128,
|
||||
max_query_length=64,
|
||||
is_training=False,
|
||||
output_fn=None,
|
||||
verbose_logging=False)
|
||||
|
||||
net = BertSquad(bert_net_cfg, False, 2)
|
||||
net.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
@ -123,7 +108,7 @@ def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="",
|
|||
start = logits[1].asnumpy()
|
||||
end = logits[2].asnumpy()
|
||||
|
||||
for i in range(optimizer_cfg.batch_size):
|
||||
for i in range(eval_batch_size):
|
||||
unique_id = int(ids[i])
|
||||
start_logits = [float(x) for x in start[i].flat]
|
||||
end_logits = [float(x) for x in end[i].flat]
|
||||
|
@ -131,11 +116,11 @@ def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="",
|
|||
unique_id=unique_id,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None)
|
||||
return output
|
||||
|
||||
def run_squad():
|
||||
"""run squad task"""
|
||||
parser = argparse.ArgumentParser(description="run classifier")
|
||||
parser = argparse.ArgumentParser(description="run squad")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
|
||||
help="Device type, default is Ascend")
|
||||
parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"],
|
||||
|
@ -143,12 +128,14 @@ def run_squad():
|
|||
parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
|
||||
help="Eable eval, default is false")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
|
||||
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
|
||||
parser.add_argument("--epoch_num", type=int, default=3, help="Epoch number, default is 1.")
|
||||
parser.add_argument("--num_class", type=int, default=2, help="The number of class, default is 2.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable train data shuffle, default is true")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
|
||||
help="Enable eval data shuffle, default is false")
|
||||
parser.add_argument("--train_batch_size", type=int, default=32, help="Train batch size, default is 32")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=1, help="Eval batch size, default is 1")
|
||||
parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path")
|
||||
parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json")
|
||||
parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
|
@ -156,8 +143,6 @@ def run_squad():
|
|||
parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_file_path", type=str, default="",
|
||||
help="Schema path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
@ -171,8 +156,6 @@ def run_squad():
|
|||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
if args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
if args_opt.vocab_file_path == "":
|
||||
raise ValueError("'vocab_file_path' must be set when do evaluation task")
|
||||
if args_opt.eval_json_path == "":
|
||||
|
@ -193,7 +176,7 @@ def run_squad():
|
|||
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
ds = create_squad_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
ds = create_squad_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.train_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
|
||||
|
@ -207,12 +190,29 @@ def run_squad():
|
|||
ds.get_dataset_size(), epoch_num, "squad")
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
ds = create_squad_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
from src import tokenization
|
||||
from src.create_squad_data import read_squad_examples, convert_examples_to_features
|
||||
from src.squad_get_predictions import write_predictions
|
||||
from src.squad_postprocess import SQuad_postprocess
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=args_opt.vocab_file_path, do_lower_case=True)
|
||||
eval_examples = read_squad_examples(args_opt.eval_json_path, False)
|
||||
eval_features = convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=bert_net_cfg.seq_length,
|
||||
doc_stride=128,
|
||||
max_query_length=64,
|
||||
is_training=False,
|
||||
output_fn=None,
|
||||
vocab_file=args_opt.vocab_file_path)
|
||||
ds = create_squad_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
data_file_path=eval_features,
|
||||
schema_file_path=args_opt.schema_file_path, is_training=False,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,
|
||||
load_finetune_checkpoint_path, bert_net_cfg.seq_length)
|
||||
outputs = do_eval(ds, load_finetune_checkpoint_path, args_opt.eval_batch_size)
|
||||
all_predictions = write_predictions(eval_examples, eval_features, outputs, 20, 30, True)
|
||||
SQuad_postprocess(args_opt.eval_json_path, all_predictions, output_metrics="output.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_squad()
|
||||
|
|
|
@ -32,13 +32,15 @@ python ${PROJECT_DIR}/../run_classifier.py \
|
|||
--do_eval="false" \
|
||||
--assessment_method="Accuracy" \
|
||||
--device_id=0 \
|
||||
--epoch_num=1 \
|
||||
--epoch_num=3 \
|
||||
--num_class=2 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=1 \
|
||||
--save_finetune_checkpoint_path="" \
|
||||
--load_pretrain_checkpoint_path="" \
|
||||
--load_finetune_checkpoint_path="" \
|
||||
--train_data_file_path="" \
|
||||
--eval_data_file_path="" \
|
||||
--schema_file_path="" > classfifier_log.txt 2>&1 &
|
||||
--schema_file_path="" > classifier_log.txt 2>&1 &
|
||||
|
|
|
@ -33,10 +33,12 @@ python ${PROJECT_DIR}/../run_ner.py \
|
|||
--assessment_method="F1" \
|
||||
--use_crf="false" \
|
||||
--device_id=0 \
|
||||
--epoch_num=1 \
|
||||
--num_class=2 \
|
||||
--epoch_num=5 \
|
||||
--num_class=41 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=1 \
|
||||
--vocab_file_path="" \
|
||||
--label_file_path="" \
|
||||
--save_finetune_checkpoint_path="" \
|
||||
|
|
|
@ -31,15 +31,16 @@ python ${PROJECT_DIR}/../run_squad.py \
|
|||
--do_train="true" \
|
||||
--do_eval="false" \
|
||||
--device_id=0 \
|
||||
--epoch_num=1 \
|
||||
--epoch_num=3 \
|
||||
--num_class=2 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--train_batch_size=32 \
|
||||
--eval_batch_size=1 \
|
||||
--vocab_file_path="" \
|
||||
--eval_json_path="" \
|
||||
--save_finetune_checkpoint_path="" \
|
||||
--load_pretrain_checkpoint_path="" \
|
||||
--load_finetune_checkpoint_path="" \
|
||||
--train_data_file_path="" \
|
||||
--eval_data_file_path="" \
|
||||
--eval_json_path="" \
|
||||
--schema_file_path="" > squad_log.txt 2>&1 &
|
||||
|
|
|
@ -325,6 +325,8 @@ class BertSquad(nn.Cell):
|
|||
total_loss = (start_loss + end_loss) / 2.0
|
||||
else:
|
||||
start_logits = self.squeeze(logits[:, :, 0:1])
|
||||
start_logits = start_logits + 100 * input_mask
|
||||
end_logits = self.squeeze(logits[:, :, 1:2])
|
||||
end_logits = end_logits + 100 * input_mask
|
||||
total_loss = (unique_id, start_logits, end_logits)
|
||||
return total_loss
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create squad data"""
|
||||
|
||||
import collections
|
||||
import json
|
||||
from src import tokenization
|
||||
|
||||
|
||||
class SquadExample():
|
||||
"""extract column contents from raw data"""
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=False):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.doc_tokens = doc_tokens
|
||||
self.orig_answer_text = orig_answer_text
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
|
||||
class InputFeatures():
|
||||
"""A single set of features of data."""
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
self.tokens = tokens
|
||||
self.token_to_orig_map = token_to_orig_map
|
||||
self.token_is_max_context = token_is_max_context
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
|
||||
def read_squad_examples(input_file, is_training, version_2_with_negative=False):
|
||||
"""Return list of SquadExample from input_data or input_file (SQuAD json file)"""
|
||||
with open(input_file, "r") as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
|
||||
def is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
def token_offset(text):
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
prev_is_whitespace = True
|
||||
for c in paragraph_text:
|
||||
if is_whitespace(c):
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||
return (doc_tokens, char_to_word_offset)
|
||||
|
||||
|
||||
def process_one_example(qa, is_training, version_2_with_negative, doc_tokens, char_to_word_offset):
|
||||
qas_id = qa["id"]
|
||||
question_text = qa["question"]
|
||||
start_position = -1
|
||||
end_position = -1
|
||||
orig_answer_text = ""
|
||||
is_impossible = False
|
||||
if is_training:
|
||||
if version_2_with_negative:
|
||||
is_impossible = qa["is_impossible"]
|
||||
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||
raise ValueError("For training, each question should have exactly 1 answer.")
|
||||
if not is_impossible:
|
||||
answer = qa["answers"][0]
|
||||
orig_answer_text = answer["text"]
|
||||
answer_offset = answer["answer_start"]
|
||||
answer_length = len(orig_answer_text)
|
||||
start_position = char_to_word_offset[answer_offset]
|
||||
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||
cleaned_answer_text = " ".join(tokenization.whitespace_tokenize(orig_answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
return None
|
||||
example = SquadExample(
|
||||
qas_id=qas_id,
|
||||
question_text=question_text,
|
||||
doc_tokens=doc_tokens,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible)
|
||||
return example
|
||||
|
||||
|
||||
examples = []
|
||||
for entry in input_data:
|
||||
for paragraph in entry["paragraphs"]:
|
||||
paragraph_text = paragraph["context"]
|
||||
doc_tokens, char_to_word_offset = token_offset(paragraph_text)
|
||||
for qa in paragraph["qas"]:
|
||||
one_example = process_one_example(qa, is_training, version_2_with_negative,
|
||||
doc_tokens, char_to_word_offset)
|
||||
if one_example is not None:
|
||||
examples.append(one_example)
|
||||
return examples
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
end = doc_span.start + doc_span.length - 1
|
||||
if position < doc_span.start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_span.start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = span_index
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
return (input_start, input_end)
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride,
|
||||
max_query_length, is_training, output_fn, vocab_file):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
unique_id = 1000000000
|
||||
output = []
|
||||
for (example_index, example) in enumerate(examples):
|
||||
query_tokens = tokenizer.tokenize(example.question_text)
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
query_tokens = query_tokens[0:max_query_length]
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for (i, token) in enumerate(example.doc_tokens):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
tok_start_position = None
|
||||
tok_end_position = None
|
||||
if is_training and example.is_impossible:
|
||||
tok_start_position = -1
|
||||
tok_end_position = -1
|
||||
if is_training and not example.is_impossible:
|
||||
tok_start_position = orig_to_tok_index[example.start_position]
|
||||
if example.end_position < len(example.doc_tokens) - 1:
|
||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
|
||||
example.orig_answer_text)
|
||||
|
||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
||||
|
||||
# We can have documents that are longer than the maximum sequence length.
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, doc_stride)
|
||||
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenization.convert_tokens_to_ids(vocab_file, tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not example.is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
if is_training and example.is_impossible:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
|
||||
feature = InputFeatures(
|
||||
unique_id=unique_id,
|
||||
example_index=example_index,
|
||||
doc_span_index=doc_span_index,
|
||||
tokens=tokens,
|
||||
token_to_orig_map=token_to_orig_map,
|
||||
token_is_max_context=token_is_max_context,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=example.is_impossible)
|
||||
|
||||
# Run callback
|
||||
output.append(feature)
|
||||
unique_id += 1
|
||||
return output
|
|
@ -92,6 +92,11 @@ def create_classification_dataset(batch_size=1, repeat_count=1, assessment_metho
|
|||
return ds
|
||||
|
||||
|
||||
def generator_squad(data_features):
|
||||
for feature in data_features:
|
||||
yield (feature.input_ids, feature.input_mask, feature.segment_ids, feature.unique_id)
|
||||
|
||||
|
||||
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None,
|
||||
is_training=True, do_shuffle=True):
|
||||
"""create finetune or evaluation dataset"""
|
||||
|
@ -104,11 +109,12 @@ def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, sche
|
|||
ds = ds.map(operations=type_cast_op, input_columns="start_positions")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="end_positions")
|
||||
else:
|
||||
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"])
|
||||
ds = de.GeneratorDataset(generator_squad(data_file_path), shuffle=do_shuffle,
|
||||
column_names=["input_ids", "input_mask", "segment_ids", "unique_ids"])
|
||||
ds = ds.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="unique_ids")
|
||||
ds = ds.repeat(repeat_count)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
|
|
@ -22,7 +22,6 @@ import mindspore.common.dtype as mstype
|
|||
from .bert_model import BertConfig
|
||||
|
||||
optimizer_cfg = edict({
|
||||
'batch_size': 16,
|
||||
'optimizer': 'Lamb',
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 2e-5,
|
||||
|
|
|
@ -0,0 +1,258 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""get predictions for squad"""
|
||||
|
||||
import math
|
||||
import collections
|
||||
import six
|
||||
from src import tokenization
|
||||
|
||||
|
||||
def get_prelim_predictions(features, unique_id_to_result, n_best_size, max_answer_length):
|
||||
"""get prelim predictions"""
|
||||
_PrelimPrediction = collections.namedtuple(
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
if feature.unique_id not in unique_id_to_result:
|
||||
continue
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index >= len(feature.tokens):
|
||||
continue
|
||||
if end_index >= len(feature.tokens):
|
||||
continue
|
||||
if start_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if end_index not in feature.token_to_orig_map:
|
||||
continue
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_logit + x.end_logit),
|
||||
reverse=True)
|
||||
return prelim_predictions
|
||||
|
||||
|
||||
def get_nbest(prelim_predictions, features, example, n_best_size, do_lower_case):
|
||||
"""get nbest predictions"""
|
||||
_NbestPrediction = collections.namedtuple(
|
||||
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case)
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
else:
|
||||
final_text = ""
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||
|
||||
assert len(nbest) >= 1
|
||||
return nbest
|
||||
|
||||
|
||||
def get_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case):
|
||||
"""Get final predictions"""
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
all_predictions = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
prelim_predictions = get_prelim_predictions(features, unique_id_to_result, n_best_size, max_answer_length)
|
||||
nbest = get_nbest(prelim_predictions, features, example, n_best_size, do_lower_case)
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_logit + entry.end_logit)
|
||||
if not best_non_null_entry:
|
||||
if entry.text:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_logit"] = entry.start_logit
|
||||
output["end_logit"] = entry.end_logit
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
|
||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||
return all_predictions
|
||||
|
||||
|
||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case):
|
||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||
|
||||
all_predictions = get_predictions(all_examples, all_features, all_results,
|
||||
n_best_size, max_answer_length, do_lower_case)
|
||||
return all_predictions
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
def _strip_spaces(text):
|
||||
ns_chars = []
|
||||
ns_to_s_map = collections.OrderedDict()
|
||||
for (i, c) in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_chars)] = i
|
||||
ns_chars.append(c)
|
||||
ns_text = "".join(ns_chars)
|
||||
return (ns_text, ns_to_s_map)
|
||||
|
||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
||||
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
||||
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
||||
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
return orig_text
|
||||
|
||||
tok_s_to_ns_map = {}
|
||||
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
||||
tok_s_to_ns_map[tok_index] = i
|
||||
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
ns_start_position = tok_s_to_ns_map[start_position]
|
||||
if ns_start_position in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
|
||||
if orig_start_position is None:
|
||||
return orig_text
|
||||
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
ns_end_position = tok_s_to_ns_map[end_position]
|
||||
if ns_end_position in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
|
||||
if orig_end_position is None:
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def _get_best_indexes(logits, n_best_size):
|
||||
"""Get the n-best logits from a list."""
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
|
||||
best_indexes = []
|
||||
for (i, score) in enumerate(index_and_score):
|
||||
if i >= n_best_size:
|
||||
break
|
||||
best_indexes.append(score[0])
|
||||
return best_indexes
|
||||
|
||||
|
||||
def _compute_softmax(scores):
|
||||
"""Compute softmax probability over raw logits."""
|
||||
if not scores:
|
||||
return []
|
||||
|
||||
max_score = None
|
||||
for score in scores:
|
||||
if max_score is None or score > max_score:
|
||||
max_score = score
|
||||
|
||||
exp_scores = []
|
||||
total_sum = 0.0
|
||||
for score in scores:
|
||||
x = math.exp(score - max_score)
|
||||
exp_scores.append(x)
|
||||
total_sum += x
|
||||
|
||||
probs = []
|
||||
for score in exp_scores:
|
||||
probs.append(score / total_sum)
|
||||
return probs
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""evaluation script for SQuAD v1.1"""
|
||||
|
||||
from collections import Counter
|
||||
import string
|
||||
import re
|
||||
import json
|
||||
import sys
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
"""calculate f1 score"""
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
"""do evaluation"""
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article['paragraphs']:
|
||||
for qa in paragraph['qas']:
|
||||
total += 1
|
||||
if qa['id'] not in predictions:
|
||||
message = 'Unanswered question ' + qa['id'] + \
|
||||
' will receive score 0.'
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
ground_truths = list(map(lambda x: x['text'], qa['answers']))
|
||||
if not ground_truths:
|
||||
continue
|
||||
prediction = predictions[qa['id']]
|
||||
exact_match += metric_max_over_ground_truths(
|
||||
exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(
|
||||
f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
return {'exact_match': exact_match, 'f1': f1}
|
||||
|
||||
|
||||
def SQuad_postprocess(dataset_file, all_predictions, output_metrics="output.json"):
|
||||
with open(dataset_file) as ds:
|
||||
dataset_json = json.load(ds)
|
||||
dataset = dataset_json['data']
|
||||
re_json = evaluate(dataset, all_predictions)
|
||||
print(json.dumps(re_json))
|
||||
with open(output_metrics, 'w') as wr:
|
||||
wr.write(json.dumps(re_json))
|
Loading…
Reference in New Issue