forked from mindspore-Ecosystem/mindspore
add the reshape part of the embeddinglookup backward operator
This commit is contained in:
parent
5c4731b772
commit
1cfb52bc0e
|
@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend";
|
||||||
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
||||||
|
|
||||||
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
||||||
constexpr char TARGET[] = "target";
|
constexpr char TARGET[] = "primitive_target";
|
||||||
constexpr char CPU[] = "CPU";
|
constexpr char CPU[] = "CPU";
|
||||||
constexpr char TRANSPOSE_A[] = "transpose_a";
|
constexpr char TRANSPOSE_A[] = "transpose_a";
|
||||||
constexpr char TRANSPOSE_B[] = "transpose_b";
|
constexpr char TRANSPOSE_B[] = "transpose_b";
|
||||||
|
|
|
@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
from ..._checkparam import Validator as validator, Rel
|
from ..._checkparam import Validator as validator, Rel
|
||||||
from .._utils import get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
from .. import functional as F
|
||||||
|
|
||||||
|
|
||||||
class AbsGrad(PrimitiveWithInfer):
|
class AbsGrad(PrimitiveWithInfer):
|
||||||
|
@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Perform the gradient for the communication part of EmbeddingLookup operator.
|
||||||
|
|
||||||
|
This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
|
||||||
|
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
||||||
|
self.add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
||||||
|
def __infer__(self, dy, split_num):
|
||||||
|
"""
|
||||||
|
This primitive is implemented by three steps:
|
||||||
|
1) Split the 'dy' along dimension 0 into 'split_num' parts.
|
||||||
|
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
|
||||||
|
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
|
||||||
|
along dimension 0.
|
||||||
|
|
||||||
|
The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
|
||||||
|
"""
|
||||||
|
dy_shape = tuple(dy['shape'])
|
||||||
|
split_num_value = split_num['value']
|
||||||
|
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
|
||||||
|
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
|
||||||
|
return {'shape': dy_shape_all,
|
||||||
|
'dtype': dy['dtype'],
|
||||||
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
class RefToEmbed(Primitive):
|
class RefToEmbed(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Make a key from Ref.
|
Make a key from Ref.
|
||||||
|
|
|
@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
self.__setattr_flag__ = True
|
self.__setattr_flag__ = True
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.add_prim_attr('target', 'CPU')
|
self.add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
||||||
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
|
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
|
||||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
|
|
|
@ -45,11 +45,11 @@ class GradWrap(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
|
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = [64, 64]
|
shape = [64, 64]
|
||||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
||||||
self.mul = P.Mul().set_strategy(strategy2)
|
self.mul = P.Mul().set_strategy(strategy2)
|
||||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
strategy1 = ((8, 1), (1, 1))
|
strategy1 = ((8, 1), (1, 1))
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
|
||||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
strategy1 = ((16, 1), (1, 1))
|
strategy1 = ((16, 1), (1, 1))
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
strategy1 = ((1, 8), (1, 1))
|
strategy1 = ((1, 8), (1, 1))
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
|
Loading…
Reference in New Issue