forked from OSSInnovation/mindspore
add ps filter
This commit is contained in:
parent
ea54018171
commit
e97bf5b8ec
|
@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
|||
allreduce (Primitive): The communication operator for gradients.
|
||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
ps_parameter(Bool): Use parameter server or not.
|
||||
ps_parameter (bool): Use parameter server or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
if not ps_parameter and allreduce_filter:
|
||||
if ps_parameter:
|
||||
return grad
|
||||
|
||||
if allreduce_filter:
|
||||
grad = allreduce(grad)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
|
@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
|||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
|
||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
|
||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
||||
"""
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
Allgather is a communication operation used for distributed deep learning.
|
||||
|
@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
|
|||
allreduce (Primitive): The communication operator for gradients.
|
||||
allreduce_filter (bool): When it is true, allgather would apply.
|
||||
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
ps_parameter (bool): Use parameter server or not.
|
||||
|
||||
Returns:
|
||||
IndexedSlices, the gradient after operation.
|
||||
"""
|
||||
if ps_parameter:
|
||||
return grad
|
||||
|
||||
if allreduce_filter:
|
||||
indices = allgather(grad.indices())
|
||||
dout = allgather(grad.values())
|
||||
|
|
Loading…
Reference in New Issue