!2514 add bert support for glue task

Merge pull request !2514 from yoonlee666/gelu2
This commit is contained in:
mindspore-ci-bot 2020-07-01 10:13:21 +08:00 committed by Gitee
commit 7b5b4837ff
4 changed files with 201 additions and 40 deletions

View File

@ -89,7 +89,7 @@ config.py:
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
finetune_config.py:
task task type: NER | SQUAD | OTHERS
task task type: SeqLabeling | Regression | Classification | COLA | SQUAD
num_labels number of labels to do classification
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
@ -101,7 +101,7 @@ finetune_config.py:
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
evaluation_config.py:
task task type: NER | SQUAD | OTHERS
task task type: SeqLabeling | Regression | Classification | COLA
num_labels number of labels to do classsification
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"

View File

@ -19,6 +19,7 @@ Bert evaluation script.
import os
import argparse
import math
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context
@ -29,19 +30,24 @@ import mindspore.dataset.transforms.c_transforms as C
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.evaluation_config import cfg, bert_net_cfg
from src.utils import BertNER, BertCLS
from src.utils import BertNER, BertCLS, BertReg
from src.CRF import postprocess
from src.cluener_evaluation import submit
from src.finetune_config import tag_to_index
class Accuracy():
'''
"""
calculate accuracy
'''
"""
def __init__(self):
self.acc_num = 0
self.total_num = 0
def update(self, logits, labels):
"""
Update accuracy
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
@ -50,18 +56,20 @@ class Accuracy():
self.total_num += len(labels)
print("=========================accuracy is ", self.acc_num / self.total_num)
class F1():
'''
"""
calculate F1 score
'''
"""
def __init__(self):
self.TP = 0
self.FP = 0
self.FN = 0
def update(self, logits, labels):
'''
"""
update F1 score
'''
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
if cfg.use_crf:
@ -80,10 +88,76 @@ class F1():
self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label)
class MCC():
"""
Calculate Matthews Correlation Coefficient.
"""
def __init__(self):
self.TP = 0
self.FP = 0
self.FN = 0
self.TN = 0
def update(self, logits, labels):
"""
Update MCC score
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
labels = labels.astype(np.bool)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1)
logit_id = logit_id.astype(np.bool)
ornot = logit_id ^ labels
self.TP += (~ornot & labels).sum()
self.FP += (ornot & ~labels).sum()
self.FN += (ornot & labels).sum()
self.TN += (~ornot & ~labels).sum()
class Spearman_Correlation():
"""
calculate Spearman Correlation coefficient
"""
def __init__(self):
self.label = []
self.logit = []
def update(self, logits, labels):
"""
Update Spearman Correlation
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.reshape(logits, -1)
self.label.append(labels)
self.logit.append(logits)
def cal(self):
"""
Calculate Spearman Correlation
"""
label = np.concatenate(self.label)
logit = np.concatenate(self.logit)
sort_label = label.argsort()[::-1]
sort_logit = logit.argsort()[::-1]
n = len(label)
d_acc = 0
for i in range(n):
d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0]
d_acc += d**2
ps = 1 - 6*d_acc/n/(n**2-1)
return ps
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get dataset
'''
"""
_ = distribute_file
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
@ -92,7 +166,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
if cfg.task == "Regression":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
@ -103,10 +181,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def bert_predict(Evaluation):
'''
"""
prediction function
'''
"""
target = args_opt.device_target
if target == "Ascend":
devid = int(os.getenv('DEVICE_ID'))
@ -131,15 +210,33 @@ def bert_predict(Evaluation):
return model, dataset
def test_eval():
'''
"""
evaluation function
'''
task_type = BertNER if cfg.task == "NER" else BertCLS
"""
if cfg.task == "SeqLabeling":
task_type = BertNER
elif cfg.task == "Regression":
task_type = BertReg
elif cfg.task == "Classification":
task_type = BertCLS
elif cfg.task == "COLA":
task_type = BertCLS
else:
raise ValueError("Task not supported.")
model, dataset = bert_predict(task_type)
if cfg.clue_benchmark:
submit(model, cfg.data_file, bert_net_cfg.seq_length)
else:
callback = F1() if cfg.task == "NER" else Accuracy()
if cfg.task == "SeqLabeling":
callback = F1()
elif cfg.task == "COLA":
callback = MCC()
elif cfg.task == "Regression":
callback = Spearman_Correlation()
else:
callback = Accuracy()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in dataset.create_dict_iterator():
input_data = []
@ -149,10 +246,19 @@ def test_eval():
logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
callback.update(logits, label_ids)
print("==============================================================")
if cfg.task == "NER":
if cfg.task == "SeqLabeling":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN)))
elif cfg.task == "COLA":
TP = callback.TP
TN = callback.TN
FP = callback.FP
FN = callback.FN
mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
print("MCC: {:.6f}".format(mcc))
elif cfg.task == "Regression":
print("Spearman Correlation is {:.6f}".format(callback.cal()[0]))
else:
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num))

View File

@ -13,13 +13,13 @@
# limitations under the License.
# ============================================================================
'''
"""
Bert finetune script.
'''
"""
import os
import argparse
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell, BertReg
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype
from mindspore import context
@ -34,14 +34,14 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net
class LossCallBack(Callback):
'''
"""
Monitor the loss in training.
If the loss is NAN or INF, terminate training.
Note:
If per_print_times is 0, do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
'''
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
@ -56,16 +56,20 @@ class LossCallBack(Callback):
f.write("\n")
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get dataset
'''
"""
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "label_ids"])
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
if cfg.task == "Regression":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count)
# apply shuffle operation
@ -77,9 +81,9 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
return ds
def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get SQuAD dataset
'''
"""
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
"start_positions", "end_positions",
"unique_ids", "is_impossible"])
@ -97,9 +101,9 @@ def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
return ds
def test_train():
'''
"""
finetune function
'''
"""
target = args_opt.device_target
if target == "Ascend":
devid = int(os.getenv('DEVICE_ID'))
@ -113,7 +117,7 @@ def test_train():
raise Exception("Target error, GPU or Ascend is supported.")
#BertCLSTrain for classification
#BertNERTrain for sequence labeling
if cfg.task == 'NER':
if cfg.task == 'SeqLabeling':
if cfg.use_crf:
netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True,
tag_to_index=tag_to_index, dropout_prob=0.1)
@ -121,8 +125,12 @@ def test_train():
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
elif cfg.task == 'SQUAD':
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
else:
elif cfg.task == 'Regression':
netwithloss = BertReg(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
elif cfg.task == 'Classification':
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
else:
raise Exception("Target error, GPU or Ascend is supported.")
if cfg.task == 'SQUAD':
dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
else:

View File

@ -13,9 +13,9 @@
# limitations under the License.
# ============================================================================
'''
"""
Functional Cells used in Bert finetune and evaluation.
'''
"""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
@ -245,6 +245,32 @@ class BertSquadCell(nn.Cell):
ret = (loss, cond)
return F.depend(ret, succ)
class BertRegressionModel(nn.Cell):
"""
Bert finetune model for regression task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertRegressionModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dropout = nn.Dropout(1 - dropout_prob)
self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init,
has_bias=True).to_float(mstype.float16)
def construct(self, input_ids, input_mask, token_type_id):
_, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls)
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
return logits
class BertCLSModel(nn.Cell):
"""
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
@ -274,9 +300,9 @@ class BertCLSModel(nn.Cell):
return log_probs
class BertSquadModel(nn.Cell):
'''
This class is responsible for SQuAD
'''
"""
Bert finetune model for SQuAD v1.1 task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquadModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
@ -401,9 +427,9 @@ class BertNER(nn.Cell):
return loss
class BertSquad(nn.Cell):
'''
"""
Train interface for SQuAD finetuning task.
'''
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquad, self).__init__()
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
@ -432,3 +458,24 @@ class BertSquad(nn.Cell):
end_logits = self.squeeze(logits[:, :, 1:2])
total_loss = (unique_id, start_logits, end_logits)
return total_loss
class BertReg(nn.Cell):
"""
Bert finetune model with loss for regression task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertReg, self).__init__()
self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
self.loss = nn.MSELoss()
self.is_training = is_training
self.sigmoid = P.Sigmoid()
self.cast = P.Cast()
self.mul = P.Mul()
def construct(self, input_ids, input_mask, token_type_id, labels):
logits = self.bert(input_ids, input_mask, token_type_id)
if self.is_training:
loss = self.loss(logits, labels)
else:
loss = logits
return loss