From 7a8ee4725b025d824e3fc3c75da105bc1e192010 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Tue, 30 Jun 2020 21:18:11 +0800 Subject: [PATCH] add bert support for glue --- model_zoo/bert/README.md | 4 +- model_zoo/bert/evaluation.py | 140 ++++++++++++++++++++++++++++++----- model_zoo/bert/finetune.py | 36 +++++---- model_zoo/bert/src/utils.py | 61 +++++++++++++-- 4 files changed, 201 insertions(+), 40 deletions(-) diff --git a/model_zoo/bert/README.md b/model_zoo/bert/README.md index 1fa755e72f8..3ed2bf67835 100644 --- a/model_zoo/bert/README.md +++ b/model_zoo/bert/README.md @@ -89,7 +89,7 @@ config.py: optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" finetune_config.py: - task task type: NER | SQUAD | OTHERS + task task type: SeqLabeling | Regression | Classification | COLA | SQUAD num_labels number of labels to do classification data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" @@ -101,7 +101,7 @@ finetune_config.py: optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" evaluation_config.py: - task task type: NER | SQUAD | OTHERS + task task type: SeqLabeling | Regression | Classification | COLA num_labels number of labels to do classsification data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" diff --git a/model_zoo/bert/evaluation.py b/model_zoo/bert/evaluation.py index 4877b60cef3..4e8b2a3aea8 100644 --- a/model_zoo/bert/evaluation.py +++ b/model_zoo/bert/evaluation.py @@ -19,6 +19,7 @@ Bert evaluation script. import os import argparse +import math import numpy as np import mindspore.common.dtype as mstype from mindspore import context @@ -29,19 +30,24 @@ import mindspore.dataset.transforms.c_transforms as C from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.evaluation_config import cfg, bert_net_cfg -from src.utils import BertNER, BertCLS +from src.utils import BertNER, BertCLS, BertReg from src.CRF import postprocess from src.cluener_evaluation import submit from src.finetune_config import tag_to_index + class Accuracy(): - ''' + """ calculate accuracy - ''' + """ def __init__(self): self.acc_num = 0 self.total_num = 0 + def update(self, logits, labels): + """ + Update accuracy + """ labels = labels.asnumpy() labels = np.reshape(labels, -1) logits = logits.asnumpy() @@ -50,18 +56,20 @@ class Accuracy(): self.total_num += len(labels) print("=========================accuracy is ", self.acc_num / self.total_num) + class F1(): - ''' + """ calculate F1 score - ''' + """ def __init__(self): self.TP = 0 self.FP = 0 self.FN = 0 + def update(self, logits, labels): - ''' + """ update F1 score - ''' + """ labels = labels.asnumpy() labels = np.reshape(labels, -1) if cfg.use_crf: @@ -80,10 +88,76 @@ class F1(): self.FP += np.sum(pos_eva&(~pos_label)) self.FN += np.sum((~pos_eva)&pos_label) + +class MCC(): + """ + Calculate Matthews Correlation Coefficient. + """ + def __init__(self): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.TN = 0 + + def update(self, logits, labels): + """ + Update MCC score + """ + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + labels = labels.astype(np.bool) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + logit_id = logit_id.astype(np.bool) + ornot = logit_id ^ labels + + self.TP += (~ornot & labels).sum() + self.FP += (ornot & ~labels).sum() + self.FN += (ornot & labels).sum() + self.TN += (~ornot & ~labels).sum() + + +class Spearman_Correlation(): + """ + calculate Spearman Correlation coefficient + """ + def __init__(self): + self.label = [] + self.logit = [] + + def update(self, logits, labels): + """ + Update Spearman Correlation + """ + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logits = np.reshape(logits, -1) + self.label.append(labels) + self.logit.append(logits) + + def cal(self): + """ + Calculate Spearman Correlation + """ + label = np.concatenate(self.label) + logit = np.concatenate(self.logit) + sort_label = label.argsort()[::-1] + sort_logit = logit.argsort()[::-1] + n = len(label) + d_acc = 0 + for i in range(n): + d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] + d_acc += d**2 + ps = 1 - 6*d_acc/n/(n**2-1) + return ps + + def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): - ''' + """ get dataset - ''' + """ _ = distribute_file ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", @@ -92,7 +166,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.map(input_columns="label_ids", operations=type_cast_op) + if cfg.task == "Regression": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.repeat(repeat_count) # apply shuffle operation @@ -103,10 +181,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): ds = ds.batch(batch_size, drop_remainder=True) return ds + def bert_predict(Evaluation): - ''' + """ prediction function - ''' + """ target = args_opt.device_target if target == "Ascend": devid = int(os.getenv('DEVICE_ID')) @@ -131,15 +210,33 @@ def bert_predict(Evaluation): return model, dataset def test_eval(): - ''' + """ evaluation function - ''' - task_type = BertNER if cfg.task == "NER" else BertCLS + """ + if cfg.task == "SeqLabeling": + task_type = BertNER + elif cfg.task == "Regression": + task_type = BertReg + elif cfg.task == "Classification": + task_type = BertCLS + elif cfg.task == "COLA": + task_type = BertCLS + else: + raise ValueError("Task not supported.") model, dataset = bert_predict(task_type) + if cfg.clue_benchmark: submit(model, cfg.data_file, bert_net_cfg.seq_length) else: - callback = F1() if cfg.task == "NER" else Accuracy() + if cfg.task == "SeqLabeling": + callback = F1() + elif cfg.task == "COLA": + callback = MCC() + elif cfg.task == "Regression": + callback = Spearman_Correlation() + else: + callback = Accuracy() + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in dataset.create_dict_iterator(): input_data = [] @@ -149,10 +246,19 @@ def test_eval(): logits = model.predict(input_ids, input_mask, token_type_id, label_ids) callback.update(logits, label_ids) print("==============================================================") - if cfg.task == "NER": + if cfg.task == "SeqLabeling": print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN))) + elif cfg.task == "COLA": + TP = callback.TP + TN = callback.TN + FP = callback.FP + FN = callback.FN + mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)) + print("MCC: {:.6f}".format(mcc)) + elif cfg.task == "Regression": + print("Spearman Correlation is {:.6f}".format(callback.cal()[0])) else: print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, callback.acc_num / callback.total_num)) diff --git a/model_zoo/bert/finetune.py b/model_zoo/bert/finetune.py index df16e3c91d9..eb1880b9cc8 100644 --- a/model_zoo/bert/finetune.py +++ b/model_zoo/bert/finetune.py @@ -13,13 +13,13 @@ # limitations under the License. # ============================================================================ -''' +""" Bert finetune script. -''' +""" import os import argparse -from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell +from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell, BertReg from src.finetune_config import cfg, bert_net_cfg, tag_to_index import mindspore.common.dtype as mstype from mindspore import context @@ -34,14 +34,14 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.serialization import load_checkpoint, load_param_into_net class LossCallBack(Callback): - ''' + """ Monitor the loss in training. If the loss is NAN or INF, terminate training. Note: If per_print_times is 0, do not print loss. Args: per_print_times (int): Print loss every times. Default: 1. - ''' + """ def __init__(self, per_print_times=1): super(LossCallBack, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: @@ -56,16 +56,20 @@ class LossCallBack(Callback): f.write("\n") def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): - ''' + """ get dataset - ''' + """ ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.map(input_columns="label_ids", operations=type_cast_op) + if cfg.task == "Regression": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.repeat(repeat_count) # apply shuffle operation @@ -77,9 +81,9 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): return ds def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): - ''' + """ get SQuAD dataset - ''' + """ ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", "start_positions", "end_positions", "unique_ids", "is_impossible"]) @@ -97,9 +101,9 @@ def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): return ds def test_train(): - ''' + """ finetune function - ''' + """ target = args_opt.device_target if target == "Ascend": devid = int(os.getenv('DEVICE_ID')) @@ -113,7 +117,7 @@ def test_train(): raise Exception("Target error, GPU or Ascend is supported.") #BertCLSTrain for classification #BertNERTrain for sequence labeling - if cfg.task == 'NER': + if cfg.task == 'SeqLabeling': if cfg.use_crf: netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, tag_to_index=tag_to_index, dropout_prob=0.1) @@ -121,8 +125,12 @@ def test_train(): netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) elif cfg.task == 'SQUAD': netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) - else: + elif cfg.task == 'Regression': + netwithloss = BertReg(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) + elif cfg.task == 'Classification': netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) + else: + raise Exception("Target error, GPU or Ascend is supported.") if cfg.task == 'SQUAD': dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) else: diff --git a/model_zoo/bert/src/utils.py b/model_zoo/bert/src/utils.py index 9b5383877bc..ec5651b2053 100644 --- a/model_zoo/bert/src/utils.py +++ b/model_zoo/bert/src/utils.py @@ -13,9 +13,9 @@ # limitations under the License. # ============================================================================ -''' +""" Functional Cells used in Bert finetune and evaluation. -''' +""" import mindspore.nn as nn from mindspore.common.initializer import TruncatedNormal @@ -245,6 +245,32 @@ class BertSquadCell(nn.Cell): ret = (loss, cond) return F.depend(ret, succ) + +class BertRegressionModel(nn.Cell): + """ + Bert finetune model for regression task + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertRegressionModel, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dropout = nn.Dropout(1 - dropout_prob) + self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init, + has_bias=True).to_float(mstype.float16) + + def construct(self, input_ids, input_mask, token_type_id): + _, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask) + cls = self.cast(pooled_output, self.dtype) + cls = self.dropout(cls) + logits = self.dense_1(cls) + logits = self.cast(logits, self.dtype) + return logits + + class BertCLSModel(nn.Cell): """ This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), @@ -274,9 +300,9 @@ class BertCLSModel(nn.Cell): return log_probs class BertSquadModel(nn.Cell): - ''' - This class is responsible for SQuAD - ''' + """ + Bert finetune model for SQuAD v1.1 task + """ def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): super(BertSquadModel, self).__init__() self.bert = BertModel(config, is_training, use_one_hot_embeddings) @@ -401,9 +427,9 @@ class BertNER(nn.Cell): return loss class BertSquad(nn.Cell): - ''' + """ Train interface for SQuAD finetuning task. - ''' + """ def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): super(BertSquad, self).__init__() self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) @@ -432,3 +458,24 @@ class BertSquad(nn.Cell): end_logits = self.squeeze(logits[:, :, 1:2]) total_loss = (unique_id, start_logits, end_logits) return total_loss + + +class BertReg(nn.Cell): + """ + Bert finetune model with loss for regression task + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertReg, self).__init__() + self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) + self.loss = nn.MSELoss() + self.is_training = is_training + self.sigmoid = P.Sigmoid() + self.cast = P.Cast() + self.mul = P.Mul() + def construct(self, input_ids, input_mask, token_type_id, labels): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.is_training: + loss = self.loss(logits, labels) + else: + loss = logits + return loss