forked from mindspore-Ecosystem/mindspore
!244 recover embeddinglookup
Merge pull request !244 from wuxuejian/recover_embeddinglookup
This commit is contained in:
commit
cc4c40ce2e
|
@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||
}
|
||||
return inner_ops_module;
|
||||
}
|
||||
return ops_module;
|
||||
}
|
||||
return inner_ops_module;
|
||||
}
|
||||
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
|
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Shape, Size, Slice, Split, EmbeddingLookup,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
|
@ -134,6 +134,7 @@ __all__ = [
|
|||
'OneHot',
|
||||
'GatherV2',
|
||||
'SparseGatherV2',
|
||||
'EmbeddingLookup',
|
||||
'Concat',
|
||||
'Pack',
|
||||
'Unpack',
|
||||
|
|
|
@ -567,6 +567,51 @@ class SparseGatherV2(GatherV2):
|
|||
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
||||
"""
|
||||
|
||||
class EmbeddingLookup(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
||||
functionality as GatherV2, but has one more inputs: `offset`.
|
||||
This primitive runs on the acipu devices.
|
||||
|
||||
Inputs:
|
||||
- **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The Tensor slice, instead of the entire Tensor.
|
||||
- **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of `params`,
|
||||
and the exceeding part will be filled with 0 in the output.
|
||||
The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64.
|
||||
- **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices
|
||||
are equal to `indices` minus `offset`.
|
||||
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Examples:
|
||||
>>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
||||
>>> offset = 4
|
||||
>>> out = P.EmbeddingLookup()(params, indices, offset)
|
||||
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init index_select"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||
outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, offset):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
valid_types = (mstype.int32, mstype.int64)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name)
|
||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||
params_shp = params['shape']
|
||||
out_shape = indices['shape'] + params_shp[1:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class Split(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue