!22067 squad add gpu train and eval

Merge pull request !22067 from chenweitao_295/bert_squad_gpu
This commit is contained in:
i-robot 2021-08-23 03:41:09 +00:00 committed by Gitee
commit 0b11b86bd7
7 changed files with 138 additions and 41 deletions

View File

@ -92,17 +92,17 @@ bash scripts/run_distributed_pretrain_ascend.sh /path/cn-wiki-128 /path/hccl.jso
- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. - Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`.
- Classification task: Set task related hyperparameters in scripts/run_classifier.sh. - Classification task: Set task related hyperparameters in scripts/run_classifier.sh.
- Run `bash scripts/run_classifier.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_classifier.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_classifier.sh bash scripts/run_classifier.sh
- NER task: Set task related hyperparameters in scripts/run_ner.sh. - NER task: Set task related hyperparameters in scripts/run_ner.sh.
- Run `bash scripts/run_ner.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_ner.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_ner.sh bash scripts/run_ner.sh
- SQuAD task: Set task related hyperparameters in scripts/run_squad.sh. - SQuAD task: Set task related hyperparameters in scripts/run_squad.sh.
- Run `bash scripts/run_squad.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_squad.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_squad.sh bash scripts/run_squad.sh
``` ```
@ -121,19 +121,19 @@ bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/cn-wiki-128
- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. - Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`.
- Classification task: Set task related hyperparameters in scripts/run_classifier.sh. - Classification task: Set task related hyperparameters in scripts/run_classifier.sh.
- Run `bash scripts/run_classifier.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_classifier.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_classifier.sh bash scripts/run_classifier.sh
- NER task: Set task related hyperparameters in scripts/run_ner.sh. - NER task: Set task related hyperparameters in scripts/run_ner.sh.
- Run `bash scripts/run_ner.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_ner.sh` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_ner.sh bash scripts/run_ner.sh
- SQuAD task: Set task related hyperparameters in scripts/run_squad.sh. - SQuAD task: Set task related hyperparameters in scripts/run_squad_gpu.sh.
- Run `bash scripts/run_squad.py` for fine-tuning of BERT-base and BERT-NEZHA model. - Run `bash scripts/run_squad_gpu.py` for fine-tuning of BERT-base and BERT-NEZHA model.
bash scripts/run_squad.sh bash scripts/run_squad_gpu.sh
``` ```
- running on ModelArts - running on ModelArts
@ -268,7 +268,8 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
├─README.md ├─README.md
├─run_classifier.sh # shell script for standalone classifier task on ascend or gpu ├─run_classifier.sh # shell script for standalone classifier task on ascend or gpu
├─run_ner.sh # shell script for standalone NER task on ascend or gpu ├─run_ner.sh # shell script for standalone NER task on ascend or gpu
├─run_squad.sh # shell script for standalone SQUAD task on ascend or gpu ├─run_squad.sh # shell script for standalone SQUAD task on ascend
├─run_squad_gpu.sh # shell script for standalone SQUAD task on gpu
├─run_standalone_pretrain_ascend.sh # shell script for standalone pretrain on ascend ├─run_standalone_pretrain_ascend.sh # shell script for standalone pretrain on ascend
├─run_distributed_pretrain_ascend.sh # shell script for distributed pretrain on ascend ├─run_distributed_pretrain_ascend.sh # shell script for distributed pretrain on ascend
├─run_distributed_pretrain_gpu.sh # shell script for distributed pretrain on gpu ├─run_distributed_pretrain_gpu.sh # shell script for distributed pretrain on gpu

View File

@ -95,17 +95,17 @@ bash scripts/run_distributed_pretrain_ascend.sh /path/cn-wiki-128 /path/hccl.jso
- 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。 - 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。
- 分类任务在scripts/run_classifier.sh中设置任务相关的超参。 - 分类任务在scripts/run_classifier.sh中设置任务相关的超参。
- 运行`bash scripts/run_classifier.py`对BERT-base和BERT-NEZHA模型进行微调。 - 运行`bash scripts/run_classifier.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_classifier.sh bash scripts/run_classifier.sh
- NER任务在scripts/run_ner.sh中设置任务相关的超参。 - NER任务在scripts/run_ner.sh中设置任务相关的超参。
- 运行`bash scripts/run_ner.py`对BERT-base和BERT-NEZHA模型进行微调。 - 运行`bash scripts/run_ner.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_ner.sh bash scripts/run_ner.sh
- SQUAD任务在scripts/run_squad.sh中设置任务相关的超参。 - SQUAD任务在scripts/run_squad.sh中设置任务相关的超参。
-运行`bash scripts/run_squad.py`对BERT-base和BERT-NEZHA模型进行微调。 -运行`bash scripts/run_squad.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_squad.sh bash scripts/run_squad.sh
``` ```
@ -124,19 +124,19 @@ bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/cn-wiki-128
- 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。 - 在`finetune_eval_config.py`中设置BERT网络配置和优化器超参。
- 分类任务在scripts/run_classifier.sh中设置任务相关的超参。 - 分类任务在scripts/run_classifier.sh中设置任务相关的超参。
- 运行`bash scripts/run_classifier.py`对BERT-base和BERT-NEZHA模型进行微调。 - 运行`bash scripts/run_classifier.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_classifier.sh bash scripts/run_classifier.sh
- NER任务在scripts/run_ner.sh中设置任务相关的超参。 - NER任务在scripts/run_ner.sh中设置任务相关的超参。
- 运行`bash scripts/run_ner.py`对BERT-base和BERT-NEZHA模型进行微调。 - 运行`bash scripts/run_ner.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_ner.sh bash scripts/run_ner.sh
- SQUAD任务在scripts/run_squad.sh中设置任务相关的超参。 - SQUAD任务在scripts/run_squad_gpu.sh中设置任务相关的超参。
-运行`bash scripts/run_squad.py`对BERT-base和BERT-NEZHA模型进行微调。 -运行`bash scripts/run_squad_gpu.sh`对BERT-base和BERT-NEZHA模型进行微调。
bash scripts/run_squad.sh bash scripts/run_squad_gpu.sh
``` ```
- 在ModelArts上运行(如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)) - 在ModelArts上运行(如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
@ -266,7 +266,8 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
--README.md --README.md
├─run_classifier.sh # Ascend或GPU设备上单机分类器任务shell脚本 ├─run_classifier.sh # Ascend或GPU设备上单机分类器任务shell脚本
├─run_ner.sh # Ascend或GPU设备上单机NER任务shell脚本 ├─run_ner.sh # Ascend或GPU设备上单机NER任务shell脚本
├─run_squad.sh # Ascend或GPU设备上单机SQUAD任务shell脚本 ├─run_squad.sh # Ascend设备上单机SQUAD任务shell脚本
├─run_squad_gpu.sh # GPU设备上单机SQUAD任务shell脚本
├─run_standalone_pretrain_ascend.sh # Ascend设备上单机预训练shell脚本 ├─run_standalone_pretrain_ascend.sh # Ascend设备上单机预训练shell脚本
├─run_distributed_pretrain_ascend.sh # Ascend设备上分布式预训练shell脚本 ├─run_distributed_pretrain_ascend.sh # Ascend设备上分布式预训练shell脚本
├─run_distributed_pretrain_gpu.sh # GPU设备上分布式预训练shell脚本 ├─run_distributed_pretrain_gpu.sh # GPU设备上分布式预训练shell脚本

View File

@ -78,7 +78,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
param_dict = load_checkpoint(load_checkpoint_path) param_dict = load_checkpoint(load_checkpoint_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell) netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)
model = Model(netwithgrads) model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb] callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
@ -157,6 +157,7 @@ def run_squad():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
elif target == "GPU": elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(enable_graph_kernel=True)
if bert_net_cfg.compute_type != mstype.float32: if bert_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.') logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32 bert_net_cfg.compute_type = mstype.float32

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_squad_gpu.sh"
echo "for example: bash scripts/run_squad_gpu.sh"
echo "assessment_method include: [Accuracy]"
echo "=============================================================================================================="
mkdir -p ms_log
CUR_DIR=`pwd`
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_squad.py \
--config_path="../../task_squad_config.yaml" \
--device_target="GPU" \
--do_train="true" \
--do_eval="false" \
--device_id=0 \
--epoch_num=3 \
--num_class=2 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--train_batch_size=32 \
--eval_batch_size=1 \
--vocab_file_path="" \
--save_finetune_checkpoint_path="" \
--load_pretrain_checkpoint_path="" \
--load_finetune_checkpoint_path="" \
--train_data_file_path="" \
--eval_json_path="" \
--schema_file_path="" > squad_log.txt 2>&1 &

View File

@ -32,21 +32,26 @@ from .bert_for_pre_training import clip_grad
from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel
from .utils import CrossEntropyCalculation from .utils import CrossEntropyCalculation
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0 GRADIENT_CLIP_VALUE = 1.0
grad_scale = C.MultitypeFuncGraph("grad_scale") grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal() reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor") @grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale) return grad * reciprocal(scale)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus() grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor") @_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad): def _tensor_grad_overflow(grad):
return grad_overflow(grad) return grad_overflow(grad)
class BertFinetuneCell(nn.Cell): class BertFinetuneCell(nn.Cell):
""" """
Especially defined for finetuning where only four inputs tensor are needed. Especially defined for finetuning where only four inputs tensor are needed.
@ -61,6 +66,7 @@ class BertFinetuneCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights. optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
""" """
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertFinetuneCell, self).__init__(auto_prefix=False) super(BertFinetuneCell, self).__init__(auto_prefix=False)
@ -156,10 +162,12 @@ class BertFinetuneCell(nn.Cell):
self.optimizer(grads) self.optimizer(grads)
return (loss, cond) return (loss, cond)
class BertSquadCell(nn.Cell): class BertSquadCell(nn.Cell):
""" """
specifically defined for finetuning where only four inputs tensor are needed. specifically defined for finetuning where only four inputs tensor are needed.
""" """
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False) super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
@ -179,6 +187,13 @@ class BertSquadCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast() self.cast = P.Cast()
self.gpu_target = False
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.alloc_status = P.NPUAllocFloatStatus() self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus() self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus() self.clear_status = P.NPUClearFloatStatus()
@ -202,7 +217,7 @@ class BertSquadCell(nn.Cell):
sens=None): sens=None):
"""BertSquad""" """BertSquad"""
weights = self.weights weights = self.weights
init = self.alloc_status() init = False
loss = self.network(input_ids, loss = self.network(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
@ -214,6 +229,8 @@ class BertSquadCell(nn.Cell):
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
scaling_sens = sens scaling_sens = sens
if not self.gpu_target:
init = self.alloc_status()
init = F.depend(init, loss) init = F.depend(init, loss)
clear_status = self.clear_status(init) clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status) scaling_sens = F.depend(scaling_sens, clear_status)
@ -230,10 +247,15 @@ class BertSquadCell(nn.Cell):
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: if self.reducer_flag:
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
if not self.gpu_target:
init = F.depend(init, grads) init = F.depend(init, grads)
get_status = self.get_status(init) get_status = self.get_status(init)
init = F.depend(init, get_status) init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,)) flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed: if self.is_distributed:
flag_reduce = self.allreduce(flag_sum) flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce) cond = self.less_equal(self.base, flag_reduce)
@ -246,10 +268,12 @@ class BertSquadCell(nn.Cell):
self.optimizer(grads) self.optimizer(grads)
return (loss, cond) return (loss, cond)
class BertCLS(nn.Cell): class BertCLS(nn.Cell):
""" """
Train interface for classification finetuning task. Train interface for classification finetuning task.
""" """
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
assessment_method=""): assessment_method=""):
super(BertCLS, self).__init__() super(BertCLS, self).__init__()
@ -259,6 +283,7 @@ class BertCLS(nn.Cell):
self.num_labels = num_labels self.num_labels = num_labels
self.assessment_method = assessment_method self.assessment_method = assessment_method
self.is_training = is_training self.is_training = is_training
def construct(self, input_ids, input_mask, token_type_id, label_ids): def construct(self, input_ids, input_mask, token_type_id, label_ids):
logits = self.bert(input_ids, input_mask, token_type_id) logits = self.bert(input_ids, input_mask, token_type_id)
if self.assessment_method == "spearman_correlation": if self.assessment_method == "spearman_correlation":
@ -275,6 +300,7 @@ class BertNER(nn.Cell):
""" """
Train interface for sequence labeling finetuning task. Train interface for sequence labeling finetuning task.
""" """
def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False, def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False,
tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False): tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertNER, self).__init__() super(BertNER, self).__init__()
@ -288,6 +314,7 @@ class BertNER(nn.Cell):
self.loss = CrossEntropyCalculation(is_training) self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels self.num_labels = num_labels
self.use_crf = use_crf self.use_crf = use_crf
def construct(self, input_ids, input_mask, token_type_id, label_ids): def construct(self, input_ids, input_mask, token_type_id, label_ids):
logits = self.bert(input_ids, input_mask, token_type_id) logits = self.bert(input_ids, input_mask, token_type_id)
if self.use_crf: if self.use_crf:
@ -296,10 +323,12 @@ class BertNER(nn.Cell):
loss = self.loss(logits, label_ids, self.num_labels) loss = self.loss(logits, label_ids, self.num_labels)
return loss return loss
class BertSquad(nn.Cell): class BertSquad(nn.Cell):
''' '''
Train interface for SQuAD finetuning task. Train interface for SQuAD finetuning task.
''' '''
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquad, self).__init__() super(BertSquad, self).__init__()
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)

View File

@ -20,14 +20,18 @@ Bert finetune and evaluation model script.
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context
from .bert_model import BertModel from .bert_model import BertModel
class BertCLSModel(nn.Cell): class BertCLSModel(nn.Cell):
""" """
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
logits as the results of log_softmax is proportional to that of softmax. logits as the results of log_softmax is proportional to that of softmax.
""" """
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False,
assessment_method=""): assessment_method=""):
super(BertCLSModel, self).__init__() super(BertCLSModel, self).__init__()
@ -46,8 +50,7 @@ class BertCLSModel(nn.Cell):
self.assessment_method = assessment_method self.assessment_method = assessment_method
def construct(self, input_ids, input_mask, token_type_id): def construct(self, input_ids, input_mask, token_type_id):
_, pooled_output, _ = \ _, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask)
self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype) cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls) cls = self.dropout(cls)
logits = self.dense_1(cls) logits = self.dense_1(cls)
@ -56,10 +59,12 @@ class BertCLSModel(nn.Cell):
logits = self.log_softmax(logits) logits = self.log_softmax(logits)
return logits return logits
class BertSquadModel(nn.Cell): class BertSquadModel(nn.Cell):
''' '''
This class is responsible for SQuAD This class is responsible for SQuAD
''' '''
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquadModel, self).__init__() super(BertSquadModel, self).__init__()
if not is_training: if not is_training:
@ -73,22 +78,36 @@ class BertSquadModel(nn.Cell):
self.dtype = config.dtype self.dtype = config.dtype
self.log_softmax = P.LogSoftmax(axis=1) self.log_softmax = P.LogSoftmax(axis=1)
self.is_training = is_training self.is_training = is_training
self.gpu_target = context.get_context("device_target") == "GPU"
self.cast = P.Cast()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.shape = (-1, config.hidden_size)
self.origin_shape = (-1, config.seq_length, self.num_labels)
self.transpose_shape = (-1, self.num_labels, config.seq_length)
def construct(self, input_ids, input_mask, token_type_id): def construct(self, input_ids, input_mask, token_type_id):
"""Return the final logits as the results of log_softmax."""
sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
batch_size, seq_length, hidden_size = P.Shape()(sequence_output) sequence = self.reshape(sequence_output, self.shape)
sequence = P.Reshape()(sequence_output, (-1, hidden_size))
logits = self.dense1(sequence) logits = self.dense1(sequence)
logits = P.Cast()(logits, self.dtype) logits = self.cast(logits, self.dtype)
logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) logits = self.reshape(logits, self.origin_shape)
if self.gpu_target:
logits = self.transpose(logits, (0, 2, 1))
logits = self.log_softmax(self.reshape(logits, (-1, self.transpose_shape[-1])))
logits = self.transpose(self.reshape(logits, self.transpose_shape), (0, 2, 1))
else:
logits = self.log_softmax(logits) logits = self.log_softmax(logits)
return logits return logits
class BertNERModel(nn.Cell): class BertNERModel(nn.Cell):
""" """
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
""" """
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
use_one_hot_embeddings=False): use_one_hot_embeddings=False):
super(BertNERModel, self).__init__() super(BertNERModel, self).__init__()
@ -111,8 +130,7 @@ class BertNERModel(nn.Cell):
def construct(self, input_ids, input_mask, token_type_id): def construct(self, input_ids, input_mask, token_type_id):
"""Return the final logits as the results of log_softmax.""" """Return the final logits as the results of log_softmax."""
sequence_output, _, _ = \ sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
self.bert(input_ids, token_type_id, input_mask)
seq = self.dropout(sequence_output) seq = self.dropout(sequence_output)
seq = self.reshape(seq, self.shape) seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq) logits = self.dense_1(seq)

View File

@ -36,7 +36,7 @@ export_file_name: 'bert_squad'
file_format: 'AIR' file_format: 'AIR'
optimizer_cfg: optimizer_cfg:
optimizer: 'Lamb' optimizer: 'Momentum'
AdamWeightDecay: AdamWeightDecay:
learning_rate: 0.0001 # 1e-4 learning_rate: 0.0001 # 1e-4
end_learning_rate: 0.00000000001 # 1e-11 end_learning_rate: 0.00000000001 # 1e-11