!28703 add kObjectTypeTensorType support for narrow, expand_dims and masked_fill
Merge pull request !28703 from 吕昱峰(Nate.River)/master
This commit is contained in:
commit
b4c77e5803
|
@ -184,6 +184,9 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"reshape", std::string("reshape")}, // P.reshape()
|
||||
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
|
||||
{"swapaxes", std::string("swapaxes")}, // P.transpose()
|
||||
{"narrow", std::string("narrow")}, // narrow()
|
||||
{"masked_fill", std::string("masked_fill")}, // masked_fill()
|
||||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore import dtype as mstype
|
|||
from ..._checkparam import Validator as validator
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops import composite as C
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like, repeat_elements
|
||||
from ...ops.composite.base import _append, _insert
|
||||
|
@ -1501,6 +1502,34 @@ def expand_tensor_as(x, y):
|
|||
return broadcast_to(x)
|
||||
|
||||
|
||||
def expand_dims(x, axis):
|
||||
"""
|
||||
Insert a dimension of shape 1 at the specified axis of Tensor
|
||||
"""
|
||||
check_is_int(axis, 'axis')
|
||||
return P.ExpandDims()(x, axis)
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
"""
|
||||
Fills elements of self tensor with value where mask is True.
|
||||
The shape of mask must be equal to the shape of the underlying tensor.
|
||||
"""
|
||||
check_is_tensor(mask)
|
||||
check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
|
||||
mask_shape = infer_out_shape(x.shape, mask.shape)
|
||||
mask = P.BroadcastTo(mask_shape)(mask)
|
||||
check_value_type('value', value, [int, float], "Tensor")
|
||||
return C.array_ops.masked_fill(x, mask, value)
|
||||
|
||||
|
||||
def narrow(x, axis, start, length):
|
||||
"""
|
||||
Returns a narrowed tensor from input tensor.
|
||||
The dimension axis is input from start to start + length.
|
||||
"""
|
||||
return F.narrow(x, axis, start, length)
|
||||
|
||||
def view(x, *shape):
|
||||
"""Reshape tensor, if shape is -1, reshape tensor into one dimension"""
|
||||
shape = check_view_shape(shape)
|
||||
|
@ -1655,6 +1684,9 @@ check_axis_type = constexpr(validator.check_axis_type)
|
|||
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
||||
empty_compile = constexpr(validator.empty_compile)
|
||||
check_type_support = constexpr(validator.check_type_support)
|
||||
check_is_int = constexpr(validator.check_is_int)
|
||||
check_type_name = constexpr(validator.check_type_name)
|
||||
check_value_type = constexpr(validator.check_value_type)
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
|
|
|
@ -1375,7 +1375,7 @@ class Tensor(Tensor_):
|
|||
mask_shape = validator.infer_out_shape(self.shape, mask.shape)
|
||||
mask = tensor_operator_registry.get('broadcast_to')(mask_shape)(mask)
|
||||
validator.check_value_type('value', value, [int, float], "Tensor")
|
||||
return tensor_operator_registry.get("_masked_fill")(self, mask, value)
|
||||
return tensor_operator_registry.get("masked_fill")(self, mask, value)
|
||||
|
||||
def ptp(self, axis=None, keepdims=False):
|
||||
"""
|
||||
|
|
|
@ -217,8 +217,8 @@ def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
|
|||
result = range_vector < mask
|
||||
return result
|
||||
|
||||
def _masked_fill(inputs, mask, value):
|
||||
def masked_fill(inputs, mask, value):
|
||||
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
|
||||
return P.Select()(mask, masked_value, inputs)
|
||||
|
||||
tensor_operator_registry.register('_masked_fill', _masked_fill)
|
||||
tensor_operator_registry.register('masked_fill', masked_fill)
|
||||
|
|
Loading…
Reference in New Issue