From cc18b206c90fe6cbe65f1921df710316a93b575a Mon Sep 17 00:00:00 2001
From: shibeiji <shibeiji@huawei.com>
Date: Fri, 29 Jan 2021 17:58:39 +0800
Subject: [PATCH] bert ner for adaption of MSRA dataset

---
 model_zoo/official/nlp/bert/README.md         |  8 +-
 model_zoo/official/nlp/bert/README_CN.md      |  8 +-
 model_zoo/official/nlp/bert/run_ner.py        | 28 ++++--
 .../official/nlp/bert/scripts/run_ner.sh      |  4 +-
 .../nlp/bert/src/assessment_method.py         | 93 +++++-------------
 model_zoo/official/nlp/bert/src/dataset.py    |  4 +-
 .../nlp/bert/src/finetune_data_preprocess.py  | 96 ++++++++++++-------
 model_zoo/official/nlp/bert/src/utils.py      |  5 +-
 8 files changed, 115 insertions(+), 131 deletions(-)

diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md
index 9e202cc1859..910a7e6d024 100644
--- a/model_zoo/official/nlp/bert/README.md
+++ b/model_zoo/official/nlp/bert/README.md
@@ -551,7 +551,7 @@ F1 0.920507
 For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below:
 
 ```python
-python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
+python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0
 ```
 
 For finetune and evaluation, just do
@@ -562,12 +562,10 @@ bash scripts/ner.sh
 
 The command above will run in the background, you can view training logs in ner_log.txt.
 
-If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches:
+If you choose MF1(F1 score with multi-labels) as assessment method, the result will be as follows if evaluation is done after finetuning 10 epoches:
 
 ```text
-Precision 0.953826
-Recall 0.957749
-F1 0.955784
+F1 0.931243
 ```
 
 #### evaluation on squad v1.1 dataset when running on Ascend
diff --git a/model_zoo/official/nlp/bert/README_CN.md b/model_zoo/official/nlp/bert/README_CN.md
index 9972bf61842..db407721e7c 100644
--- a/model_zoo/official/nlp/bert/README_CN.md
+++ b/model_zoo/official/nlp/bert/README_CN.md
@@ -515,7 +515,7 @@ F1 0.920507
 您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能:
 
 ```python
-python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
+python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0
 ```
 
 此后,您可以进行微调再训练和推理流程,
@@ -525,12 +525,10 @@ bash scripts/ner.sh
 ```
 
 以上命令后台运行,您可以在ner_log.txt中查看训练日志。
-如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果:
+如您选择MF1(多标签的F1得分)作为评估方法,在微调训练10个epoch之后进行推理,可得到如下结果:
 
 ```text
-Precision 0.953826
-Recall 0.957749
-F1 0.955784
+F1 0.931243
 ```
 
 #### Ascend处理器上运行后评估squad v1.1数据集
diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py
index ce1c3188ba4..c0324a87dc5 100644
--- a/model_zoo/official/nlp/bert/run_ner.py
+++ b/model_zoo/official/nlp/bert/run_ner.py
@@ -19,11 +19,12 @@ Bert finetune and evaluation script.
 
 import os
 import argparse
+import time
 from src.bert_for_finetune import BertFinetuneCell, BertNER
 from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
 from src.dataset import create_ner_dataset
 from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
-from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1
+from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
 import mindspore.common.dtype as mstype
 from mindspore import context
 from mindspore import log as logger
@@ -79,17 +80,22 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
     netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
     model = Model(netwithgrads)
     callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
+    train_begin = time.time()
     model.train(epoch_num, dataset, callbacks=callbacks)
+    train_end = time.time()
+    print("latency: {:.6f} s".format(train_end - train_begin))
 
 def eval_result_print(assessment_method="accuracy", callback=None):
     """print eval result"""
     if assessment_method == "accuracy":
         print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
                                                                   callback.acc_num / callback.total_num))
-    elif assessment_method in ("f1", "spanf1"):
+    elif assessment_method == "bf1":
         print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
         print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
         print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
+    elif assessment_method == "mf1":
+        print("F1 {:.6f} ".format(callback.eval()[0]))
     elif assessment_method == "mcc":
         print("MCC {:.6f} ".format(callback.cal()))
     elif assessment_method == "spearman_correlation":
@@ -116,10 +122,10 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met
     else:
         if assessment_method == "accuracy":
             callback = Accuracy()
-        elif assessment_method == "f1":
+        elif assessment_method == "bf1":
             callback = F1((use_crf.lower() == "true"), num_class)
-        elif assessment_method == "spanf1":
-            callback = SpanF1((use_crf.lower() == "true"), tag_to_index)
+        elif assessment_method == "mf1":
+            callback = F1((use_crf.lower() == "true"), num_labels=num_class, mode="MultiLabel")
         elif assessment_method == "mcc":
             callback = MCC()
         elif assessment_method == "spearman_correlation":
@@ -145,8 +151,8 @@ def parse_args():
     parser = argparse.ArgumentParser(description="run ner")
     parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
                         help="Device type, default is Ascend")
-    parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"],
-                        help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1")
+    parser.add_argument("--assessment_method", type=str, default="BF1", choices=["BF1", "clue_benchmark", "MF1"],
+                        help="assessment_method include: [BF1, clue_benchmark, MF1], default is BF1")
     parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"],
                         help="Eable train, default is false")
     parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
@@ -231,6 +237,12 @@ def run_ner():
                                 assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
                                 schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
                                 do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
+        print("==============================================================")
+        print("processor_name: {}".format(args_opt.device_target))
+        print("test_name: BERT Finetune Training")
+        print("model_name: {}".format("BERT+MLP+CRF" if args_opt.use_crf.lower() == "true" else "BERT + MLP"))
+        print("batch_size: {}".format(args_opt.train_batch_size))
+
         do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
 
         if args_opt.do_eval.lower() == "true":
@@ -245,7 +257,7 @@ def run_ner():
         ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
                                 assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
                                 schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
-                                do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
+                                do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), drop_remainder=False)
         do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
                 args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
                 args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)
diff --git a/model_zoo/official/nlp/bert/scripts/run_ner.sh b/model_zoo/official/nlp/bert/scripts/run_ner.sh
index 542b5b64b75..d108be3f86a 100644
--- a/model_zoo/official/nlp/bert/scripts/run_ner.sh
+++ b/model_zoo/official/nlp/bert/scripts/run_ner.sh
@@ -18,7 +18,7 @@ echo "==========================================================================
 echo "Please run the script as: "
 echo "bash scripts/run_ner.sh"
 echo "for example: bash scripts/run_ner.sh"
-echo "assessment_method include: [F1, SpanF1, clue_benchmark]"
+echo "assessment_method include: [BF1, MF1, clue_benchmark]"
 echo "=============================================================================================================="
 
 mkdir -p ms_log
@@ -30,7 +30,7 @@ python ${PROJECT_DIR}/../run_ner.py  \
     --device_target="Ascend" \
     --do_train="true" \
     --do_eval="false" \
-    --assessment_method="F1" \
+    --assessment_method="BF1" \
     --use_crf="false" \
     --device_id=0 \
     --epoch_num=5 \
diff --git a/model_zoo/official/nlp/bert/src/assessment_method.py b/model_zoo/official/nlp/bert/src/assessment_method.py
index 5d5d7873741..bd840654a00 100644
--- a/model_zoo/official/nlp/bert/src/assessment_method.py
+++ b/model_zoo/official/nlp/bert/src/assessment_method.py
@@ -18,6 +18,7 @@ Bert evaluation assessment method script.
 '''
 import math
 import numpy as np
+from mindspore.nn.metrics import ConfusionMatrixMetric
 from .CRF import postprocess
 
 class Accuracy():
@@ -39,12 +40,18 @@ class F1():
     '''
     calculate F1 score
     '''
-    def __init__(self, use_crf=False, num_labels=2):
+    def __init__(self, use_crf=False, num_labels=2, mode="Binary"):
         self.TP = 0
         self.FP = 0
         self.FN = 0
         self.use_crf = use_crf
         self.num_labels = num_labels
+        self.mode = mode
+        if self.mode.lower() not in ("binary", "multilabel"):
+            raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]")
+        if self.mode.lower() != "binary":
+            self.metric = ConfusionMatrixMetric(skip_channel=False, metric_name=("f1 score"),
+                                                calculation_method=False, decrease="mean")
 
     def update(self, logits, labels):
         '''
@@ -62,78 +69,24 @@ class F1():
             logits = logits.asnumpy()
             logit_id = np.argmax(logits, axis=-1)
             logit_id = np.reshape(logit_id, -1)
-        pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
-        pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
-        self.TP += np.sum(pos_eva&pos_label)
-        self.FP += np.sum(pos_eva&(~pos_label))
-        self.FN += np.sum((~pos_eva)&pos_label)
 
-
-class SpanF1():
-    '''
-    calculate F1、precision and recall score in span manner for NER
-    '''
-    def __init__(self, use_crf=False, label2id=None):
-        self.TP = 0
-        self.FP = 0
-        self.FN = 0
-        self.use_crf = use_crf
-        self.label2id = label2id
-        if label2id is None:
-            raise ValueError("label2id info should not be empty")
-        self.id2label = {}
-        for key, value in label2id.items():
-            self.id2label[value] = key
-
-    def tag2span(self, ids):
-        '''
-        conbert ids list to span mode
-        '''
-        labels = np.array([self.id2label[id] for id in ids])
-        spans = []
-        prev_label = None
-        for idx, tag in enumerate(labels):
-            tag = tag.lower()
-            cur_label, label = tag[:1], tag[2:]
-            if cur_label in ('b', 's'):
-                spans.append((label, [idx, idx]))
-            elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]:
-                spans[-1][1][1] = idx
-            elif cur_label == 'o':
-                pass
-            else:
-                spans.append((label, [idx, idx]))
-            prev_label = cur_label
-        return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans]
-
-
-    def update(self, logits, labels):
-        '''
-        update span F1 score
-        '''
-        labels = labels.asnumpy()
-        labels = np.reshape(labels, -1)
-        if self.use_crf:
-            backpointers, best_tag_id = logits
-            best_path = postprocess(backpointers, best_tag_id)
-            logit_id = []
-            for ele in best_path:
-                logit_id.extend(ele)
+        if self.mode.lower() == "binary":
+            pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
+            pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
+            self.TP += np.sum(pos_eva&pos_label)
+            self.FP += np.sum(pos_eva&(~pos_label))
+            self.FN += np.sum((~pos_eva)&pos_label)
         else:
-            logits = logits.asnumpy()
-            logit_id = np.argmax(logits, axis=-1)
-            logit_id = np.reshape(logit_id, -1)
+            target = np.zeros((len(labels), self.num_labels), dtype=np.int)
+            pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int)
+            for i, label in enumerate(labels):
+                target[i][label] = 1
+            for i, label in enumerate(logit_id):
+                pred[i][label] = 1
+            self.metric.update(pred, target)
 
-        label_spans = self.tag2span(labels)
-        pred_spans = self.tag2span(logit_id)
-        for span in pred_spans:
-            if span in label_spans:
-                self.TP += 1
-                label_spans.remove(span)
-            else:
-                self.FP += 1
-        for span in label_spans:
-            self.FN += 1
+    def eval(self):
+        return self.metric.eval()
 
 
 class MCC():
diff --git a/model_zoo/official/nlp/bert/src/dataset.py b/model_zoo/official/nlp/bert/src/dataset.py
index 5710713f49a..7024088bc95 100644
--- a/model_zoo/official/nlp/bert/src/dataset.py
+++ b/model_zoo/official/nlp/bert/src/dataset.py
@@ -53,7 +53,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
 
 
 def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None,
-                       dataset_format="mindrecord", schema_file_path=None, do_shuffle=True):
+                       dataset_format="mindrecord", schema_file_path=None, do_shuffle=True, drop_remainder=True):
     """create finetune or evaluation dataset"""
     type_cast_op = C.TypeCast(mstype.int32)
     if dataset_format == "mindrecord":
@@ -74,7 +74,7 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy
     dataset = dataset.map(operations=type_cast_op, input_columns="input_ids")
     dataset = dataset.repeat(repeat_count)
     # apply batch operations
-    dataset = dataset.batch(batch_size, drop_remainder=True)
+    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
     return dataset
 
 
diff --git a/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py b/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py
index 2a5c9d626e4..84269982e0b 100755
--- a/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py
+++ b/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py
@@ -20,6 +20,7 @@ sample script of processing CLUE classification dataset using mindspore.dataset.
 import os
 import argparse
 import numpy as np
+from lxml import etree
 
 import mindspore.common.dtype as mstype
 import mindspore.dataset as ds
@@ -139,46 +140,60 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage
     return dataset
 
 
-def process_cluener_msra(data_file):
-    """process MSRA dataset for CLUE"""
-    content = []
-    labels = []
-    for line in open(data_file):
-        line = line.strip()
-        if line:
-            word = line.split("\t")[0]
-            if len(line.split("\t")) == 1:
-                label = "O"
+def process_cluener_msra(data_file, class_filter=None, split_begin=None, split_end=None):
+    """
+    Data pre-process for MSRA dataset
+    Args:
+        data_file (path): The original dataset file path.
+        class_filter (list of str): Only tags within the class_filter will be counted unless the list is None.
+        split_begin (float): Only data after split_begin part will be counted. Used for split dataset
+                     into training and evaluation subsets if needed.
+        split_end (float): Only data before split_end part will be counted. Used for split dataset
+                     into training and evaluation subsets if needed.
+    """
+    tree = etree.parse(data_file)
+    root = tree.getroot()
+    print("original dataset length: ", len(root))
+    dataset_size = len(root)
+    beg = 0 if split_begin is None or not 0 <= split_begin <= 1.0 else int(dataset_size * split_begin)
+    end = dataset_size if split_end is None or not 0 <= split_end <= 1.0 else int(dataset_size * split_end)
+    print("preporcessed dataset_size: ", end - beg)
+    for i in range(beg, end):
+        sentence = root[i]
+        tags = []
+        content = ""
+        for phrases in sentence:
+            labeled_words = [word for word in phrases]
+            if labeled_words:
+                for words in phrases:
+                    name = words.tag
+                    label = words.get("TYPE")
+                    words = words.text
+                    if not words:
+                        continue
+                    content += words
+                    if class_filter and name not in class_filter:
+                        tags += ["O" for _ in words]
+                    else:
+                        length = len(words)
+                        labels = ["S_"] if length == 1 else ["B_"] + ["M_" for i in range(length - 2)] + ["E_"]
+                        tags += [ele + label for ele in labels]
             else:
-                label = line.split("\t")[1].split("\n")[0]
-                if label[0] != "O":
-                    label = label[0] + "_" + label[2:]
-                if label[0] == "I":
-                    label = "M" + label[1:]
-            content.append(word)
-            labels.append(label)
-        else:
-            for i in range(1, len(labels) - 1):
-                if labels[i][0] == "B" and labels[i+1][0] != "M":
-                    labels[i] = "S" + labels[i][1:]
-                elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]:
-                    labels[i] = "E" + labels[i][1:]
-            last = len(labels) - 1
-            if labels[last][0] == "B":
-                labels[last] = "S" + labels[last][1:]
-            elif labels[last][0] == "M":
-                labels[last] = "E" + labels[last][1:]
-
-            yield (np.array("".join(content)), np.array(list(labels)))
-            content.clear()
-            labels.clear()
-            continue
+                phrases = phrases.text
+                if phrases:
+                    content += phrases
+                    tags += ["O" for ele in phrases]
+        if len(content) != len(tags):
+            raise ValueError("Mismathc length of content: ", len(content), " and label: ", len(tags))
+        yield (np.array("".join(content)), np.array(list(tags)))
 
 
-def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128):
+def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128, class_filter=None,
+                              split_begin=None, split_end=None):
     """Process MSRA dataset"""
     ### Loading MSRA from CLUEDataset
-    dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label'])
+    dataset = ds.GeneratorDataset(process_cluener_msra(data_dir, class_filter, split_begin, split_end),
+                                  column_names=['text', 'label'])
 
     ### Processing label
     label_vocab = text.Vocab.from_list(label_list)
@@ -213,10 +228,16 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="create mindrecord")
     parser.add_argument("--data_dir", type=str, default="", help="dataset path")
     parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path")
-    parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length")
     parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord")
     parser.add_argument("--label2id", type=str, default="",
                         help="Label2id file path, must be set for cluener2020 task")
+    parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length")
+    parser.add_argument("--class_filter", nargs='*', help="Specified classes will be counted, if empty all in counted")
+    parser.add_argument("--split_begin", type=float, default=None, help="Specified subsets of date will be counted,"
+                        "if not None, the data will counted begin from split_begin")
+    parser.add_argument("--split_end", type=float, default=None, help="Specified subsets of date will be counted,"
+                        "if not None, the data will counted before split_before")
+
     args_opt = parser.parse_args()
     if args_opt.label2id == "":
         raise ValueError("label2id should not be empty")
@@ -225,5 +246,6 @@ if __name__ == "__main__":
         for tag in f:
             labels_list.append(tag.strip())
     tag_to_index = list(convert_labels_to_index(labels_list).keys())
-    ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len)
+    ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len,
+                                   args_opt.class_filter, args_opt.split_begin, args_opt.split_end)
     ds.save(args_opt.save_path)
diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py
index e03d221cc2f..36633cce483 100644
--- a/model_zoo/official/nlp/bert/src/utils.py
+++ b/model_zoo/official/nlp/bert/src/utils.py
@@ -106,10 +106,11 @@ class LossCallBack(Callback):
                 percent = 1
                 epoch_num -= 1
             print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
-                  .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
+                  .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs),
+                          flush=True))
         else:
             print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
-                                                               str(cb_params.net_outputs)))
+                                                               str(cb_params.net_outputs), flush=True))
 
 def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
     """