!28703 add kObjectTypeTensorType support for narrow, expand_dims and masked_fill

Merge pull request !28703 from 吕昱峰(Nate.River)/master
This commit is contained in:
i-robot 2022-01-08 06:55:55 +00:00 committed by Gitee
commit b4c77e5803
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 38 additions and 3 deletions

View File

@ -184,6 +184,9 @@ BuiltInTypeMap &GetMethodMap() {
{"reshape", std::string("reshape")}, // P.reshape()
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
{"swapaxes", std::string("swapaxes")}, // P.transpose()
{"narrow", std::string("narrow")}, // narrow()
{"masked_fill", std::string("masked_fill")}, // masked_fill()
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"cumsum", std::string("cumsum")}, // P.cumsum()

View File

@ -24,6 +24,7 @@ from mindspore import dtype as mstype
from ..._checkparam import Validator as validator
from ...ops import functional as F
from ...ops import operations as P
from ...ops import composite as C
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like, repeat_elements
from ...ops.composite.base import _append, _insert
@ -1501,6 +1502,34 @@ def expand_tensor_as(x, y):
return broadcast_to(x)
def expand_dims(x, axis):
"""
Insert a dimension of shape 1 at the specified axis of Tensor
"""
check_is_int(axis, 'axis')
return P.ExpandDims()(x, axis)
def masked_fill(x, mask, value):
"""
Fills elements of self tensor with value where mask is True.
The shape of mask must be equal to the shape of the underlying tensor.
"""
check_is_tensor(mask)
check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
mask_shape = infer_out_shape(x.shape, mask.shape)
mask = P.BroadcastTo(mask_shape)(mask)
check_value_type('value', value, [int, float], "Tensor")
return C.array_ops.masked_fill(x, mask, value)
def narrow(x, axis, start, length):
"""
Returns a narrowed tensor from input tensor.
The dimension axis is input from start to start + length.
"""
return F.narrow(x, axis, start, length)
def view(x, *shape):
"""Reshape tensor, if shape is -1, reshape tensor into one dimension"""
shape = check_view_shape(shape)
@ -1655,6 +1684,9 @@ check_axis_type = constexpr(validator.check_axis_type)
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
empty_compile = constexpr(validator.empty_compile)
check_type_support = constexpr(validator.check_type_support)
check_is_int = constexpr(validator.check_is_int)
check_type_name = constexpr(validator.check_type_name)
check_value_type = constexpr(validator.check_value_type)
def tensor_bool(x):

View File

@ -1375,7 +1375,7 @@ class Tensor(Tensor_):
mask_shape = validator.infer_out_shape(self.shape, mask.shape)
mask = tensor_operator_registry.get('broadcast_to')(mask_shape)(mask)
validator.check_value_type('value', value, [int, float], "Tensor")
return tensor_operator_registry.get("_masked_fill")(self, mask, value)
return tensor_operator_registry.get("masked_fill")(self, mask, value)
def ptp(self, axis=None, keepdims=False):
"""

View File

@ -217,8 +217,8 @@ def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
result = range_vector < mask
return result
def _masked_fill(inputs, mask, value):
def masked_fill(inputs, mask, value):
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
return P.Select()(mask, masked_value, inputs)
tensor_operator_registry.register('_masked_fill', _masked_fill)
tensor_operator_registry.register('masked_fill', masked_fill)