add ps filter

This commit is contained in:
jinyaohui 2020-07-21 09:32:35 +08:00
parent ea54018171
commit e97bf5b8ec
1 changed files with 11 additions and 4 deletions

View File

@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply. allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation. grad (Tensor): The gradient tensor before operation.
ps_parameter(Bool): Use parameter server or not. ps_parameter (bool): Use parameter server or not.
Returns: Returns:
Tensor, the gradient tensor after operation. Tensor, the gradient tensor after operation.
""" """
if not ps_parameter and allreduce_filter: if ps_parameter:
return grad
if allreduce_filter:
grad = allreduce(grad) grad = allreduce(grad)
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad)) degree = F.scalar_cast(degree, F.dtype(grad))
@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
return grad return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning. Allgather is a communication operation used for distributed deep learning.
@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allgather would apply. allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation. grad (tuple): The indices, gradient tensor and tensor_shape before operation.
ps_parameter (bool): Use parameter server or not.
Returns: Returns:
IndexedSlices, the gradient after operation. IndexedSlices, the gradient after operation.
""" """
if ps_parameter:
return grad
if allreduce_filter: if allreduce_filter:
indices = allgather(grad.indices()) indices = allgather(grad.indices())
dout = allgather(grad.values()) dout = allgather(grad.values())