forked from mindspore-Ecosystem/mindspore
!28617 add masked_fill and expand_dims for Tensor
Merge pull request !28617 from 吕昱峰(Nate.River)/masked_fill
This commit is contained in:
commit
85a3ac1fd4
|
@ -24,6 +24,7 @@ from .._c_expression import Tensor as Tensor_
|
|||
from .._c_expression import CSRTensor as CSRTensor_
|
||||
from .._c_expression import PynativeExecutor_
|
||||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
|
||||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'CSRTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
|
@ -953,6 +954,44 @@ class Tensor(Tensor_):
|
|||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def expand_dims(self, axis):
|
||||
"""
|
||||
Insert a dimension of shape 1 at the specified axis of Tensor
|
||||
|
||||
Args:
|
||||
axis (int): the axis at which to insert the singleton dimension.
|
||||
|
||||
Returns:
|
||||
Tensor, with inserted dimension of length 1.
|
||||
|
||||
Raises:
|
||||
TypeError: If axis is not an int.
|
||||
ValueError: If axis is not in range [-self.ndim - 1, self.ndim + 1).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.ones((2,2), dtype=np.float32))
|
||||
>>> print(x)
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
>>> print(x.shape)
|
||||
(2, 2)
|
||||
>>> y = x.expand_dims(axis=0)
|
||||
>>> print(y)
|
||||
[[[1. 1.]
|
||||
[1. 1.]]]
|
||||
>>> print(y.shape)
|
||||
(1, 2, 2)
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_is_int(axis, 'axis')
|
||||
validator.check_int_range(axis, -self.ndim - 1, self.ndim + 1, Rel.INC_LEFT, 'axis')
|
||||
return tensor_operator_registry.get('expand_dims')(self, axis)
|
||||
|
||||
def astype(self, dtype, copy=True):
|
||||
"""
|
||||
Return a copy of the tensor, cast to a specified type.
|
||||
|
@ -1299,6 +1338,45 @@ class Tensor(Tensor_):
|
|||
"but got {}.".format(type(value)))
|
||||
return tensor_operator_registry.get("fill")(self.dtype, self.shape, value)
|
||||
|
||||
def masked_fill(self, mask, value):
|
||||
"""
|
||||
Fills elements of self tensor with value where mask is True.
|
||||
The shape of mask must be equal to the shape of the underlying tensor.
|
||||
|
||||
Args:
|
||||
mask (Tensor[bool]): The boolean mask.
|
||||
value (Union[int, float]): The value to fill in with, which only supports a float or an int number.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as self.
|
||||
|
||||
Raises:
|
||||
TypeError: If mask is not a tensor.
|
||||
TypeError: If mask is not bool.
|
||||
TypeError: If value is neither int nor float number.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> a = Tensor(np.arange(4)).astype('float32'))
|
||||
>>> print(a)
|
||||
[0. 1. 2. 3.]
|
||||
>>> mask = Tensor([False, False, True, True])
|
||||
>>> print(a.masked_fill(mask, 0.0))
|
||||
[0. 1. 0. 0.]
|
||||
"""
|
||||
if not isinstance(mask, Tensor):
|
||||
raise TypeError("For 'Tensor.masked_fill', the type of the argument 'mask' must be Tensor, but "
|
||||
"got {}.".format(type(mask)))
|
||||
validator.check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
|
||||
mask_shape = validator.infer_out_shape(self.shape, mask.shape)
|
||||
mask = tensor_operator_registry.get('broadcast_to')(mask_shape)(mask)
|
||||
validator.check_value_type('value', value, [int, float], "Tensor")
|
||||
return tensor_operator_registry.get("_masked_fill")(self, mask, value)
|
||||
|
||||
def ptp(self, axis=None, keepdims=False):
|
||||
"""
|
||||
The name of the function comes from the acronym for "peak to peak".
|
||||
|
|
|
@ -216,3 +216,9 @@ def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
|
|||
mask = expand_op(lengths, -1)
|
||||
result = range_vector < mask
|
||||
return result
|
||||
|
||||
def _masked_fill(inputs, mask, value):
|
||||
masked_value = P.Fill()(inputs.dtype, inputs.shape, value)
|
||||
return P.Select()(mask, masked_value, inputs)
|
||||
|
||||
tensor_operator_registry.register('_masked_fill', _masked_fill)
|
||||
|
|
|
@ -630,6 +630,7 @@ tensor_operator_registry.register('__ge__', tensor_ge)
|
|||
tensor_operator_registry.register('__logical_not__', logical_not)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
tensor_operator_registry.register('squeeze', squeeze)
|
||||
tensor_operator_registry.register('expand_dims', expand_dims)
|
||||
# support GE backend for no compare operators
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
tensor_operator_registry.register('shape_mul', shape_mul)
|
||||
|
|
Loading…
Reference in New Issue