forked from mindspore-Ecosystem/mindspore
add the reshape part of the embeddinglookup backward operator
This commit is contained in:
parent
5c4731b772
commit
1cfb52bc0e
|
@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend";
|
|||
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
||||
|
||||
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
||||
constexpr char TARGET[] = "target";
|
||||
constexpr char TARGET[] = "primitive_target";
|
||||
constexpr char CPU[] = "CPU";
|
||||
constexpr char TRANSPOSE_A[] = "transpose_a";
|
||||
constexpr char TRANSPOSE_B[] = "transpose_b";
|
||||
|
|
|
@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
|||
from ..._checkparam import Validator as validator, Rel
|
||||
from .._utils import get_concat_offset
|
||||
from ...common import dtype as mstype
|
||||
from .. import functional as F
|
||||
|
||||
|
||||
class AbsGrad(PrimitiveWithInfer):
|
||||
|
@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer):
|
|||
'value': None}
|
||||
|
||||
|
||||
class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
Perform the gradient for the communication part of EmbeddingLookup operator.
|
||||
|
||||
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
|
||||
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def __infer__(self, dy, split_num):
|
||||
"""
|
||||
This primitive is implemented by three steps:
|
||||
1) Split the 'dy' along dimension 0 into 'split_num' parts.
|
||||
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
|
||||
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
|
||||
along dimension 0.
|
||||
|
||||
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
|
||||
"""
|
||||
dy_shape = tuple(dy['shape'])
|
||||
split_num_value = split_num['value']
|
||||
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
|
||||
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
|
||||
return {'shape': dy_shape_all,
|
||||
'dtype': dy['dtype'],
|
||||
'value': None}
|
||||
|
||||
|
||||
class RefToEmbed(Primitive):
|
||||
r"""
|
||||
Make a key from Ref.
|
||||
|
|
|
@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('target', 'CPU')
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
|
|
|
@ -45,11 +45,11 @@ class GradWrap(nn.Cell):
|
|||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
|
||||
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
|
||||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
|
|||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
|
Loading…
Reference in New Issue