forked from OSSInnovation/mindspore
add ps filter
This commit is contained in:
parent
ea54018171
commit
e97bf5b8ec
|
@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
||||||
allreduce (Primitive): The communication operator for gradients.
|
allreduce (Primitive): The communication operator for gradients.
|
||||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||||
grad (Tensor): The gradient tensor before operation.
|
grad (Tensor): The gradient tensor before operation.
|
||||||
ps_parameter(Bool): Use parameter server or not.
|
ps_parameter (bool): Use parameter server or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the gradient tensor after operation.
|
Tensor, the gradient tensor after operation.
|
||||||
"""
|
"""
|
||||||
if not ps_parameter and allreduce_filter:
|
if ps_parameter:
|
||||||
|
return grad
|
||||||
|
|
||||||
|
if allreduce_filter:
|
||||||
grad = allreduce(grad)
|
grad = allreduce(grad)
|
||||||
if mean:
|
if mean:
|
||||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||||
|
@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
|
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
|
||||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
||||||
"""
|
"""
|
||||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||||
Allgather is a communication operation used for distributed deep learning.
|
Allgather is a communication operation used for distributed deep learning.
|
||||||
|
@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
|
||||||
allreduce (Primitive): The communication operator for gradients.
|
allreduce (Primitive): The communication operator for gradients.
|
||||||
allreduce_filter (bool): When it is true, allgather would apply.
|
allreduce_filter (bool): When it is true, allgather would apply.
|
||||||
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||||
|
ps_parameter (bool): Use parameter server or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IndexedSlices, the gradient after operation.
|
IndexedSlices, the gradient after operation.
|
||||||
"""
|
"""
|
||||||
|
if ps_parameter:
|
||||||
|
return grad
|
||||||
|
|
||||||
if allreduce_filter:
|
if allreduce_filter:
|
||||||
indices = allgather(grad.indices())
|
indices = allgather(grad.indices())
|
||||||
dout = allgather(grad.values())
|
dout = allgather(grad.values())
|
||||||
|
|
Loading…
Reference in New Issue