forked from mindspore-Ecosystem/mindspore
!1737 sparse feature backpropagation
Merge pull request !1737 from lirongzhen1/sparse
This commit is contained in:
commit
5b0472683c
|
@ -52,6 +52,31 @@ def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
|
|||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
|
||||
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
|
||||
"""
|
||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
||||
Allgather is a communication operation used for distributed deep learning.
|
||||
|
||||
Args:
|
||||
mul (Primitive): Div operation.
|
||||
degree (int): The mean coefficient.
|
||||
allreduce_filter (bool): When it is true, allgather would apply.
|
||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
|
||||
Returns:
|
||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
indices = _all_gather(grad[0])
|
||||
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
||||
dout = _all_gather(grad[1])
|
||||
cast_op = P.Cast()
|
||||
dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout)))
|
||||
grad = (indices, dout, dout[2])
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tensor")
|
||||
def _tensors_allreduce(allreduce_filter, grad):
|
||||
"""
|
||||
|
@ -69,6 +94,26 @@ def _tensors_allreduce(allreduce_filter, grad):
|
|||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tuple")
|
||||
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
|
||||
"""
|
||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
||||
Allgather is a communication operation used for distributed deep learning.
|
||||
|
||||
Args:
|
||||
allreduce_filter (bool): When it is true, allgather would apply.
|
||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
|
||||
Returns:
|
||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
indices = _all_gather(grad[0])
|
||||
dout = _all_gather(grad[1])
|
||||
grad = (indices, dout, dout[2])
|
||||
return grad
|
||||
|
||||
|
||||
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
||||
|
||||
|
||||
|
|
|
@ -26,9 +26,10 @@ from .grad_base import bprop_getters
|
|||
|
||||
@bprop_getters.register(AllReduce)
|
||||
def get_bprop_all_reduce(self):
|
||||
"""Generate bprop for AllReduce."""
|
||||
"""Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
|
||||
|
||||
all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
|
||||
all_gather = AllGather(group=self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
all_reduce_grad.set_prim_instance_name(instance_name)
|
||||
|
@ -42,15 +43,28 @@ def get_bprop_all_reduce(self):
|
|||
if self.op == ReduceOp.SUM:
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = all_reduce_grad(dout)
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce_grad(dout)
|
||||
else:
|
||||
indices = all_gather(dout[0])
|
||||
grad = all_gather(dout[1])
|
||||
dx = (indices, grad, dout[2])
|
||||
return (dx,)
|
||||
else:
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = all_reduce_grad(dout)
|
||||
z = equal(x, out)
|
||||
z = cast(z, dtype(dx))
|
||||
dx = mul(dx, z)
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce_grad(dout)
|
||||
z = equal(x, out)
|
||||
z = cast(z, dtype(dx))
|
||||
dx = mul(dx, z)
|
||||
else:
|
||||
indices = all_gather(dout[0])
|
||||
grad = all_gather(dout[1])
|
||||
z = equal(x, out)
|
||||
z = cast(z, dtype(grad))
|
||||
grad = mul(grad, z)
|
||||
dx = (indices, grad, dout[2])
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
@ -147,12 +161,16 @@ def get_bprop_all_to_all(self):
|
|||
|
||||
@bprop_getters.register(_MirrorOperator)
|
||||
def get_bprop_mirror_operator(self):
|
||||
"""Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group)."""
|
||||
"""
|
||||
Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
|
||||
allgather for sparse feature.
|
||||
"""
|
||||
group = self.group
|
||||
dev_num = self.dev_num
|
||||
mean_flag = self.mean_flag
|
||||
|
||||
all_reduce = AllReduce(group=group)
|
||||
all_gather = AllGather(group=group)
|
||||
mul = P.Mul()
|
||||
cast = P.Cast()
|
||||
|
||||
|
@ -170,12 +188,25 @@ def get_bprop_mirror_operator(self):
|
|||
|
||||
def bprop(x, out, dout):
|
||||
if mean_flag:
|
||||
dx = all_reduce(dout)
|
||||
float_one = F.scalar_cast(1.0, F.dtype(dx))
|
||||
num = F.scalar_cast(dev_num, F.dtype(dx))
|
||||
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce(dout)
|
||||
float_one = F.scalar_cast(1.0, F.dtype(dx))
|
||||
num = F.scalar_cast(dev_num, F.dtype(dx))
|
||||
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
|
||||
else:
|
||||
indices = all_gather(dout[0])
|
||||
grad = all_gather(dout[1])
|
||||
float_one = F.scalar_cast(1.0, F.dtype(grad))
|
||||
num = F.scalar_cast(dev_num, F.dtype(grad))
|
||||
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
|
||||
dx = (indices, grad, dout[2])
|
||||
else:
|
||||
dx = all_reduce(dout)
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
dx = all_reduce(dout)
|
||||
else:
|
||||
indices = all_gather(dout[0])
|
||||
grad = all_gather(dout[1])
|
||||
dx = (indices, grad, dout[2])
|
||||
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test sparse feature bprop """
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, _MirrorOperator
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
return C.grad_all(self.network)(x)
|
||||
|
||||
class VirtualGatherV2(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init index_select"""
|
||||
super(VirtualGatherV2, self).__init__('VirtualGatherV2')
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
@bprop_getters.register(VirtualGatherV2)
|
||||
def get_bprop_gather_v2(self):
|
||||
"""Generate bprop for GatherV2"""
|
||||
|
||||
def bprop(x, indices, axis, out, dout):
|
||||
return (indices, dout, x), axis, out
|
||||
|
||||
return bprop
|
||||
|
||||
def test_bprop_with_sparse_feature_allreduce():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, shape=None):
|
||||
super(Net, self).__init__()
|
||||
if shape is None:
|
||||
shape = [8, 8]
|
||||
self.all_reduce = AllReduce()
|
||||
self.gatherv2 = VirtualGatherV2()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x):
|
||||
out = self.all_reduce(x)
|
||||
out = self.gatherv2(out, self.index, self.axis)
|
||||
|
||||
return out
|
||||
|
||||
net = GradWrap(Net())
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_bprop_with_sparse_feature_mirror():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, shape=None):
|
||||
super(Net, self).__init__()
|
||||
if shape is None:
|
||||
shape = [8, 8]
|
||||
self.mirror = _MirrorOperator(group=HCCL_WORLD_COMM_GROUP)
|
||||
self.gatherv2 = VirtualGatherV2()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x):
|
||||
out = self.mirror(x)
|
||||
out = self.gatherv2(out, self.index, self.axis)
|
||||
|
||||
return out
|
||||
|
||||
net = GradWrap(Net())
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x)
|
Loading…
Reference in New Issue