diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc index 7bd2fa808de..4db912f63e8 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc @@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) { const std::string inner_ops_module = INNER_OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); - if (!py::hasattr(mod, common::SafeCStr(op_name))) { - if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { + if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { + if (!py::hasattr(mod, common::SafeCStr(op_name))) { MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; } - return inner_ops_module; + return ops_module; } - return ops_module; + return inner_ops_module; } ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index beed99f713b..dec223193aa 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, Size, Slice, Split, + Shape, Size, Slice, Split, EmbeddingLookup, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, @@ -134,6 +134,7 @@ __all__ = [ 'OneHot', 'GatherV2', 'SparseGatherV2', + 'EmbeddingLookup', 'Concat', 'Pack', 'Unpack', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 1bb39d15472..72533791c61 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -567,6 +567,51 @@ class SparseGatherV2(GatherV2): >>> out = P.GatherV2()(input_params, input_indices, axis) """ +class EmbeddingLookup(PrimitiveWithInfer): + """ + Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar + functionality as GatherV2, but has one more inputs: `offset`. + This primitive runs on the acipu devices. + + Inputs: + - **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `params`, + and the exceeding part will be filled with 0 in the output. + The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64. + - **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices + are equal to `indices` minus `offset`. + + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) + >>> offset = 4 + >>> out = P.EmbeddingLookup()(params, indices, offset) + [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] + """ + @prim_attr_register + def __init__(self): + """init index_select""" + self.init_prim_io_names(inputs=['params', 'indices', 'offset'], + outputs=['output']) + + def __infer__(self, params, indices, offset): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + valid_types = (mstype.int32, mstype.int64) + validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name) + validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) + params_shp = params['shape'] + out_shape = indices['shape'] + params_shp[1:] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out + class Split(PrimitiveWithInfer): """