forked from mindspore-Ecosystem/mindspore
The teacher use fp16 calculations to optimize the performance of tinybert on the gpu
This commit is contained in:
parent
f0988c7b16
commit
9e2b38c961
|
@ -87,13 +87,10 @@ def run_general_distill():
|
||||||
|
|
||||||
enable_loss_scale = True
|
enable_loss_scale = True
|
||||||
if args_opt.device_target == "GPU":
|
if args_opt.device_target == "GPU":
|
||||||
if bert_teacher_net_cfg.compute_type != mstype.float32:
|
|
||||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
|
||||||
bert_teacher_net_cfg.compute_type = mstype.float32
|
|
||||||
if bert_student_net_cfg.compute_type != mstype.float32:
|
if bert_student_net_cfg.compute_type != mstype.float32:
|
||||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
|
||||||
bert_student_net_cfg.compute_type = mstype.float32
|
bert_student_net_cfg.compute_type = mstype.float32
|
||||||
# Both the forward and backward of the network are calculated using fp32,
|
# Backward of the network are calculated using fp32,
|
||||||
# and the loss scale is not necessary
|
# and the loss scale is not necessary
|
||||||
enable_loss_scale = False
|
enable_loss_scale = False
|
||||||
|
|
||||||
|
|
|
@ -285,13 +285,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
enable_loss_scale = True
|
enable_loss_scale = True
|
||||||
if args_opt.device_target == "GPU":
|
if args_opt.device_target == "GPU":
|
||||||
if td_teacher_net_cfg.compute_type != mstype.float32:
|
|
||||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
|
||||||
td_teacher_net_cfg.compute_type = mstype.float32
|
|
||||||
if td_student_net_cfg.compute_type != mstype.float32:
|
if td_student_net_cfg.compute_type != mstype.float32:
|
||||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
logger.warning('Compute about the student only support float32 temporarily, run with float32.')
|
||||||
td_student_net_cfg.compute_type = mstype.float32
|
td_student_net_cfg.compute_type = mstype.float32
|
||||||
# Both the forward and backward of the network are calculated using fp32,
|
# Backward of the network are calculated using fp32,
|
||||||
# and the loss scale is not necessary
|
# and the loss scale is not necessary
|
||||||
enable_loss_scale = False
|
enable_loss_scale = False
|
||||||
|
|
||||||
|
|
|
@ -37,4 +37,5 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||||
--save_ckpt_path="" \
|
--save_ckpt_path="" \
|
||||||
--data_dir=$DATA_DIR \
|
--data_dir=$DATA_DIR \
|
||||||
--schema_dir=$SCHEMA_DIR \
|
--schema_dir=$SCHEMA_DIR \
|
||||||
|
--enable_data_sink=False \
|
||||||
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &
|
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &
|
||||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore import context
|
||||||
from .fused_layer_norm import FusedLayerNorm
|
from .fused_layer_norm import FusedLayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,11 +251,16 @@ class BertOutput(nn.Cell):
|
||||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
if compute_type == mstype.float16:
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||||
self.layernorm = FusedLayerNorm((out_channels,),
|
if self.is_gpu:
|
||||||
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
|
||||||
|
self.compute_type = compute_type
|
||||||
else:
|
else:
|
||||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
if compute_type == mstype.float16:
|
||||||
|
self.layernorm = FusedLayerNorm((out_channels,),
|
||||||
|
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
|
||||||
|
else:
|
||||||
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
@ -264,6 +270,8 @@ class BertOutput(nn.Cell):
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
output = self.add(input_tensor, output)
|
output = self.add(input_tensor, output)
|
||||||
output = self.layernorm(output)
|
output = self.layernorm(output)
|
||||||
|
if self.is_gpu:
|
||||||
|
output = self.cast(output, self.compute_type)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue