!4517 The teacher use fp16 calculations to imporve the performance of tinybert on gpu

Merge pull request !4517 from hanhuifeng/tinybert_gpu_perf
This commit is contained in:
mindspore-ci-bot 2020-08-18 15:11:18 +08:00 committed by Gitee
commit 4815dcbb1d
4 changed files with 17 additions and 14 deletions

View File

@ -87,13 +87,10 @@ def run_general_distill():
enable_loss_scale = True
if args_opt.device_target == "GPU":
if bert_teacher_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_teacher_net_cfg.compute_type = mstype.float32
if bert_student_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
bert_student_net_cfg.compute_type = mstype.float32
# Both the forward and backward of the network are calculated using fp32,
# Backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale = False

View File

@ -285,13 +285,10 @@ if __name__ == '__main__':
enable_loss_scale = True
if args_opt.device_target == "GPU":
if td_teacher_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
td_teacher_net_cfg.compute_type = mstype.float32
if td_student_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
td_student_net_cfg.compute_type = mstype.float32
# Both the forward and backward of the network are calculated using fp32,
# Backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale = False

View File

@ -37,4 +37,5 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
--save_ckpt_path="" \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR \
--enable_data_sink=False \
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &

View File

@ -24,6 +24,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore import context
from .fused_layer_norm import FusedLayerNorm
@ -250,6 +251,11 @@ class BertOutput(nn.Cell):
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.add = P.TensorAdd()
self.is_gpu = context.get_context('device_target') == "GPU"
if self.is_gpu:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
self.compute_type = compute_type
else:
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
@ -264,6 +270,8 @@ class BertOutput(nn.Cell):
output = self.dropout(output)
output = self.add(input_tensor, output)
output = self.layernorm(output)
if self.is_gpu:
output = self.cast(output, self.compute_type)
return output