forked from OSSInnovation/mindspore
!2163 [Auto parallel] Move 'EmbeddingLookup' to internal
Merge pull request !2163 from Xiaoda/3-changing-embeddinglookup-internal
This commit is contained in:
commit
f513bb8f83
|
@ -191,7 +191,7 @@ def get_bprop_tile(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.EmbeddingLookup)
|
@bprop_getters.register(inner.EmbeddingLookup)
|
||||||
def get_bprop_embedding_lookup(self):
|
def get_bprop_embedding_lookup(self):
|
||||||
"""Generate bprop for EmbeddingLookup"""
|
"""Generate bprop for EmbeddingLookup"""
|
||||||
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
|
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
|
@ -26,7 +26,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split, EmbeddingLookup,
|
Shape, Size, Slice, Split,
|
||||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
|
@ -138,7 +138,6 @@ __all__ = [
|
||||||
'ReduceSum',
|
'ReduceSum',
|
||||||
'ReduceMean',
|
'ReduceMean',
|
||||||
'LayerNorm',
|
'LayerNorm',
|
||||||
'EmbeddingLookup',
|
|
||||||
'Rank',
|
'Rank',
|
||||||
'Less',
|
'Less',
|
||||||
'LessEqual',
|
'LessEqual',
|
||||||
|
|
|
@ -258,3 +258,73 @@ class AscendDequant(PrimitiveWithInfer):
|
||||||
validator.check_type_name("x", x_type, [mstype.int32], self.name)
|
validator.check_type_name("x", x_type, [mstype.int32], self.name)
|
||||||
validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
|
validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
|
||||||
return mstype.float16
|
return mstype.float16
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Returns a slice of input tensor based on the specified indices.
|
||||||
|
|
||||||
|
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
|
||||||
|
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
The Tensor slice, instead of the entire Tensor.
|
||||||
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||||
|
and the exceeding part will be filled with 0 in the output.
|
||||||
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
||||||
|
are equal to `input_indices` minus `offset`.
|
||||||
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
||||||
|
Only constant value is allowed.
|
||||||
|
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
||||||
|
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
|
||||||
|
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||||
|
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
||||||
|
>>> offset = 4
|
||||||
|
>>> reduce_scatter_flag = False
|
||||||
|
>>> split_num = 1
|
||||||
|
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
|
||||||
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init index_select"""
|
||||||
|
self.__setattr_flag__ = True
|
||||||
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
|
||||||
|
outputs=['output'])
|
||||||
|
self.add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
||||||
|
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
|
||||||
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||||
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||||
|
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
||||||
|
if split_num['value'] < 1:
|
||||||
|
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
||||||
|
params_shp = params['shape']
|
||||||
|
out_shape = indices['shape'] + params_shp[1:]
|
||||||
|
if reduce_scatter_flag is None:
|
||||||
|
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
||||||
|
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
||||||
|
if split_num is None:
|
||||||
|
raise ValueError("The value of 'split_num_value' is None.")
|
||||||
|
split_num_value = split_num['value']
|
||||||
|
if reduce_scatter_flag_value is True:
|
||||||
|
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
|
||||||
|
# (split_num * 8)
|
||||||
|
if out_shape[0] % (split_num_value * 8) != 0:
|
||||||
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
|
||||||
|
(out_shape[0], (split_num_value * 8)))
|
||||||
|
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
|
||||||
|
out_shape[0] = out_shape[0] // 8
|
||||||
|
out = {'shape': out_shape,
|
||||||
|
'dtype': params['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
|
@ -558,76 +558,6 @@ class SparseGatherV2(GatherV2):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingLookup(PrimitiveWithInfer):
|
|
||||||
"""
|
|
||||||
Returns a slice of input tensor based on the specified indices.
|
|
||||||
|
|
||||||
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
|
|
||||||
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
||||||
The Tensor slice, instead of the entire Tensor.
|
|
||||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
||||||
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
|
||||||
and the exceeding part will be filled with 0 in the output.
|
|
||||||
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
|
||||||
are equal to `input_indices` minus `offset`.
|
|
||||||
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
|
||||||
Only constant value is allowed.
|
|
||||||
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
|
||||||
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
|
|
||||||
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
|
||||||
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
|
||||||
>>> offset = 4
|
|
||||||
>>> reduce_scatter_flag = False
|
|
||||||
>>> split_num = 1
|
|
||||||
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
|
|
||||||
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
|
||||||
"""
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self):
|
|
||||||
"""init index_select"""
|
|
||||||
self.__setattr_flag__ = True
|
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
|
|
||||||
outputs=['output'])
|
|
||||||
self.add_prim_attr('primitive_target', 'CPU')
|
|
||||||
|
|
||||||
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
|
|
||||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
||||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
|
||||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
|
||||||
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
|
||||||
if split_num['value'] < 1:
|
|
||||||
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
|
||||||
params_shp = params['shape']
|
|
||||||
out_shape = indices['shape'] + params_shp[1:]
|
|
||||||
if reduce_scatter_flag is None:
|
|
||||||
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
|
||||||
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
|
||||||
if split_num is None:
|
|
||||||
raise ValueError("The value of 'split_num_value' is None.")
|
|
||||||
split_num_value = split_num['value']
|
|
||||||
if reduce_scatter_flag_value is True:
|
|
||||||
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
|
|
||||||
# (split_num * 8)
|
|
||||||
if out_shape[0] % (split_num_value * 8) != 0:
|
|
||||||
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
|
|
||||||
(out_shape[0], (split_num_value * 8)))
|
|
||||||
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
|
|
||||||
out_shape[0] = out_shape[0] // 8
|
|
||||||
out = {'shape': out_shape,
|
|
||||||
'dtype': params['dtype'],
|
|
||||||
'value': None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Split(PrimitiveWithInfer):
|
class Split(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Splits input tensor into output_num of tensors along the given axis and output numbers.
|
Splits input tensor into output_num of tensors along the given axis and output numbers.
|
||||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ class Net(nn.Cell):
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.reduce_scatter_flag = reduce_scatter_flag
|
self.reduce_scatter_flag = reduce_scatter_flag
|
||||||
self.split_num = split_num
|
self.split_num = split_num
|
||||||
self.elu = P.EmbeddingLookup()
|
self.elu = inner.EmbeddingLookup()
|
||||||
self.mm = P.BatchMatMul()
|
self.mm = P.BatchMatMul()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
|
|
Loading…
Reference in New Issue