!18566 Remove communication fusion value for PanGu Model

Merge pull request !18566 from huangxinjing/rm_type_check
This commit is contained in:
i-robot 2021-06-22 14:06:03 +00:00 committed by Gitee
commit 03cc6449c2
1 changed files with 2 additions and 16 deletions

View File

@ -731,16 +731,10 @@ class PanguAlpha_Model(nn.Cell):
for i in range(num_layers): for i in range(num_layers):
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved, # Each layer will be recomputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almosttotally recomputed. # in other words, in backward process each block will be almost totally recomputed.
if config.use_recompute: if config.use_recompute:
per_block.recompute() per_block.recompute()
if config.param_init_type == mstype.float16:
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
# so we fuse communications of layernorm to a large value(+100)
per_block.layernorm1.set_comm_fusion(int(int(i / fusion_group_size) + 100))
per_block.layernorm2.set_comm_fusion(int(int(i / fusion_group_size) + 100))
self.blocks.append(per_block) self.blocks.append(per_block)
if config.self_layernorm: if config.self_layernorm:
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float( self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
@ -753,11 +747,6 @@ class PanguAlpha_Model(nn.Cell):
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm.gamma.parallel_optimizer = False self.layernorm.gamma.parallel_optimizer = False
self.layernorm.beta.parallel_optimizer = False self.layernorm.beta.parallel_optimizer = False
if config.param_init_type == mstype.float16:
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
# so we fuse communications of layernorm to a large value(+100)
self.layernorm.set_comm_fusion(int(num_layers / fusion_group_size + 100))
self.use_past = config.use_past self.use_past = config.use_past
self.past = tuple([None] * config.num_layers) self.past = tuple([None] * config.num_layers)
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
@ -795,10 +784,7 @@ class PanguAlpha_Model(nn.Cell):
self.top_query_layer = QueryLayer(config) self.top_query_layer = QueryLayer(config)
if config.use_recompute: if config.use_recompute:
self.top_query_layer.recompute() self.top_query_layer.recompute()
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2) self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
self.top_query_layer.layernorm1.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
self.top_query_layer.layernorm2.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
self.use_top_query_attention = config.use_top_query_attention self.use_top_query_attention = config.use_top_query_attention