Add function and tensor interfaces.

This commit is contained in:
liqiliang 2022-05-25 10:04:06 +08:00
parent e65fe8fabe
commit 73e4ee0f95
13 changed files with 468 additions and 99 deletions

View File

@ -14,13 +14,13 @@ mindspore.nn.Softsign
**输入:**
**input_x** (Tensor) - shape为 :math:`(N, *)` 的Tensor, 其中 :math:`*` 表示任意个数的维度。它的数据类型必须为float16或float32。
**x** (Tensor) - shape为 :math:`(N, *)` 的Tensor, 其中 :math:`*` 表示任意个数的维度。它的数据类型必须为float16或float32。
**输出:**
Tensor数据类型和shape与 `input_x` 相同。
Tensor数据类型和shape与 `x` 相同。
**异常:**
- **TypeError** - `input_x` 不是Tensor。
- **TypeError** - `input_x` 的数据类型既不是float16也不是float32。
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型既不是float16也不是float32。

View File

@ -218,6 +218,10 @@ BuiltInTypeMap &GetMethodMap() {
{"choose", std::string("choose")}, // P.Select()
{"diagonal", std::string("diagonal")}, // P.Eye()
{"matrix_diag", std::string("matrix_diag")}, // matrix_diag()
{"inv", std::string("inv")}, // inv()
{"invert", std::string("invert")}, // invert()
{"matrix_band_part", std::string("matrix_band_part")}, // matrix_band_part()
{"padding", std::string("padding")}, // padding()
{"searchsorted", std::string("searchsorted")}, // P.Select()
{"take", std::string("take")}, // P.GatherNd()
{"tensor_scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()

View File

@ -832,6 +832,34 @@ def matrix_diag(x, k=0, num_rows=-1, num_cols=-1, padding_value=0, align="RIGHT_
return F.matrix_diag(x, k, num_rows, num_cols, padding_value, align)
def inv(x):
"""
Computes Reciprocal of input tensor element-wise.
"""
return F.inv(x)
def invert(x):
"""
Flips all bits of input tensor element-wise.
"""
return F.invert(x)
def matrix_band_part(x, lower, upper):
"""
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
"""
return F.matrix_band_part(x, lower, upper)
def padding(x, pad_dim_size=8):
"""
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
"""
return F.padding(x, pad_dim_size)
def trace(x, offset=0, axis1=0, axis2=1, dtype=None):
"""
Returns the sum along diagonals of the array.

View File

@ -1058,6 +1058,130 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('matrix_diag')(self, k, num_rows, num_cols, padding_value, align)
def inv(self):
r"""
Computes Reciprocal of input tensor element-wise.
.. math::
out_i = \frac{1}{x_{i} }
Returns:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not one of float16, float32, int32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
>>> output = x.inv()
>>> print(output)
[4. 2.5 3.2258065 1.923077 ]
"""
self._init_check()
return tensor_operator_registry.get('inv')(self)
def invert(self):
r"""
Flips all bits of input tensor element-wise.
.. math::
out_i = ~x_{i}
Returns:
Tensor, has the same shape as `x`.
Raises:
TypeError: If dtype of `x` is neither int16 nor uint16.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
>>> output = x.invert()
>>> print(output)
[-26 -5 -14 -10]
"""
self._init_check()
return tensor_operator_registry.get('invert')(self)
def matrix_band_part(self, lower, upper):
r"""
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
Args:
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
If negative, keep entire lower triangle.
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
If negative, keep entire upper triangle.
Returns:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
TypeError: If dtype of `lower` is not int32 or int64.
TypeError: If dtype of `upper` is not int32 or int64.
ValueError: If the shape of `x` is not greater than or equal to 2D.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
>>> output = x.matrix_band_part(2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]
[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]]
"""
self._init_check()
return tensor_operator_registry.get('matrix_band_part')(self, lower, upper)
def padding(self, pad_dim_size):
r"""
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
Args:
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive.
Default: 8.
Returns:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If `pad_dim_size` is not an int.
ValueError: If `pad_dim_size` is less than 1.
ValueError: If last dim of `x` is not equal to 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> output = x.padding(pad_dim_size)
>>> print(output)
[[ 8. 0. 0. 0.]
[10. 0. 0. 0.]]
"""
self._init_check()
return tensor_operator_registry.get('padding')(self, pad_dim_size)
def pow(self, power):
r"""
Calculate the power of Tensor.
@ -2236,7 +2360,6 @@ class Tensor(Tensor_):
self.init.seed = slice_index + Tensor.delta_seed
Tensor.delta_seed += self._device_num
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
np.random.seed(self._np_seed)

View File

@ -706,27 +706,26 @@ class Softsign(Cell):
Softsign is defined as:
.. math::
\text{SoftSign}(x) = \frac{x}{1 + |x|}
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
- **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Tensor, with the same type and shape as the `x`.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If dtype of `input_x` is neither float16 nor float32.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
>>> x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
>>> softsign = nn.Softsign()
>>> output = softsign(input_x)
>>> output = softsign(x)
>>> print(output)
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
"""

View File

@ -19,28 +19,173 @@ Function operator.
A collection of function to build neural networks or to compute functions.
"""
from . import array_func, parameter_func, math_func, nn_func, clip_func
from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_, ger,
dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor,
tuple_to_array, expand_dims, transpose, scatter_nd, scatter_nd_add, scatter_nd_sub, gather,
gather_d, gather_nd, scalar_cast, masked_fill, tensor_scatter_add, tensor_scatter_sub,
tensor_scatter_mul, unique_consecutive,
tensor_scatter_div, scatter_max, scatter_min, nonzero, space_to_batch_nd, range, select,
one_hot, matrix_diag, diag, masked_select, meshgrid)
from .array_func import fills
from .parameter_func import assign, assign_add, assign_sub, index_add
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le, lerp,
lp_norm, round, tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
tensor_floordiv, floor_div, floordiv, tensor_pow, pow, pows, tensor_mod, floor_mod, floormod,
tensor_exp, exp, tensor_expm1, expm1, equal, not_equal, ne, isfinite, isnan, same_type_shape,
log, log_matrix_determinant, matrix_determinant, maximum, logaddexp, logaddexp2,
invert, minimum, floor, logical_not, logical_or, logical_and, sin, cos, tan,
asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, bitwise_and, bitwise_or,
bitwise_xor, erf, erfc, cdist, bessel_i0, bessel_i0e, bessel_j0, bessel_j1, bessel_k0,
bessel_k0e, bessel_y0, bessel_y1, bessel_i1, bessel_i1e, bessel_k1, bessel_k1e, exp2, deg2rad)
from .nn_func import (fast_gelu, hardshrink)
from .linalg_func import svd
from .clip_func import (clip_by_norm)
from . import (
array_func,
parameter_func,
math_func,
nn_func,
linalg_func,
clip_func,
)
from .array_func import (
unique,
eye,
matrix_band_part,
padding,
fill,
fill_,
tile,
size,
ones,
ones_like,
shape,
shape_,
ger,
dyn_shape,
rank,
reshape,
reshape_,
tensor_slice,
slice,
scalar_to_array,
scalar_to_tensor,
tuple_to_array,
expand_dims,
transpose,
scatter_nd,
scatter_nd_add,
scatter_nd_sub,
gather,
gather_d,
gather_nd,
scalar_cast,
masked_fill,
tensor_scatter_add,
tensor_scatter_sub,
tensor_scatter_mul,
unique_consecutive,
tensor_scatter_div,
scatter_max,
scatter_min,
nonzero,
space_to_batch_nd,
range,
select,
one_hot,
matrix_diag,
diag,
masked_select,
meshgrid,
fills,
)
from .parameter_func import (
assign,
assign_add,
assign_sub,
index_add,
)
from .math_func import (
addn,
absolute,
abs,
tensor_add,
add,
neg_tensor,
neg,
tensor_lt,
less,
tensor_le,
le,
lerp,
lp_norm,
round,
tensor_gt,
gt,
tensor_ge,
ge,
tensor_sub,
sub,
tensor_mul,
mul,
tensor_div,
div,
tensor_floordiv,
floor_div,
floordiv,
tensor_pow,
pow,
pows,
tensor_mod,
floor_mod,
floormod,
tensor_exp,
exp,
tensor_expm1,
expm1,
equal,
not_equal,
ne,
isfinite,
isnan,
same_type_shape,
log,
log_matrix_determinant,
matrix_determinant,
maximum,
logaddexp,
logaddexp2,
inv,
invert,
minimum,
floor,
logical_not,
logical_or,
logical_and,
sin,
cos,
tan,
asin,
acos,
atan,
sinh,
cosh,
tanh,
asinh,
acosh,
atanh,
atan2,
bitwise_and,
bitwise_or,
bitwise_xor,
erf,
erfc,
cdist,
bessel_i0,
bessel_i0e,
bessel_j0,
bessel_j1,
bessel_k0,
bessel_k0e,
bessel_y0,
bessel_y1,
bessel_i1,
bessel_i1e,
bessel_k1,
bessel_k1e,
exp2,
deg2rad,
)
from .nn_func import (
fast_gelu,
hardshrink,
softsign,
)
from .linalg_func import (
svd,
)
from .clip_func import (
clip_by_norm,
)
__all__ = []
__all__.extend(array_func.__all__)

View File

@ -123,12 +123,12 @@ def matrix_band_part(x, lower, upper):
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
Args:
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
The data type must be float16, float32, float64, int32 or int64.
- **lower** (int) - Number of subdiagonals to keep. It must be int32 or int64.
If negative, keep entire lower triangle.
- **upper** (int) - Number of superdiagonals to keep. It must be int32 or int64.
If negative, keep entire upper triangle.
x (Tensor): Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
The data type must be float16, float32, float64, int32 or int64.
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
If negative, keep entire lower triangle.
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
If negative, keep entire upper triangle.
Returns:
Tensor, has the same type and shape as input shape value.
@ -144,8 +144,8 @@ def matrix_band_part(x, lower, upper):
Examples:
>>> from mindspore.ops import functional as F
>>> x = np.ones([2, 4, 4]).astype(np.float32)
>>> output = F.matrix_band_part(Tensor(x), 2, 1)
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
>>> output = F.matrix_band_part(x, 2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
[1. 1. 1. 0.]
@ -160,6 +160,39 @@ def matrix_band_part(x, lower, upper):
return matrix_band_part_(x, lower, upper)
def padding(x, pad_dim_size=8):
r"""
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
Args:
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
The last dimension of `x` must be 1. The data type is Number.
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
Returns:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If `pad_dim_size` is not an int.
ValueError: If `pad_dim_size` is less than 1.
ValueError: If last dim of `x` is not equal to 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> output = F.padding(x, pad_dim_size)
>>> print(output)
[[ 8. 0. 0. 0.]
[10. 0. 0. 0.]]
"""
padding_ = P.array_ops.Padding(pad_dim_size)
return padding_(x)
def one_hot(indices, depth, on_value, off_value, axis=-1):
r"""
Computes a one-hot tensor.
@ -2114,6 +2147,7 @@ __all__ = [
'unique_consecutive',
'eye',
'matrix_band_part',
'padding',
'fill',
'fill_',
'fills',

View File

@ -82,6 +82,7 @@ atan2_ = P.Atan2()
bitwise_and_ = P.BitwiseAnd()
bitwise_or_ = P.BitwiseOr()
bitwise_xor_ = P.BitwiseXor()
inv_ = P.math_ops.Inv()
invert_ = P.Invert()
erf_ = P.Erf()
erfc_ = P.Erfc()
@ -1372,12 +1373,41 @@ def bitwise_xor(x, y):
return bitwise_xor_(x, y)
def inv(x):
r"""
Computes Reciprocal of input tensor element-wise.
.. math::
out_i = \frac{1}{x_{i} }
Args:
x (Tensor): Tensor of any dimension. Must be one of the following types: float16, float32 or int32.
Returns:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not one of float16, float32, int32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
>>> output = F.inv(x)
>>> print(output)
[4. 2.5 3.2258065 1.923077 ]
"""
return inv_(x)
def invert(x):
r"""
Flips all bits of input tensor element-wise.
.. math::
out_i = ~x_{i}
Args:
@ -1391,11 +1421,12 @@ def invert(x):
TypeError: If dtype of `x` is neither int16 nor uint16.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
>>> output = ops.invert(x)
>>> output = F.invert(x)
>>> print(output)
[-26 -5 -14 -10]
"""
@ -2250,7 +2281,7 @@ def logaddexp(x1, x2):
log_op = P.Log()
exp_op = P.Exp()
y = log_op(exp_op(x1)+exp_op(x2))
y = log_op(exp_op(x1) + exp_op(x2))
return y
@ -2532,6 +2563,7 @@ def deg2rad(x):
out = x * math.pi / 180.0
return out
#####################################
# Reduction Operation Functions.
#####################################
@ -2649,6 +2681,7 @@ __all__ = [
'bitwise_and',
'bitwise_or',
'bitwise_xor',
'inv',
'invert',
'erf',
'erfc',

View File

@ -17,8 +17,10 @@
from mindspore.ops import operations as P
fast_gelu_ = P.FastGeLU()
softsign_ = P.Softsign()
def fast_gelu(x):
r"""
Fast Gaussian Error Linear Units activation function.
@ -91,8 +93,42 @@ def hardshrink(x, lambd=0.5):
return hshrink_op(x)
def softsign(x):
r"""
Softsign activation function.
The function is shown as follows:
.. math::
\text{SoftSign}(x) = \frac{x}{1 + |x|}
Args:
x (Tensor): Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
Outputs:
Tensor, with the same type and shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
>>> output = F.softsign(x)
>>> print(output)
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
"""
return softsign_(x)
__all__ = [
'fast_gelu',
'hardshrink'
'hardshrink',
'softsign'
]
__all__.sort()

View File

@ -901,6 +901,10 @@ tensor_operator_registry.register('one_hot', P.OneHot)
tensor_operator_registry.register('masked_select', masked_select)
tensor_operator_registry.register('nonzero', nonzero)
tensor_operator_registry.register('matrix_diag', matrix_diag)
tensor_operator_registry.register('inv', inv)
tensor_operator_registry.register('invert', invert)
tensor_operator_registry.register('matrix_band_part', matrix_band_part)
tensor_operator_registry.register('padding', padding)
tensor_operator_registry.register('hardshrink', P.HShrink)
tensor_operator_registry.register('svd', linalg_ops.Svd)
tensor_operator_registry.register('diag', P.Diag)

View File

@ -968,28 +968,16 @@ class Padding(Primitive):
"""
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
Args:
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
The last dimension of `x` must be 1. The data type is Number.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Raises:
TypeError: If `pad_dim_size` is not an int.
ValueError: If `pad_dim_size` is less than 1.
ValueError: If last dim of `x` is not equal to 1.
Refer to :func:`mindspore.ops.padding` for more details.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import Padding
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> output = ops.Padding(pad_dim_size)(x)
>>> output = Padding(pad_dim_size)(x)
>>> print(output)
[[ 8. 0. 0. 0.]
[10. 0. 0. 0.]]
@ -1435,7 +1423,7 @@ class MatrixBandPart(Primitive):
r"""
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
Refer to :func:`mindspore.ops.matrix_band_part` for more detail.
Refer to :func:`mindspore.ops.matrix_band_part` for more details.
Supported Platforms:
``GPU`` ``CPU``

View File

@ -4780,25 +4780,14 @@ class Inv(Primitive):
r"""
Computes Reciprocal of input tensor element-wise.
.. math::
out_i = \frac{1}{x_{i} }
Inputs:
- **x** (Tensor) - Tensor of any dimension. Must be one of the following types: float16, float32 or int32.
Outputs:
Tensor, has the same shape and data type as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not one of float16, float32, int32.
Refer to :func:`mindspore.ops.inv` for more details.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> inv = ops.Inv()
>>> from mindspore.ops.operations.math_ops import Inv
>>> inv = Inv()
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
>>> output = inv(x)
>>> print(output)
@ -4814,13 +4803,14 @@ class Invert(Primitive):
r"""
Flips all bits of input tensor element-wise.
Refer to :func:`mindspore.ops.invert` for more detail.
Refer to :func:`mindspore.ops.invert` for more details.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> invert = ops.Invert()
>>> from mindspore.ops.operations.math_ops import Invert
>>> invert = Invert()
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
>>> output = invert(x)
>>> print(output)

View File

@ -429,29 +429,15 @@ class Softsign(Primitive):
r"""
Softsign activation function.
The function is shown as follows:
.. math::
\text{SoftSign}(x) = \frac{x}{1 + |x|}
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
TypeError: If dtype of `input_x` is neither float16 nor float32.
Refer to :func:`mindspore.ops.softsign` for more details.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> from mindspore.ops.operations.nn_ops import Softsign
>>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
>>> softsign = ops.Softsign()
>>> softsign = Softsign()
>>> output = softsign(input_x)
>>> print(output)
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
@ -9565,7 +9551,6 @@ class PSROIPooling(Primitive):
@prim_attr_register
def __init__(self, spatial_scale, group_size, output_dim):
"""Initialize PSROIPooling"""
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
validator.check_value_type("group_size", group_size, [int], self.name)