diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 47e3458c032..e67e74d9efe 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -45,8 +45,35 @@ def _init_allreduce_operators(length): return op_list +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor") +def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): + """ + Apply allreduce on gradient. + + Args: + degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. + allreduce_filter (bool): When it is true, allreduce would apply. + grad (Tensor): The gradient tensor before operation. + + Returns: + Tensor, the gradient tensor after operation. + """ + if allreduce_filter: + grad = allreduce(grad) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad)) + cast_op = P.Cast() + mul_op = P.Mul() + grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) + return grad + return grad + + @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") -def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): +def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ Apply allreduce on gradient. @@ -76,8 +103,37 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra return grad +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") +def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): + """ + Apply allgather on gradient instead of allreduce for sparse feature. + Allgather is a communication operation used for distributed deep learning. + + Args: + degree (int): The mean coefficient. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. + allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. + allreduce_filter (bool): When it is true, allgather would apply. + grad (tuple): The indices, gradient tensor and tensor_shape before operation. + + Returns: + IndexedSlices, the gradient after operation. + """ + if allreduce_filter: + indices = allgather(grad.indices()) + dout = allgather(grad.values()) + if mean: + degree = F.scalar_cast(degree, F.dtype(grad.values())) + cast_op = P.Cast() + mul_op = P.Mul() + dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) + grad = IndexedSlices(indices, dout, grad.dense_shape()) + return grad + + @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") -def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): +def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ Apply allgather on gradient instead of allreduce for sparse feature. Allgather is a communication operation used for distributed deep learning. @@ -269,6 +325,7 @@ class DistributedGradReducer(Cell): self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in parameters) + self.enable_parameter_server = any(self.ps_parameters) def construct(self, grads): """ @@ -285,10 +342,18 @@ class DistributedGradReducer(Cell): datatypes = self.map_(F.partial(_get_datatype), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) if self.split_fusion: - new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), - self.opt_list, self.allreduce_filter, grads, self.ps_parameters) + if self.enable_parameter_server: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.opt_list, self.allreduce_filter, grads, self.ps_parameters) + else: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.opt_list, self.allreduce_filter, grads) else: - new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, - self.allreduce), self.allreduce_filter, grads, self.ps_parameters) + if self.enable_parameter_server: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, + self.allreduce), self.allreduce_filter, grads, self.ps_parameters) + else: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, + self.allreduce), self.allreduce_filter, grads) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad