add ps filter

This commit is contained in:
jinyaohui 2020-07-21 09:32:35 +08:00
parent ea54018171
commit e97bf5b8ec
1 changed files with 11 additions and 4 deletions

View File

@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
ps_parameter(Bool): Use parameter server or not.
ps_parameter (bool): Use parameter server or not.
Returns:
Tensor, the gradient tensor after operation.
"""
if not ps_parameter and allreduce_filter:
if ps_parameter:
return grad
if allreduce_filter:
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
"""
Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
ps_parameter (bool): Use parameter server or not.
Returns:
IndexedSlices, the gradient after operation.
"""
if ps_parameter:
return grad
if allreduce_filter:
indices = allgather(grad.indices())
dout = allgather(grad.values())