!1737 sparse feature backpropagation

Merge pull request !1737 from lirongzhen1/sparse
This commit is contained in:
mindspore-ci-bot 2020-06-12 09:51:41 +08:00 committed by Gitee
commit 5b0472683c
3 changed files with 206 additions and 12 deletions

View File

@ -52,6 +52,31 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
return grad
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
"""
Apply mean and allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
Args:
mul (Primitive): Div operation.
degree (int): The mean coefficient.
allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation.
"""
if allreduce_filter:
indices = _all_gather(grad[0])
degree = F.scalar_cast(degree, F.dtype(grad[1]))
dout = _all_gather(grad[1])
cast_op = P.Cast()
dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout)))
grad = (indices, dout, dout[2])
return grad
@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
"""
@ -69,6 +94,26 @@ def _tensors_allreduce(allreduce_filter, grad):
return grad
@reduce_opt.register("Bool", "Tuple")
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
"""
Apply mean and allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
Args:
allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation.
"""
if allreduce_filter:
indices = _all_gather(grad[0])
dout = _all_gather(grad[1])
grad = (indices, dout, dout[2])
return grad
_get_datatype = C.MultitypeFuncGraph("_get_datatype")

View File

@ -26,9 +26,10 @@ from .grad_base import bprop_getters
@bprop_getters.register(AllReduce)
def get_bprop_all_reduce(self):
"""Generate bprop for AllReduce."""
"""Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
all_gather = AllGather(group=self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_reduce_grad.set_prim_instance_name(instance_name)
@ -42,15 +43,28 @@ def get_bprop_all_reduce(self):
if self.op == ReduceOp.SUM:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
else:
indices = all_gather(dout[0])
grad = all_gather(dout[1])
dx = (indices, grad, dout[2])
return (dx,)
else:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
z = equal(x, out)
z = cast(z, dtype(dx))
dx = mul(dx, z)
else:
indices = all_gather(dout[0])
grad = all_gather(dout[1])
z = equal(x, out)
z = cast(z, dtype(grad))
grad = mul(grad, z)
dx = (indices, grad, dout[2])
return (dx,)
return bprop
@ -147,12 +161,16 @@ def get_bprop_all_to_all(self):
@bprop_getters.register(_MirrorOperator)
def get_bprop_mirror_operator(self):
"""Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group)."""
"""
Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
@ -170,12 +188,25 @@ def get_bprop_mirror_operator(self):
def bprop(x, out, dout):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
indices = all_gather(dout[0])
grad = all_gather(dout[1])
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = (indices, grad, dout[2])
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
else:
indices = all_gather(dout[0])
grad = all_gather(dout[1])
dx = (indices, grad, dout[2])
return (dx,)
return bprop

View File

@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test sparse feature bprop """
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from mindspore.common.api import _executor
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x):
return C.grad_all(self.network)(x)
class VirtualGatherV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init index_select"""
super(VirtualGatherV2, self).__init__('VirtualGatherV2')
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
@bprop_getters.register(VirtualGatherV2)
def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2"""
def bprop(x, indices, axis, out, dout):
return (indices, dout, x), axis, out
return bprop
def test_bprop_with_sparse_feature_allreduce():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
class Net(nn.Cell):
def __init__(self, axis=0, shape=None):
super(Net, self).__init__()
if shape is None:
shape = [8, 8]
self.all_reduce = AllReduce()
self.gatherv2 = VirtualGatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
def construct(self, x):
out = self.all_reduce(x)
out = self.gatherv2(out, self.index, self.axis)
return out
net = GradWrap(Net())
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x)
def test_bprop_with_sparse_feature_mirror():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
class Net(nn.Cell):
def __init__(self, axis=0, shape=None):
super(Net, self).__init__()
if shape is None:
shape = [8, 8]
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
self.gatherv2 = VirtualGatherV2()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
def construct(self, x):
out = self.mirror(x)
out = self.gatherv2(out, self.index, self.axis)
return out
net = GradWrap(Net())
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x)