!244 recover embeddinglookup

Merge pull request !244 from wuxuejian/recover_embeddinglookup
This commit is contained in:
mindspore-ci-bot 2020-06-28 09:49:32 +08:00 committed by Gitee
commit cc4c40ce2e
3 changed files with 51 additions and 5 deletions

View File

@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
const std::string inner_ops_module = INNER_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
}
return inner_ops_module;
}
return ops_module;
}
return inner_ops_module;
}
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {

View File

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, EmbeddingLookup,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
@ -134,6 +134,7 @@ __all__ = [
'OneHot',
'GatherV2',
'SparseGatherV2',
'EmbeddingLookup',
'Concat',
'Pack',
'Unpack',

View File

@ -567,6 +567,51 @@ class SparseGatherV2(GatherV2):
>>> out = P.GatherV2()(input_params, input_indices, axis)
"""
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
functionality as GatherV2, but has one more inputs: `offset`.
This primitive runs on the acipu devices.
Inputs:
- **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `params`,
and the exceeding part will be filled with 0 in the output.
The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64.
- **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices
are equal to `indices` minus `offset`.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> out = P.EmbeddingLookup()(params, indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
outputs=['output'])
def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
valid_types = (mstype.int32, mstype.int64)
validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
out_shape = indices['shape'] + params_shp[1:]
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
class Split(PrimitiveWithInfer):
"""