!12311 adjust second allreduce split indices for thor optimizer

From: @sl_wang
Reviewed-by: @wang_zi_dong,@kisnwang
Signed-off-by: @kisnwang
This commit is contained in:
mindspore-ci-bot 2021-02-10 11:26:29 +08:00 committed by Gitee
commit 725af9c1bb
1 changed files with 8 additions and 8 deletions

View File

@ -281,9 +281,9 @@ class THOR_GPU(Optimizer):
degree = _get_device_num()
if self.conv_layer_count > 0:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.matrix_A) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
@ -294,9 +294,9 @@ class THOR_GPU(Optimizer):
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
else:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.params) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3")
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3)
@ -595,9 +595,9 @@ class THOR_Ascend(Optimizer):
degree = _get_device_num()
if self.conv_layer_count > 0:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.matrix_A) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
@ -608,9 +608,9 @@ class THOR_Ascend(Optimizer):
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
else:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.params) - 1]
else:
self.split_indices = split_indices
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3")
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3)