From e97bf5b8ec6c4acc84a1f6c59a446485b611e74c Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Tue, 21 Jul 2020 09:32:35 +0800 Subject: [PATCH] add ps filter --- mindspore/nn/wrap/grad_reducer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 930cabf478..47e3458c03 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra allreduce (Primitive): The communication operator for gradients. allreduce_filter (bool): When it is true, allreduce would apply. grad (Tensor): The gradient tensor before operation. - ps_parameter(Bool): Use parameter server or not. + ps_parameter (bool): Use parameter server or not. Returns: Tensor, the gradient tensor after operation. """ - if not ps_parameter and allreduce_filter: + if ps_parameter: + return grad + + if allreduce_filter: grad = allreduce(grad) if mean: degree = F.scalar_cast(degree, F.dtype(grad)) @@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra return grad -@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") -def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") +def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ Apply allgather on gradient instead of allreduce for sparse feature. Allgather is a communication operation used for distributed deep learning. @@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce allreduce (Primitive): The communication operator for gradients. allreduce_filter (bool): When it is true, allgather would apply. grad (tuple): The indices, gradient tensor and tensor_shape before operation. + ps_parameter (bool): Use parameter server or not. Returns: IndexedSlices, the gradient after operation. """ + if ps_parameter: + return grad + if allreduce_filter: indices = allgather(grad.indices()) dout = allgather(grad.values())