diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index b2283e5c3c5..4b8654ed84d 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -26,13 +26,13 @@ namespace mindspore { namespace kernel { namespace { std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { - auto parallel_context_instance = parallel::ParallelContext::GetInstance(); - MS_EXCEPTION_IF_NULL(parallel_context_instance); - if (parallel_context_instance->enable_parallel_optimizer()) { - return kOpFormat_DEFAULT; - } const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; auto op_name = AnfAlgo::GetCNodeName(kernel_node); + auto parallel_context_instance = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context_instance); + if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { + return kOpFormat_DEFAULT; + } auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); if (op_name != kReduceScatter && op_name != kAllGatherOpName) { return format; diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 9da8f0a65ba..a0e4805b1af 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -65,6 +65,8 @@ void ParallelContext::Reset() { strategy_ckpt_load_file_ = ""; strategy_ckpt_save_file_ = ""; enable_parallel_optimizer_ = false; + all_reduce_fusion_split_indices_.clear(); + all_reduce_fusion_split_sizes_.clear(); } void ParallelContext::set_device_num(int32_t device_num) { diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index bd0d18347a5..a8033e1cfdd 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer): self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) if self.use_parallel: - optim_result = self.broadcast_params(optim_result) + self.broadcast_params(optim_result) return optim_result diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 0d2552b8c14..7fc79bf4183 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -312,7 +312,7 @@ class Lamb(Optimizer): self.decay_flags, self.optim_filter) if self.use_parallel: - optim_result = self.broadcast_params(optim_result) + self.broadcast_params(optim_result) if not self.dynamic_lr: F.control_depend(lr, self.assignadd(self.global_step, 1)) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 281bd045c82..9a4bf0fce2d 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -466,7 +466,7 @@ class Optimizer(Cell): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): - param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) + param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],) key = P.MakeRefKey(self.param_names[i])() key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) new_param_group = [] @@ -476,9 +476,9 @@ class Optimizer(Cell): new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) - status = True + status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): - status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) + status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) return status diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index e67e74d9efe..f2481aab3a7 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") -def _init_allreduce_operators(length): +def _init_allreduce_operators(length, split_indices): """ initialize allreduce communication operators""" group = 1 fusion = () @@ -318,7 +318,7 @@ class DistributedGradReducer(Cell): split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() if is_parallel_optimizer and split_indices: self.split_fusion = True - self.op_list = _init_allreduce_operators(len(parameters)) + self.op_list = _init_allreduce_operators(len(parameters), split_indices) else: self.split_fusion = False self.allreduce = AllReduce().add_prim_attr('fusion', 1) @@ -344,10 +344,10 @@ class DistributedGradReducer(Cell): if self.split_fusion: if self.enable_parameter_server: new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), - self.opt_list, self.allreduce_filter, grads, self.ps_parameters) + self.op_list, self.allreduce_filter, grads, self.ps_parameters) else: new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), - self.opt_list, self.allreduce_filter, grads) + self.op_list, self.allreduce_filter, grads) else: if self.enable_parameter_server: new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, diff --git a/mindspore/parallel/_dp_allreduce_fusion.py b/mindspore/parallel/_dp_allreduce_fusion.py index 3c7039dbd6d..ad78595d139 100644 --- a/mindspore/parallel/_dp_allreduce_fusion.py +++ b/mindspore/parallel/_dp_allreduce_fusion.py @@ -16,8 +16,6 @@ import ctypes -from mindspore import log as logger - _MAX_GROUP_NAME_LEN = 127 _HCCL_LIB = 'libhccl.so' @@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so' def _load_lib(): try: hccl_lib = ctypes.CDLL(_HCCL_LIB) - except RuntimeError: - logger.error('Get hccl lib error') + except Exception: + raise RuntimeError('Get hccl lib error') return hccl_lib @@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): try: lib_ctype = _load_lib() except RuntimeError: - logger.error('Load HCCL lib failed') - + import hccl_test.manage.api as hccl + hccl.set_fusion_strategy_by_idx() + return if isinstance(group, (str)): group_len = len(group) if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): @@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): try: lib_ctype = _load_lib() except RuntimeError: - logger.error('Load HCCL lib failed') + import hccl_test.manage.api as hccl + hccl.set_fusion_strategy_by_size() + return if isinstance(group, (str)): group_len = len(group) if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index f6b60b3d2ea..e44f824ce2c 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids): # pylint: disable=unused-argument def destroy_group(group): pass + + +# pylint: disable=unused-argument +def set_fusion_strategy_by_idx(): + pass + + +# pylint: disable=unused-argument +def set_fusion_strategy_by_size(): + pass diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index f6173b24c40..ca5fe0ac3e7 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.ops import operations as P from mindspore import context - +from mindspore.parallel._auto_parallel_context import auto_parallel_context class Net(nn.Cell): """Net definition""" @@ -64,6 +64,7 @@ def test_AdamWeightDecay(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context() def test_lamb_compile(): @@ -79,8 +80,25 @@ def test_lamb_compile(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context() +def test_lamb_split_fusion(): + """ test_Lamb_split_fusion """ + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + label = Tensor(np.zeros([32, 768]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Lamb(net.trainable_params(), learning_rate=0.1) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) + context.reset_auto_parallel_context() + def test_edge_case(): """ test_edge_case """ context.set_auto_parallel_context(enable_parallel_optimizer=True) @@ -93,3 +111,4 @@ def test_edge_case(): with pytest.raises(RuntimeError): context.set_auto_parallel_context(device_num=16) Lamb(net.trainable_params(), learning_rate=0.1) + context.reset_auto_parallel_context()