forked from mindspore-Ecosystem/mindspore
!1596 add bert SQuAD finetune and eval code in bert example
Merge pull request !1596 from yoonlee666/edit-example
This commit is contained in:
commit
0b3da2c787
|
@ -5,7 +5,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
|
|||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
|
||||
- Download the CLUE dataset for fine-tuning and evaluation.
|
||||
- Download the CLUE/SQuAD v1.1 dataset for fine-tuning and evaluation.
|
||||
> Notes:
|
||||
If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
|
||||
|
||||
|
@ -36,11 +36,20 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
|
|||
### Evaluation
|
||||
- Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path.
|
||||
|
||||
- Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
|
||||
- NER: Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
|
||||
|
||||
```bash
|
||||
python evaluation.py
|
||||
```
|
||||
- SQuAD v1.1: Run `squadeval.py` and `SQuAD_postprocess.py` for evaluation of BERT-base and BERT-NEZHA model.
|
||||
|
||||
```bash
|
||||
python squadeval.py
|
||||
```
|
||||
|
||||
```bash
|
||||
python SQuAD_postprocess.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
### Pre-Training
|
||||
|
@ -80,7 +89,7 @@ config.py:
|
|||
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
|
||||
finetune_config.py:
|
||||
task task type: NER | XNLI | LCQMC | SENTIi | OTHERS
|
||||
task task type: NER | SQUAD | OTHERS
|
||||
num_labels number of labels to do classification
|
||||
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
|
||||
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
|
||||
|
@ -92,7 +101,7 @@ finetune_config.py:
|
|||
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
|
||||
evaluation_config.py:
|
||||
task task type: NER | XNLI | LCQMC | SENTI | OTHERS
|
||||
task task type: NER | SQUAD | OTHERS
|
||||
num_labels number of labels to do classsification
|
||||
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
|
||||
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
|
||||
|
|
|
@ -18,10 +18,9 @@ Bert finetune script.
|
|||
'''
|
||||
|
||||
import os
|
||||
from src.utils import BertFinetuneCell, BertCLS, BertNER
|
||||
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
|
||||
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
@ -58,8 +57,6 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
|||
'''
|
||||
get dataset
|
||||
'''
|
||||
_ = distribute_file
|
||||
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
|
||||
"segment_ids", "label_ids"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
@ -77,10 +74,29 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
|||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
||||
'''
|
||||
get SQuAD dataset
|
||||
'''
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"start_positions", "end_positions",
|
||||
"unique_ids", "is_impossible"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="start_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="end_positions", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
def test_train():
|
||||
'''
|
||||
finetune function
|
||||
pytest -s finetune.py::test_train
|
||||
'''
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
|
@ -92,9 +108,14 @@ def test_train():
|
|||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
else:
|
||||
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
|
||||
elif cfg.task == 'SQUAD':
|
||||
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
|
||||
else:
|
||||
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
|
||||
dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
|
||||
if cfg.task == 'SQUAD':
|
||||
dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
|
||||
else:
|
||||
dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
|
||||
# optimizer
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
if cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
|
@ -103,13 +124,14 @@ def test_train():
|
|||
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
|
||||
power=cfg.AdamWeightDecayDynamicLR.power,
|
||||
warmup_steps=steps_per_epoch,
|
||||
warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
|
||||
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
|
||||
eps=cfg.AdamWeightDecayDynamicLR.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num,
|
||||
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter)
|
||||
power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay,
|
||||
warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
|
@ -122,10 +144,12 @@ def test_train():
|
|||
load_param_into_net(netwithloss, param_dict)
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
if cfg.task == 'SQUAD':
|
||||
netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
|
||||
D.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_train()
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation script for SQuAD task"""
|
||||
|
||||
import os
|
||||
import collections
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src import tokenization
|
||||
from src.evaluation_config import cfg, bert_net_cfg
|
||||
from src.utils import BertSquad
|
||||
from src.create_squad_data import read_squad_examples, convert_examples_to_features
|
||||
from src.run_squad import write_predictions
|
||||
|
||||
def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
||||
"""get SQuAD dataset from tfrecord"""
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
|
||||
"segment_ids", "unique_ids"],
|
||||
shuffle=False)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
def test_eval():
|
||||
"""Evaluation function for SQuAD task"""
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True)
|
||||
input_file = "dataset/v1.1/dev-v1.1.json"
|
||||
eval_examples = read_squad_examples(input_file, False)
|
||||
eval_features = convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=384,
|
||||
doc_stride=128,
|
||||
max_query_length=64,
|
||||
is_training=False,
|
||||
output_fn=None,
|
||||
verbose_logging=False)
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
|
||||
dataset = get_squad_dataset(bert_net_cfg.batch_size, 1)
|
||||
net = BertSquad(bert_net_cfg, False, 2)
|
||||
net.set_train(False)
|
||||
param_dict = load_checkpoint(cfg.finetune_ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
model = Model(net)
|
||||
output = []
|
||||
RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(Tensor(data[i]))
|
||||
input_ids, input_mask, segment_ids, unique_ids = input_data
|
||||
start_positions = Tensor([1], mstype.float32)
|
||||
end_positions = Tensor([1], mstype.float32)
|
||||
is_impossible = Tensor([1], mstype.float32)
|
||||
logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
|
||||
end_positions, unique_ids, is_impossible)
|
||||
ids = logits[0].asnumpy()
|
||||
start = logits[1].asnumpy()
|
||||
end = logits[2].asnumpy()
|
||||
|
||||
for i in range(bert_net_cfg.batch_size):
|
||||
unique_id = int(ids[i])
|
||||
start_logits = [float(x) for x in start[i].flat]
|
||||
end_logits = [float(x) for x in end[i].flat]
|
||||
output.append(RawResult(
|
||||
unique_id=unique_id,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json",
|
||||
None, None, False, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_eval()
|
|
@ -43,6 +43,7 @@ cfg = edict({
|
|||
'start_learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: False,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||
from .bert_model import BertModel
|
||||
from .bert_for_pre_training import clip_grad
|
||||
from .CRF import CRF
|
||||
|
||||
|
@ -131,6 +131,98 @@ class BertFinetuneCell(nn.Cell):
|
|||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertSquadCell(nn.Cell):
|
||||
"""
|
||||
specifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertSquadCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible,
|
||||
sens=None):
|
||||
weights = self.weights
|
||||
init = self.alloc_status()
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
start_position,
|
||||
end_position,
|
||||
unique_id,
|
||||
is_impossible,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
clear_before_grad = self.clear_before_grad(init)
|
||||
F.control_depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
flag = self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
F.control_depend(grads, flag)
|
||||
F.control_depend(flag, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertCLSModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
|
||||
|
@ -159,6 +251,30 @@ class BertCLSModel(nn.Cell):
|
|||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
class BertSquadModel(nn.Cell):
|
||||
'''
|
||||
This class is responsible for SQuAD
|
||||
'''
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquadModel, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.num_labels = num_labels
|
||||
self.dtype = config.dtype
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
|
||||
batch_size, seq_length, hidden_size = P.Shape()(sequence_output)
|
||||
sequence = P.Reshape()(sequence_output, (-1, hidden_size))
|
||||
logits = self.dense1(sequence)
|
||||
logits = P.Cast()(logits, self.dtype)
|
||||
logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels))
|
||||
logits = self.log_softmax(logits)
|
||||
return logits
|
||||
|
||||
class BertNERModel(nn.Cell):
|
||||
"""
|
||||
|
@ -261,3 +377,36 @@ class BertNER(nn.Cell):
|
|||
else:
|
||||
loss = self.loss(logits, label_ids, self.num_labels)
|
||||
return loss
|
||||
|
||||
class BertSquad(nn.Cell):
|
||||
'''
|
||||
Train interface for SQuAD finetuning task.
|
||||
'''
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertSquad, self).__init__()
|
||||
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.seq_length = config.seq_length
|
||||
self.is_training = is_training
|
||||
self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num')
|
||||
self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num')
|
||||
self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num')
|
||||
self.sum = P.ReduceSum()
|
||||
self.equal = P.Equal()
|
||||
self.argmax = P.ArgMaxWithValue(axis=1)
|
||||
self.squeeze = P.Squeeze(axis=-1)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.is_training:
|
||||
unstacked_logits_0 = self.squeeze(logits[:, :, 0:1])
|
||||
unstacked_logits_1 = self.squeeze(logits[:, :, 1:2])
|
||||
start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length)
|
||||
end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length)
|
||||
total_loss = (start_loss + end_loss) / 2.0
|
||||
else:
|
||||
start_logits = self.squeeze(logits[:, :, 0:1])
|
||||
end_logits = self.squeeze(logits[:, :, 1:2])
|
||||
total_loss = (unique_id, start_logits, end_logits)
|
||||
return total_loss
|
||||
|
|
Loading…
Reference in New Issue