add the reshape part of the embeddinglookup backward operator

This commit is contained in:
Xiaoda Zhang 2020-06-01 11:38:09 +08:00
parent 5c4731b772
commit 1cfb52bc0e
4 changed files with 39 additions and 7 deletions

View File

@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend";
constexpr char BATCH_PARALLEL[] = "BatchParallel";
constexpr char ACTIVATION_TYPE[] = "activation_type";
constexpr char TARGET[] = "target";
constexpr char TARGET[] = "primitive_target";
constexpr char CPU[] = "CPU";
constexpr char TRANSPOSE_A[] = "transpose_a";
constexpr char TRANSPOSE_B[] = "transpose_b";

View File

@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset
from ...common import dtype as mstype
from .. import functional as F
class AbsGrad(PrimitiveWithInfer):
@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer):
'value': None}
class EmbeddingLookupCommGrad(PrimitiveWithInfer):
"""
Perform the gradient for the communication part of EmbeddingLookup operator.
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')
def __infer__(self, dy, split_num):
"""
This primitive is implemented by three steps:
1) Split the 'dy' along dimension 0 into 'split_num' parts.
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
along dimension 0.
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
"""
dy_shape = tuple(dy['shape'])
split_num_value = split_num['value']
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
return {'shape': dy_shape_all,
'dtype': dy['dtype'],
'value': None}
class RefToEmbed(Primitive):
r"""
Make a key from Ref.

View File

@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('target', 'CPU')
self.add_prim_attr('primitive_target', 'CPU')
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)

View File

@ -45,11 +45,11 @@ class GradWrap(nn.Cell):
class Net(nn.Cell):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().set_strategy(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)