!18566 Remove communication fusion value for PanGu Model
Merge pull request !18566 from huangxinjing/rm_type_check
This commit is contained in:
commit
03cc6449c2
|
@ -731,16 +731,10 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
# Each layer will be recomputed in the backward process. The output activation of each layer will be saved,
|
||||||
# in other words, in backward process each block will be almosttotally recomputed.
|
# in other words, in backward process each block will be almost totally recomputed.
|
||||||
if config.use_recompute:
|
if config.use_recompute:
|
||||||
per_block.recompute()
|
per_block.recompute()
|
||||||
if config.param_init_type == mstype.float16:
|
|
||||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
|
||||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
|
||||||
# so we fuse communications of layernorm to a large value(+100)
|
|
||||||
per_block.layernorm1.set_comm_fusion(int(int(i / fusion_group_size) + 100))
|
|
||||||
per_block.layernorm2.set_comm_fusion(int(int(i / fusion_group_size) + 100))
|
|
||||||
self.blocks.append(per_block)
|
self.blocks.append(per_block)
|
||||||
if config.self_layernorm:
|
if config.self_layernorm:
|
||||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
|
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
|
||||||
|
@ -753,11 +747,6 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||||
self.layernorm.gamma.parallel_optimizer = False
|
self.layernorm.gamma.parallel_optimizer = False
|
||||||
self.layernorm.beta.parallel_optimizer = False
|
self.layernorm.beta.parallel_optimizer = False
|
||||||
if config.param_init_type == mstype.float16:
|
|
||||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
|
||||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
|
||||||
# so we fuse communications of layernorm to a large value(+100)
|
|
||||||
self.layernorm.set_comm_fusion(int(num_layers / fusion_group_size + 100))
|
|
||||||
self.use_past = config.use_past
|
self.use_past = config.use_past
|
||||||
self.past = tuple([None] * config.num_layers)
|
self.past = tuple([None] * config.num_layers)
|
||||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
|
@ -795,10 +784,7 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
self.top_query_layer = QueryLayer(config)
|
self.top_query_layer = QueryLayer(config)
|
||||||
if config.use_recompute:
|
if config.use_recompute:
|
||||||
self.top_query_layer.recompute()
|
self.top_query_layer.recompute()
|
||||||
|
|
||||||
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
|
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
|
||||||
self.top_query_layer.layernorm1.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
|
||||||
self.top_query_layer.layernorm2.set_comm_fusion(int(config.num_layers / fusion_group_size + 100))
|
|
||||||
|
|
||||||
self.use_top_query_attention = config.use_top_query_attention
|
self.use_top_query_attention = config.use_top_query_attention
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue