!1596 add bert SQuAD finetune and eval code in bert example

Merge pull request !1596 from yoonlee666/edit-example
This commit is contained in:
mindspore-ci-bot 2020-05-29 14:59:17 +08:00 committed by Gitee
commit 0b3da2c787
5 changed files with 297 additions and 15 deletions

View File

@ -5,7 +5,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
- Download the CLUE dataset for fine-tuning and evaluation.
- Download the CLUE/SQuAD v1.1 dataset for fine-tuning and evaluation.
> Notes:
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
@ -36,11 +36,20 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
### Evaluation
- Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path.
- Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
- NER: Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
```bash
python evaluation.py
```
- SQuAD v1.1: Run `squadeval.py` and `SQuAD_postprocess.py` for evaluation of BERT-base and BERT-NEZHA model.
```bash
python squadeval.py
```
```bash
python SQuAD_postprocess.py
```
## Usage
### Pre-Training
@ -80,7 +89,7 @@ config.py:
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
finetune_config.py:
task task type: NER | XNLI | LCQMC | SENTIi | OTHERS
task task type: NER | SQUAD | OTHERS
num_labels number of labels to do classification
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
@ -92,7 +101,7 @@ finetune_config.py:
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
evaluation_config.py:
task task type: NER | XNLI | LCQMC | SENTI | OTHERS
task task type: NER | SQUAD | OTHERS
num_labels number of labels to do classsification
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"

View File

@ -18,10 +18,9 @@ Bert finetune script.
'''
import os
from src.utils import BertFinetuneCell, BertCLS, BertNER
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
@ -58,8 +57,6 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
get dataset
'''
_ = distribute_file
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "label_ids"])
type_cast_op = C.TypeCast(mstype.int32)
@ -77,10 +74,29 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
get SQuAD dataset
'''
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
"start_positions", "end_positions",
"unique_ids", "is_impossible"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="start_positions", operations=type_cast_op)
ds = ds.map(input_columns="end_positions", operations=type_cast_op)
ds = ds.repeat(repeat_count)
buffer_size = 960
ds = ds.shuffle(buffer_size=buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def test_train():
'''
finetune function
pytest -s finetune.py::test_train
'''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
@ -92,8 +108,13 @@ def test_train():
tag_to_index=tag_to_index, dropout_prob=0.1)
else:
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
elif cfg.task == 'SQUAD':
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
else:
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
if cfg.task == 'SQUAD':
dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
else:
dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
# optimizer
steps_per_epoch = dataset.get_dataset_size()
@ -103,13 +124,14 @@ def test_train():
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
power=cfg.AdamWeightDecayDynamicLR.power,
warmup_steps=steps_per_epoch,
warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
eps=cfg.AdamWeightDecayDynamicLR.eps)
elif cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter)
power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay,
warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
@ -122,10 +144,12 @@ def test_train():
load_param_into_net(netwithloss, param_dict)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
if cfg.task == 'SQUAD':
netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
model = Model(netwithgrads)
model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
D.release()
if __name__ == "__main__":
test_train()

View File

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation script for SQuAD task"""
import os
import collections
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src import tokenization
from src.evaluation_config import cfg, bert_net_cfg
from src.utils import BertSquad
from src.create_squad_data import read_squad_examples, convert_examples_to_features
from src.run_squad import write_predictions
def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
"""get SQuAD dataset from tfrecord"""
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "unique_ids"],
shuffle=False)
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.repeat(repeat_count)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def test_eval():
"""Evaluation function for SQuAD task"""
tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True)
input_file = "dataset/v1.1/dev-v1.1.json"
eval_examples = read_squad_examples(input_file, False)
eval_features = convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=384,
doc_stride=128,
max_query_length=64,
is_training=False,
output_fn=None,
verbose_logging=False)
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
dataset = get_squad_dataset(bert_net_cfg.batch_size, 1)
net = BertSquad(bert_net_cfg, False, 2)
net.set_train(False)
param_dict = load_checkpoint(cfg.finetune_ckpt)
load_param_into_net(net, param_dict)
model = Model(net)
output = []
RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(Tensor(data[i]))
input_ids, input_mask, segment_ids, unique_ids = input_data
start_positions = Tensor([1], mstype.float32)
end_positions = Tensor([1], mstype.float32)
is_impossible = Tensor([1], mstype.float32)
logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
end_positions, unique_ids, is_impossible)
ids = logits[0].asnumpy()
start = logits[1].asnumpy()
end = logits[2].asnumpy()
for i in range(bert_net_cfg.batch_size):
unique_id = int(ids[i])
start_logits = [float(x) for x in start[i].flat]
end_logits = [float(x) for x in end[i].flat]
output.append(RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json",
None, None, False, False)
if __name__ == "__main__":
test_eval()

View File

@ -43,6 +43,7 @@ cfg = edict({
'start_learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 0.01,
'decay_filter': lambda x: False,
}),
'Momentum': edict({

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
from .bert_model import BertModel
from .bert_for_pre_training import clip_grad
from .CRF import CRF
@ -131,6 +131,98 @@ class BertFinetuneCell(nn.Cell):
ret = (loss, cond)
return F.depend(ret, succ)
class BertSquadCell(nn.Cell):
"""
specifically defined for finetuning where only four inputs tensor are needed.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
def construct(self,
input_ids,
input_mask,
token_type_id,
start_position,
end_position,
unique_id,
is_impossible,
sens=None):
weights = self.weights
init = self.alloc_status()
loss = self.network(input_ids,
input_mask,
token_type_id,
start_position,
end_position,
unique_id,
is_impossible)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
start_position,
end_position,
unique_id,
is_impossible,
self.cast(scaling_sens,
mstype.float32))
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
grads = self.grad_reducer(grads)
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
return F.depend(ret, succ)
class BertCLSModel(nn.Cell):
"""
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
@ -159,6 +251,30 @@ class BertCLSModel(nn.Cell):
log_probs = self.log_softmax(logits)
return log_probs
class BertSquadModel(nn.Cell):
'''
This class is responsible for SQuAD
'''
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquadModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.weight_init = TruncatedNormal(config.initializer_range)
self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.num_labels = num_labels
self.dtype = config.dtype
self.log_softmax = P.LogSoftmax(axis=1)
self.is_training = is_training
def construct(self, input_ids, input_mask, token_type_id):
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
batch_size, seq_length, hidden_size = P.Shape()(sequence_output)
sequence = P.Reshape()(sequence_output, (-1, hidden_size))
logits = self.dense1(sequence)
logits = P.Cast()(logits, self.dtype)
logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels))
logits = self.log_softmax(logits)
return logits
class BertNERModel(nn.Cell):
"""
@ -261,3 +377,36 @@ class BertNER(nn.Cell):
else:
loss = self.loss(logits, label_ids, self.num_labels)
return loss
class BertSquad(nn.Cell):
'''
Train interface for SQuAD finetuning task.
'''
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquad, self).__init__()
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels
self.seq_length = config.seq_length
self.is_training = is_training
self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num')
self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num')
self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num')
self.sum = P.ReduceSum()
self.equal = P.Equal()
self.argmax = P.ArgMaxWithValue(axis=1)
self.squeeze = P.Squeeze(axis=-1)
def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible):
logits = self.bert(input_ids, input_mask, token_type_id)
if self.is_training:
unstacked_logits_0 = self.squeeze(logits[:, :, 0:1])
unstacked_logits_1 = self.squeeze(logits[:, :, 1:2])
start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length)
end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length)
total_loss = (start_loss + end_loss) / 2.0
else:
start_logits = self.squeeze(logits[:, :, 0:1])
end_logits = self.squeeze(logits[:, :, 1:2])
total_loss = (unique_id, start_logits, end_logits)
return total_loss