forked from mindspore-Ecosystem/mindspore
add bert finetune and eval scripts in bert_clue example dir
This commit is contained in:
parent
eb3f70a0c7
commit
6c190b49c0
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
CRF script.
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
class CRF(nn.Cell):
|
||||
'''
|
||||
Conditional Random Field
|
||||
Args:
|
||||
tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
|
||||
batch_size: Batch size, i.e., the length of the first dimension.
|
||||
seq_length: Sequence length, i.e., the length of the second dimention.
|
||||
is_training: Specifies whether to use training mode.
|
||||
Returns:
|
||||
Training mode: Tensor, total loss.
|
||||
Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
|
||||
step with the highest score.
|
||||
'''
|
||||
def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
|
||||
|
||||
super(CRF, self).__init__()
|
||||
self.target_size = len(tag_to_index)
|
||||
self.is_training = is_training
|
||||
self.tag_to_index = tag_to_index
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.START_TAG = "<START>"
|
||||
self.STOP_TAG = "<STOP>"
|
||||
self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
|
||||
self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
|
||||
transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
|
||||
transitions[tag_to_index[self.START_TAG], :] = -10000
|
||||
transitions[:, tag_to_index[self.STOP_TAG]] = -10000
|
||||
self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
|
||||
self.cat = P.Concat(axis=-1)
|
||||
self.argmax = P.ArgMaxWithValue(axis=-1)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.sum = P.ReduceSum()
|
||||
self.tile = P.Tile()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.reshape = P.Reshape()
|
||||
self.expand = P.ExpandDims()
|
||||
self.mean = P.ReduceMean()
|
||||
init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
|
||||
init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
|
||||
self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_max = P.ReduceMax(keep_dims=True)
|
||||
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||
self.onehot = P.OneHot()
|
||||
|
||||
def log_sum_exp(self, logits):
|
||||
'''
|
||||
Compute the log_sum_exp score for normalization factor.
|
||||
'''
|
||||
max_score = self.reduce_max(logits, -1) #16 5 5
|
||||
score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
|
||||
score = max_score + score
|
||||
return score
|
||||
|
||||
def _realpath_score(self, features, label):
|
||||
'''
|
||||
Compute the emission and transition score for the real path.
|
||||
'''
|
||||
label = label * 1
|
||||
concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
|
||||
concat_A = self.reshape(concat_A, (self.batch_size, 1))
|
||||
labels = self.cat((concat_A, label))
|
||||
onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
|
||||
emits = features * onehot_label
|
||||
labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
|
||||
label1 = labels[:, 1:, :]
|
||||
label2 = labels[:, :self.seq_length, :]
|
||||
label1 = self.expand(label1, 3)
|
||||
label2 = self.expand(label2, 2)
|
||||
label_trans = label1 * label2
|
||||
transitions = self.expand(self.expand(self.transitions, 0), 0)
|
||||
trans = transitions * label_trans
|
||||
score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
|
||||
stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
|
||||
stop_value = self.transitions[(self.target_size-1):self.target_size, :]
|
||||
stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
|
||||
score = score + self.sum(stop_score, 1)
|
||||
score = self.reshape(score, (self.batch_size, -1))
|
||||
return score
|
||||
|
||||
def _normalization_factor(self, features):
|
||||
'''
|
||||
Compute the total score for all the paths.
|
||||
'''
|
||||
forward_var = self.init_alphas
|
||||
forward_var = self.expand(forward_var, 1)
|
||||
for idx in range(self.seq_length):
|
||||
feat = features[:, idx:(idx+1), :]
|
||||
emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
|
||||
next_tag_var = emit_score + self.transitions + forward_var
|
||||
forward_var = self.log_sum_exp(next_tag_var)
|
||||
forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
|
||||
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
|
||||
alpha = self.log_sum_exp(terminal_var)
|
||||
alpha = self.reshape(alpha, (self.batch_size, -1))
|
||||
return alpha
|
||||
|
||||
def _decoder(self, features):
|
||||
'''
|
||||
Viterbi decode for evaluation.
|
||||
'''
|
||||
backpointers = ()
|
||||
forward_var = self.init_alphas
|
||||
for idx in range(self.seq_length):
|
||||
feat = features[:, idx:(idx+1), :]
|
||||
feat = self.reshape(feat, (self.batch_size, self.target_size))
|
||||
bptrs_t = ()
|
||||
|
||||
next_tag_var = self.expand(forward_var, 1) + self.transitions
|
||||
best_tag_id, best_tag_value = self.argmax(next_tag_var)
|
||||
bptrs_t += (best_tag_id,)
|
||||
forward_var = best_tag_value + feat
|
||||
|
||||
backpointers += (bptrs_t,)
|
||||
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
|
||||
best_tag_id, _ = self.argmax(terminal_var)
|
||||
return backpointers, best_tag_id
|
||||
|
||||
def construct(self, features, label):
|
||||
if self.is_training:
|
||||
forward_score = self._normalization_factor(features)
|
||||
gold_score = self._realpath_score(features, label)
|
||||
return_value = self.mean(forward_score - gold_score)
|
||||
else:
|
||||
path_list, tag = self._decoder(features)
|
||||
return_value = path_list, tag
|
||||
return return_value
|
||||
|
||||
def postprocess(backpointers, best_tag_id):
|
||||
'''
|
||||
Do postprocess
|
||||
'''
|
||||
best_tag_id = best_tag_id.asnumpy()
|
||||
batch_size = len(best_tag_id)
|
||||
best_path = []
|
||||
for i in range(batch_size):
|
||||
best_path.append([])
|
||||
best_local_id = best_tag_id[i]
|
||||
best_path[-1].append(best_local_id)
|
||||
for bptrs_t in reversed(backpointers):
|
||||
bptrs_t = bptrs_t[0].asnumpy()
|
||||
local_idx = bptrs_t[i]
|
||||
best_local_id = local_idx[best_local_id]
|
||||
best_path[-1].append(best_local_id)
|
||||
# Pop off the start tag (we dont want to return that to the caller)
|
||||
best_path[-1].pop()
|
||||
best_path[-1].reverse()
|
||||
return best_path
|
|
@ -17,21 +17,21 @@ kiextractor). Convert the dataset to TFRecord format and move the files to a spe
|
|||
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH
|
||||
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
|
||||
```
|
||||
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH
|
||||
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
|
||||
```
|
||||
|
||||
### Fine-Tuning
|
||||
- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'ckpt_file' are set to your own path, set the 'pre_training_ckpt' to save the checkpoint files generated.
|
||||
- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'pre_training_file' are set to your own path. Set the 'pre_training_ckpt' to a saved checkpoint file generated after pre-training.
|
||||
|
||||
- Run `finetune.py` for fine-tuning of BERT-base and BERT-NEZHA model.
|
||||
|
||||
```bash
|
||||
python finetune.py --backend=ms
|
||||
python finetune.py
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
@ -40,7 +40,7 @@ kiextractor). Convert the dataset to TFRecord format and move the files to a spe
|
|||
- Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model.
|
||||
|
||||
```bash
|
||||
python evaluation.py --backend=ms
|
||||
python evaluation.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
@ -77,28 +77,33 @@ options:
|
|||
It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively.
|
||||
### Options:
|
||||
```
|
||||
Pre-Training:
|
||||
config.py:
|
||||
bert_network version of BERT model: base | nezha, default is base
|
||||
loss_scale_value initial value of loss scale: N, default is 2^32
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 1000
|
||||
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
|
||||
Fine-Tuning:
|
||||
task task type: NER | XNLI | LCQMC | SENTI
|
||||
data_file dataset file to load: PATH, default is "/your/path/cn-wiki-128"
|
||||
schema_file dataset schema file to load: PATH, default is "/your/path/datasetSchema.json"
|
||||
epoch_num repeat counts of training: N, default is 40
|
||||
finetune_config.py:
|
||||
task task type: NER | XNLI | LCQMC | SENTIi | OTHERS
|
||||
num_labels number of labels to do classification
|
||||
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
|
||||
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
|
||||
epoch_num repeat counts of training: N, default is 5
|
||||
ckpt_prefix prefix used to save checkpoint files: PREFIX, default is "bert"
|
||||
ckpt_dir path to save checkpoint files: PATH, default is None
|
||||
pre_training_ckpt checkpoint file to load: PATH, default is "/your/path/pre_training.ckpt"
|
||||
optimizer optimizer used in the network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False
|
||||
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"
|
||||
|
||||
Evaluation:
|
||||
task task type: NER | XNLI | LCQMC | SENTI
|
||||
evaluation_config.py:
|
||||
task task type: NER | XNLI | LCQMC | SENTI | OTHERS
|
||||
num_labels number of labels to do classsification
|
||||
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
|
||||
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
|
||||
finetune_ckpt checkpoint file to load: PATH, default is "/your/path/your.ckpt"
|
||||
use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False
|
||||
clue_benchmark whether to use clue benchmark. clue_benchmark takes effect only when task type is NER, default is False
|
||||
```
|
||||
|
||||
### Parameters:
|
||||
|
@ -125,25 +130,24 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation):
|
|||
|
||||
Parameters for optimizer:
|
||||
AdamWeightDecayDynamicLR:
|
||||
decay_steps steps of the learning rate decay: N, default is 12276*3
|
||||
learning_rate value of learning rate: Q, default is 1e-5
|
||||
end_learning_rate value of end learning rate: Q, default is 0.0
|
||||
power power: Q, default is 10.0
|
||||
warmup_steps steps of the learning rate warm up: N, default is 2100
|
||||
weight_decay weight decay: Q, default is 1e-5
|
||||
eps term added to the denominator to improve numerical stability: Q, default is 1e-6
|
||||
decay_steps steps of the learning rate decay: N
|
||||
learning_rate value of learning rate: Q
|
||||
end_learning_rate value of end learning rate: Q, must be positive
|
||||
power power: Q
|
||||
warmup_steps steps of the learning rate warm up: N
|
||||
weight_decay weight decay: Q
|
||||
eps term added to the denominator to improve numerical stability: Q
|
||||
|
||||
Lamb:
|
||||
decay_steps steps of the learning rate decay: N, default is 12276*3
|
||||
learning_rate value of learning rate: Q, default is 1e-5
|
||||
end_learning_rate value of end learning rate: Q, default is 0.0
|
||||
power power: Q, default is 5.0
|
||||
warmup_steps steps of the learning rate warm up: N, default is 2100
|
||||
weight_decay weight decay: Q, default is 1e-5
|
||||
decay_filter function to determine whether to apply weight decay on parameters: FUNCTION, default is lambda x: False
|
||||
decay_steps steps of the learning rate decay: N
|
||||
learning_rate value of learning rate: Q
|
||||
end_learning_rate value of end learning rate: Q
|
||||
power power: Q
|
||||
warmup_steps steps of the learning rate warm up: N
|
||||
weight_decay weight decay: Q
|
||||
|
||||
Momentum:
|
||||
learning_rate value of learning rate: Q, default is 2e-5
|
||||
momentum momentum for the moving average: Q, default is 0.9
|
||||
learning_rate value of learning rate: Q
|
||||
momentum momentum for the moving average: Q
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''bert clue evaluation'''
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from evaluation_config import cfg
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from CRF import postprocess
|
||||
import tokenization
|
||||
from sample_process import label_generation, process_one_example_p
|
||||
|
||||
vocab_file = "./vocab.txt"
|
||||
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
|
||||
|
||||
def process(model, text, sequence_length):
|
||||
"""
|
||||
process text.
|
||||
"""
|
||||
data = [text]
|
||||
features = []
|
||||
res = []
|
||||
ids = []
|
||||
for i in data:
|
||||
feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length)
|
||||
features.append(feature)
|
||||
input_ids, input_mask, token_type_id = feature
|
||||
input_ids = Tensor(np.array(input_ids), mstype.int32)
|
||||
input_mask = Tensor(np.array(input_mask), mstype.int32)
|
||||
token_type_id = Tensor(np.array(token_type_id), mstype.int32)
|
||||
if cfg.use_crf:
|
||||
backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
|
||||
best_path = postprocess(backpointers, best_tag_id)
|
||||
logits = []
|
||||
for ele in best_path:
|
||||
logits.extend(ele)
|
||||
ids = logits
|
||||
else:
|
||||
logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
|
||||
ids = logits.asnumpy()
|
||||
ids = np.argmax(ids, axis=-1)
|
||||
ids = list(ids)
|
||||
res = label_generation(text, ids)
|
||||
return res
|
||||
|
||||
def submit(model, path, sequence_length):
|
||||
"""
|
||||
submit task
|
||||
"""
|
||||
data = []
|
||||
for line in open(path):
|
||||
if not line.strip():
|
||||
continue
|
||||
oneline = json.loads(line.strip())
|
||||
res = process(model, oneline["text"], sequence_length)
|
||||
print("text", oneline["text"])
|
||||
print("res:", res)
|
||||
data.append(json.dumps({"label": res}, ensure_ascii=False))
|
||||
open("ner_predict.json", "w").write("\n".join(data))
|
|
@ -30,6 +30,7 @@ cfg = edict({
|
|||
'power': 5.0,
|
||||
'weight_decay': 1e-5,
|
||||
'eps': 1e-6,
|
||||
'warmup_steps': 10000,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'start_learning_rate': 3e-5,
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Bert evaluation script.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from evaluation_config import cfg, bert_net_cfg
|
||||
from utils import BertNER, BertCLS
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from CRF import postprocess
|
||||
from cluener_evaluation import submit
|
||||
from finetune_config import tag_to_index
|
||||
|
||||
class Accuracy():
|
||||
'''
|
||||
calculate accuracy
|
||||
'''
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
def update(self, logits, labels):
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
self.acc_num += np.sum(labels == logit_id)
|
||||
self.total_num += len(labels)
|
||||
print("=========================accuracy is ", self.acc_num / self.total_num)
|
||||
|
||||
class F1():
|
||||
'''
|
||||
calculate F1 score
|
||||
'''
|
||||
def __init__(self):
|
||||
self.TP = 0
|
||||
self.FP = 0
|
||||
self.FN = 0
|
||||
def update(self, logits, labels):
|
||||
'''
|
||||
update F1 score
|
||||
'''
|
||||
labels = labels.asnumpy()
|
||||
labels = np.reshape(labels, -1)
|
||||
if cfg.use_crf:
|
||||
backpointers, best_tag_id = logits
|
||||
best_path = postprocess(backpointers, best_tag_id)
|
||||
logit_id = []
|
||||
for ele in best_path:
|
||||
logit_id.extend(ele)
|
||||
else:
|
||||
logits = logits.asnumpy()
|
||||
logit_id = np.argmax(logits, axis=-1)
|
||||
logit_id = np.reshape(logit_id, -1)
|
||||
pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)])
|
||||
pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)])
|
||||
self.TP += np.sum(pos_eva&pos_label)
|
||||
self.FP += np.sum(pos_eva&(~pos_label))
|
||||
self.FN += np.sum((~pos_eva)&pos_label)
|
||||
|
||||
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
||||
'''
|
||||
get dataset
|
||||
'''
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
|
||||
"segment_ids", "label_ids"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
# apply shuffle operation
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
def bert_predict(Evaluation):
|
||||
'''
|
||||
prediction function
|
||||
'''
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
dataset = get_dataset(bert_net_cfg.batch_size, 1)
|
||||
if cfg.use_crf:
|
||||
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
|
||||
tag_to_index=tag_to_index, dropout_prob=0.0)
|
||||
else:
|
||||
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(cfg.finetune_ckpt)
|
||||
load_param_into_net(net_for_pretraining, param_dict)
|
||||
model = Model(net_for_pretraining)
|
||||
return model, dataset
|
||||
|
||||
def test_eval():
|
||||
'''
|
||||
evaluation function
|
||||
'''
|
||||
task_type = BertNER if cfg.task == "NER" else BertCLS
|
||||
model, dataset = bert_predict(task_type)
|
||||
if cfg.clue_benchmark:
|
||||
submit(model, cfg.data_file, bert_net_cfg.seq_length)
|
||||
else:
|
||||
callback = F1() if cfg.task == "NER" else Accuracy()
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(Tensor(data[i]))
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
|
||||
callback.update(logits, label_ids)
|
||||
print("==============================================================")
|
||||
if cfg.task == "NER":
|
||||
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
|
||||
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
|
||||
print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FP)))
|
||||
else:
|
||||
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
print("==============================================================")
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_labels = cfg.num_labels
|
||||
test_eval()
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
config settings, will be used in finetune.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
|
||||
cfg = edict({
|
||||
'task': 'NER',
|
||||
'num_labels': 41,
|
||||
'data_file': '/your/path/evaluation.tfrecord',
|
||||
'schema_file': '/your/path/schema.json',
|
||||
'finetune_ckpt': '/your/path/your.ckpt',
|
||||
'use_crf': False,
|
||||
'clue_benchmark': False,
|
||||
})
|
||||
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16 if not cfg.clue_benchmark else 1,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Bert finetune script.
|
||||
'''
|
||||
|
||||
import os
|
||||
from utils import BertFinetuneCell, BertCLS, BertNER
|
||||
from finetune_config import cfg, bert_net_cfg, tag_to_index
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
class LossCallBack(Callback):
|
||||
'''
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF, terminate training.
|
||||
Note:
|
||||
If per_print_times is 0, do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
'''
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be in and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
with open("./loss.log", "a+") as f:
|
||||
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
f.write("\n")
|
||||
|
||||
def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
||||
'''
|
||||
get dataset
|
||||
'''
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
|
||||
"segment_ids", "label_ids"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
# apply shuffle operation
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
def test_train():
|
||||
'''
|
||||
finetune function
|
||||
pytest -s finetune.py::test_train
|
||||
'''
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid,
|
||||
enable_mem_reuse=True, enable_task_sink=True)
|
||||
#BertCLSTrain for classification
|
||||
#BertNERTrain for sequence labeling
|
||||
if cfg.task == 'NER':
|
||||
if cfg.use_crf:
|
||||
netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True,
|
||||
tag_to_index=tag_to_index, dropout_prob=0.1)
|
||||
else:
|
||||
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
|
||||
else:
|
||||
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
|
||||
dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
|
||||
# optimizer
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
if cfg.optimizer == 'AdamWeightDecayDynamicLR':
|
||||
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
|
||||
decay_steps=steps_per_epoch * cfg.epoch_num,
|
||||
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
|
||||
power=cfg.AdamWeightDecayDynamicLR.power,
|
||||
warmup_steps=steps_per_epoch,
|
||||
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
|
||||
eps=cfg.AdamWeightDecayDynamicLR.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num,
|
||||
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported.")
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config)
|
||||
param_dict = load_checkpoint(cfg.pre_training_ckpt)
|
||||
load_param_into_net(netwithloss, param_dict)
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
model = Model(netwithgrads)
|
||||
model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
|
||||
D.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_train()
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
config settings, will be used in finetune.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
|
||||
cfg = edict({
|
||||
'task': 'NER',
|
||||
'num_labels': 41,
|
||||
'data_file': '/your/path/train.tfrecord',
|
||||
'schema_file': '/your/path/schema.json',
|
||||
'epoch_num': 5,
|
||||
'ckpt_prefix': 'bert',
|
||||
'ckpt_dir': None,
|
||||
'pre_training_ckpt': '/your/path/pre_training.ckpt',
|
||||
'use_crf': False,
|
||||
'optimizer': 'Lamb',
|
||||
'AdamWeightDecayDynamicLR': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 1e-5,
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'start_learning_rate': 2e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'decay_filter': lambda x: False,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
})
|
||||
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
|
||||
tag_to_index = {
|
||||
"O": 0,
|
||||
"S_address": 1,
|
||||
"B_address": 2,
|
||||
"M_address": 3,
|
||||
"E_address": 4,
|
||||
"S_book": 5,
|
||||
"B_book": 6,
|
||||
"M_book": 7,
|
||||
"E_book": 8,
|
||||
"S_company": 9,
|
||||
"B_company": 10,
|
||||
"M_company": 11,
|
||||
"E_company": 12,
|
||||
"S_game": 13,
|
||||
"B_game": 14,
|
||||
"M_game": 15,
|
||||
"E_game": 16,
|
||||
"S_government": 17,
|
||||
"B_government": 18,
|
||||
"M_government": 19,
|
||||
"E_government": 20,
|
||||
"S_movie": 21,
|
||||
"B_movie": 22,
|
||||
"M_movie": 23,
|
||||
"E_movie": 24,
|
||||
"S_name": 25,
|
||||
"B_name": 26,
|
||||
"M_name": 27,
|
||||
"E_name": 28,
|
||||
"S_organization": 29,
|
||||
"B_organization": 30,
|
||||
"M_organization": 31,
|
||||
"E_organization": 32,
|
||||
"S_position": 33,
|
||||
"B_position": 34,
|
||||
"M_position": 35,
|
||||
"E_position": 36,
|
||||
"S_scene": 37,
|
||||
"B_scene": 38,
|
||||
"M_scene": 39,
|
||||
"E_scene": 40,
|
||||
"<START>": 41,
|
||||
"<STOP>": 42
|
||||
}
|
|
@ -84,7 +84,6 @@ def run_pretrain():
|
|||
if args_opt.distribute == "true":
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_context(enable_hccl=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
D.init()
|
||||
|
@ -103,7 +102,7 @@ def run_pretrain():
|
|||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
|
||||
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
|
||||
eps=cfg.Lamb.eps, decay_filter=cfg.Lamb.decay_filter)
|
||||
eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
|
@ -114,7 +113,8 @@ def run_pretrain():
|
|||
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
|
||||
power=cfg.AdamWeightDecayDynamicLR.power,
|
||||
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
|
||||
eps=cfg.AdamWeightDecayDynamicLR.eps)
|
||||
eps=cfg.AdamWeightDecayDynamicLR.eps,
|
||||
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]".
|
||||
format(cfg.optimizer))
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""process txt"""
|
||||
|
||||
import re
|
||||
import json
|
||||
|
||||
def process_one_example_p(tokenizer, text, max_seq_len=128):
|
||||
"""process one testline"""
|
||||
textlist = list(text)
|
||||
tokens = []
|
||||
for _, word in enumerate(textlist):
|
||||
token = tokenizer.tokenize(word)
|
||||
tokens.extend(token)
|
||||
if len(tokens) >= max_seq_len - 1:
|
||||
tokens = tokens[0:(max_seq_len - 2)]
|
||||
ntokens = []
|
||||
segment_ids = []
|
||||
label_ids = []
|
||||
ntokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for _, token in enumerate(tokens):
|
||||
ntokens.append(token)
|
||||
segment_ids.append(0)
|
||||
ntokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_len:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
label_ids.append(0)
|
||||
ntokens.append("**NULL**")
|
||||
assert len(input_ids) == max_seq_len
|
||||
assert len(input_mask) == max_seq_len
|
||||
assert len(segment_ids) == max_seq_len
|
||||
|
||||
feature = (input_ids, input_mask, segment_ids)
|
||||
return feature
|
||||
|
||||
def label_generation(text, probs):
|
||||
"""generate label"""
|
||||
data = [text]
|
||||
probs = [probs]
|
||||
result = []
|
||||
label2id = json.loads(open("./label2id.json").read())
|
||||
id2label = [k for k, v in label2id.items()]
|
||||
|
||||
for index, prob in enumerate(probs):
|
||||
for v in prob[1:len(data[index]) + 1]:
|
||||
result.append(id2label[int(v)])
|
||||
|
||||
labels = {}
|
||||
start = None
|
||||
index = 0
|
||||
for _, t in zip("".join(data), result):
|
||||
if re.search("^[BS]", t):
|
||||
if start is not None:
|
||||
label = result[index - 1][2:]
|
||||
if labels.get(label):
|
||||
te_ = text[start:index]
|
||||
labels[label][te_] = [[start, index - 1]]
|
||||
else:
|
||||
te_ = text[start:index]
|
||||
labels[label] = {te_: [[start, index - 1]]}
|
||||
start = index
|
||||
if re.search("^O", t):
|
||||
if start is not None:
|
||||
label = result[index - 1][2:]
|
||||
if labels.get(label):
|
||||
te_ = text[start:index]
|
||||
labels[label][te_] = [[start, index - 1]]
|
||||
else:
|
||||
te_ = text[start:index]
|
||||
labels[label] = {te_: [[start, index - 1]]}
|
||||
start = None
|
||||
index += 1
|
||||
if start is not None:
|
||||
label = result[start][2:]
|
||||
if labels.get(label):
|
||||
te_ = text[start:index]
|
||||
labels[label][te_] = [[start, index - 1]]
|
||||
else:
|
||||
te_ = text[start:index]
|
||||
labels[label] = {te_: [[start, index - 1]]}
|
||||
return labels
|
|
@ -0,0 +1,264 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Functional Cells used in Bert finetune and evaluation.
|
||||
'''
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
|
||||
from CRF import CRF
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
class BertFinetuneCell(nn.Cell):
|
||||
"""
|
||||
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
|
||||
super(BertFinetuneCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
sens=None):
|
||||
|
||||
|
||||
weights = self.weights
|
||||
init = self.alloc_status()
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
clear_before_grad = self.clear_before_grad(init)
|
||||
F.control_depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
flag = self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
F.control_depend(grads, flag)
|
||||
F.control_depend(flag, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
class BertCLSModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
|
||||
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
|
||||
logits as the results of log_softmax is propotional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertCLSModel, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
_, pooled_output, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
cls = self.cast(pooled_output, self.dtype)
|
||||
cls = self.dropout(cls)
|
||||
logits = self.dense_1(cls)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class BertNERModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
|
||||
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertNERModel, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.num_labels = num_labels
|
||||
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, config.hidden_size)
|
||||
self.use_crf = use_crf
|
||||
self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id):
|
||||
sequence_output, _, _ = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
seq = self.dropout(sequence_output)
|
||||
seq = self.reshape(seq, self.shape)
|
||||
logits = self.dense_1(seq)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
if self.use_crf:
|
||||
return_value = self.reshape(logits, self.origin_shape)
|
||||
else:
|
||||
return_value = self.log_softmax(logits)
|
||||
return return_value
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, logits, label_ids, num_labels):
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0
|
||||
return return_value
|
||||
|
||||
class BertCLS(nn.Cell):
|
||||
"""
|
||||
Train interface for classification finetuning task.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||
super(BertCLS, self).__init__()
|
||||
self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
log_probs = self.bert(input_ids, input_mask, token_type_id)
|
||||
loss = self.loss(log_probs, label_ids, self.num_labels)
|
||||
return loss
|
||||
|
||||
|
||||
class BertNER(nn.Cell):
|
||||
"""
|
||||
Train interface for sequence labeling finetuning task.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertNER, self).__init__()
|
||||
self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
|
||||
if use_crf:
|
||||
if not tag_to_index:
|
||||
raise Exception("The dict for tag-index mapping should be provided for CRF.")
|
||||
self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
|
||||
else:
|
||||
self.loss = CrossEntropyCalculation(is_training)
|
||||
self.num_labels = num_labels
|
||||
self.use_crf = use_crf
|
||||
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||
if self.use_crf:
|
||||
loss = self.loss(logits, label_ids)
|
||||
else:
|
||||
loss = self.loss(logits, label_ids, self.num_labels)
|
||||
return loss
|
Loading…
Reference in New Issue