forked from mindspore-Ecosystem/mindspore
!244 recover embeddinglookup
Merge pull request !244 from wuxuejian/recover_embeddinglookup
This commit is contained in:
commit
cc4c40ce2e
|
@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
||||||
const std::string inner_ops_module = INNER_OP_PATH;
|
const std::string inner_ops_module = INNER_OP_PATH;
|
||||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||||
}
|
}
|
||||||
return inner_ops_module;
|
return ops_module;
|
||||||
}
|
}
|
||||||
return ops_module;
|
return inner_ops_module;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split,
|
Shape, Size, Slice, Split, EmbeddingLookup,
|
||||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
|
@ -134,6 +134,7 @@ __all__ = [
|
||||||
'OneHot',
|
'OneHot',
|
||||||
'GatherV2',
|
'GatherV2',
|
||||||
'SparseGatherV2',
|
'SparseGatherV2',
|
||||||
|
'EmbeddingLookup',
|
||||||
'Concat',
|
'Concat',
|
||||||
'Pack',
|
'Pack',
|
||||||
'Unpack',
|
'Unpack',
|
||||||
|
|
|
@ -567,6 +567,51 @@ class SparseGatherV2(GatherV2):
|
||||||
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
||||||
|
functionality as GatherV2, but has one more inputs: `offset`.
|
||||||
|
This primitive runs on the acipu devices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
The Tensor slice, instead of the entire Tensor.
|
||||||
|
- **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `params`,
|
||||||
|
and the exceeding part will be filled with 0 in the output.
|
||||||
|
The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64.
|
||||||
|
- **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices
|
||||||
|
are equal to `indices` minus `offset`.
|
||||||
|
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||||
|
>>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
||||||
|
>>> offset = 4
|
||||||
|
>>> out = P.EmbeddingLookup()(params, indices, offset)
|
||||||
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init index_select"""
|
||||||
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||||
|
outputs=['output'])
|
||||||
|
|
||||||
|
def __infer__(self, params, indices, offset):
|
||||||
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
|
valid_types = (mstype.int32, mstype.int64)
|
||||||
|
validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name)
|
||||||
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||||
|
params_shp = params['shape']
|
||||||
|
out_shape = indices['shape'] + params_shp[1:]
|
||||||
|
out = {'shape': out_shape,
|
||||||
|
'dtype': params['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Split(PrimitiveWithInfer):
|
class Split(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue