!28617 add masked_fill and expand_dims for Tensor

Merge pull request !28617 from 吕昱峰(Nate.River)/masked_fill
This commit is contained in:
i-robot 2022-01-06 11:20:27 +00:00 committed by Gitee
commit 85a3ac1fd4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 85 additions and 0 deletions

View File

@ -24,6 +24,7 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import CSRTensor as CSRTensor_
from .._c_expression import PynativeExecutor_
from .._checkparam import Validator as validator
from .._checkparam import Rel
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'CSRTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
@ -953,6 +954,44 @@ class Tensor(Tensor_):
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
return tensor_operator_registry.get('reshape')()(self, new_shape)
def expand_dims(self, axis):
"""
Insert a dimension of shape 1 at the specified axis of Tensor
Args:
axis (int): the axis at which to insert the singleton dimension.
Returns:
Tensor, with inserted dimension of length 1.
Raises:
TypeError: If axis is not an int.
ValueError: If axis is not in range [-self.ndim - 1, self.ndim + 1).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,2), dtype=np.float32))
>>> print(x)
[[1. 1.]
[1. 1.]]
>>> print(x.shape)
(2, 2)
>>> y = x.expand_dims(axis=0)
>>> print(y)
[[[1. 1.]
[1. 1.]]]
>>> print(y.shape)
(1, 2, 2)
"""
self._init_check()
validator.check_is_int(axis, 'axis')
validator.check_int_range(axis, -self.ndim - 1, self.ndim + 1, Rel.INC_LEFT, 'axis')
return tensor_operator_registry.get('expand_dims')(self, axis)
def astype(self, dtype, copy=True):
"""
Return a copy of the tensor, cast to a specified type.
@ -1299,6 +1338,45 @@ class Tensor(Tensor_):
"but got {}.".format(type(value)))
return tensor_operator_registry.get("fill")(self.dtype, self.shape, value)
def masked_fill(self, mask, value):
"""
Fills elements of self tensor with value where mask is True.
The shape of mask must be equal to the shape of the underlying tensor.
Args:
mask (Tensor[bool]): The boolean mask.
value (Union[int, float]): The value to fill in with, which only supports a float or an int number.
Returns:
Tensor, has the same type and shape as self.
Raises:
TypeError: If mask is not a tensor.
TypeError: If mask is not bool.
TypeError: If value is neither int nor float number.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.arange(4)).astype('float32'))
>>> print(a)
[0. 1. 2. 3.]
>>> mask = Tensor([False, False, True, True])
>>> print(a.masked_fill(mask, 0.0))
[0. 1. 0. 0.]
"""
if not isinstance(mask, Tensor):
raise TypeError("For 'Tensor.masked_fill', the type of the argument 'mask' must be Tensor, but "
"got {}.".format(type(mask)))
validator.check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
mask_shape = validator.infer_out_shape(self.shape, mask.shape)
mask = tensor_operator_registry.get('broadcast_to')(mask_shape)(mask)
validator.check_value_type('value', value, [int, float], "Tensor")
return tensor_operator_registry.get("_masked_fill")(self, mask, value)
def ptp(self, axis=None, keepdims=False):
"""
The name of the function comes from the acronym for "peak to peak".

View File

@ -216,3 +216,9 @@ def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
mask = expand_op(lengths, -1)
result = range_vector < mask
return result
def _masked_fill(inputs, mask, value):
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
return P.Select()(mask, masked_value, inputs)
tensor_operator_registry.register('_masked_fill', _masked_fill)

View File

@ -630,6 +630,7 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__logical_not__', logical_not)
tensor_operator_registry.register('shape', shape)
tensor_operator_registry.register('squeeze', squeeze)
tensor_operator_registry.register('expand_dims', expand_dims)
# support GE backend for no compare operators
tensor_operator_registry.register('cast', cast)
tensor_operator_registry.register('shape_mul', shape_mul)