forked from mindspore-Ecosystem/mindspore
Add function and tensor interfaces.
This commit is contained in:
parent
e65fe8fabe
commit
73e4ee0f95
|
@ -14,13 +14,13 @@ mindspore.nn.Softsign
|
|||
|
||||
**输入:**
|
||||
|
||||
**input_x** (Tensor) - shape为 :math:`(N, *)` 的Tensor, 其中 :math:`*` 表示任意个数的维度。它的数据类型必须为float16或float32。
|
||||
**x** (Tensor) - shape为 :math:`(N, *)` 的Tensor, 其中 :math:`*` 表示任意个数的维度。它的数据类型必须为float16或float32。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,数据类型和shape与 `input_x` 相同。
|
||||
Tensor,数据类型和shape与 `x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `input_x` 不是Tensor。
|
||||
- **TypeError** - `input_x` 的数据类型既不是float16也不是float32。
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
||||
|
|
|
@ -218,6 +218,10 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"choose", std::string("choose")}, // P.Select()
|
||||
{"diagonal", std::string("diagonal")}, // P.Eye()
|
||||
{"matrix_diag", std::string("matrix_diag")}, // matrix_diag()
|
||||
{"inv", std::string("inv")}, // inv()
|
||||
{"invert", std::string("invert")}, // invert()
|
||||
{"matrix_band_part", std::string("matrix_band_part")}, // matrix_band_part()
|
||||
{"padding", std::string("padding")}, // padding()
|
||||
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
||||
{"take", std::string("take")}, // P.GatherNd()
|
||||
{"tensor_scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()
|
||||
|
|
|
@ -832,6 +832,34 @@ def matrix_diag(x, k=0, num_rows=-1, num_cols=-1, padding_value=0, align="RIGHT_
|
|||
return F.matrix_diag(x, k, num_rows, num_cols, padding_value, align)
|
||||
|
||||
|
||||
def inv(x):
|
||||
"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
"""
|
||||
return F.inv(x)
|
||||
|
||||
|
||||
def invert(x):
|
||||
"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
"""
|
||||
return F.invert(x)
|
||||
|
||||
|
||||
def matrix_band_part(x, lower, upper):
|
||||
"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
"""
|
||||
return F.matrix_band_part(x, lower, upper)
|
||||
|
||||
|
||||
def padding(x, pad_dim_size=8):
|
||||
"""
|
||||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
"""
|
||||
return F.padding(x, pad_dim_size)
|
||||
|
||||
|
||||
def trace(x, offset=0, axis1=0, axis2=1, dtype=None):
|
||||
"""
|
||||
Returns the sum along diagonals of the array.
|
||||
|
|
|
@ -1058,6 +1058,130 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('matrix_diag')(self, k, num_rows, num_cols, padding_value, align)
|
||||
|
||||
def inv(self):
|
||||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = \frac{1}{x_{i} }
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is not one of float16, float32, int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
|
||||
>>> output = x.inv()
|
||||
>>> print(output)
|
||||
[4. 2.5 3.2258065 1.923077 ]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('inv')(self)
|
||||
|
||||
def invert(self):
|
||||
r"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = ~x_{i}
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither int16 nor uint16.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
|
||||
>>> output = x.invert()
|
||||
>>> print(output)
|
||||
[-26 -5 -14 -10]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('invert')(self)
|
||||
|
||||
def matrix_band_part(self, lower, upper):
|
||||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Args:
|
||||
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
|
||||
TypeError: If dtype of `lower` is not int32 or int64.
|
||||
TypeError: If dtype of `upper` is not int32 or int64.
|
||||
ValueError: If the shape of `x` is not greater than or equal to 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
|
||||
>>> output = x.matrix_band_part(2, 1)
|
||||
>>> print(output)
|
||||
[[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]
|
||||
|
||||
[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('matrix_band_part')(self, lower, upper)
|
||||
|
||||
def padding(self, pad_dim_size):
|
||||
r"""
|
||||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
|
||||
Args:
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive.
|
||||
Default: 8.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If `pad_dim_size` is not an int.
|
||||
ValueError: If `pad_dim_size` is less than 1.
|
||||
ValueError: If last dim of `x` is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
|
||||
>>> pad_dim_size = 4
|
||||
>>> output = x.padding(pad_dim_size)
|
||||
>>> print(output)
|
||||
[[ 8. 0. 0. 0.]
|
||||
[10. 0. 0. 0.]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('padding')(self, pad_dim_size)
|
||||
|
||||
def pow(self, power):
|
||||
r"""
|
||||
Calculate the power of Tensor.
|
||||
|
@ -2236,7 +2360,6 @@ class Tensor(Tensor_):
|
|||
self.init.seed = slice_index + Tensor.delta_seed
|
||||
Tensor.delta_seed += self._device_num
|
||||
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
if self.need_set_seed:
|
||||
np.random.seed(self._np_seed)
|
||||
|
|
|
@ -706,27 +706,26 @@ class Softsign(Cell):
|
|||
Softsign is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{SoftSign}(x) = \frac{x}{1 + |x|}
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
additional dimensions, with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_x`.
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
|
||||
>>> x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
|
||||
>>> softsign = nn.Softsign()
|
||||
>>> output = softsign(input_x)
|
||||
>>> output = softsign(x)
|
||||
>>> print(output)
|
||||
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
|
||||
"""
|
||||
|
|
|
@ -19,28 +19,173 @@ Function operator.
|
|||
A collection of function to build neural networks or to compute functions.
|
||||
"""
|
||||
|
||||
from . import array_func, parameter_func, math_func, nn_func, clip_func
|
||||
from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_, ger,
|
||||
dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor,
|
||||
tuple_to_array, expand_dims, transpose, scatter_nd, scatter_nd_add, scatter_nd_sub, gather,
|
||||
gather_d, gather_nd, scalar_cast, masked_fill, tensor_scatter_add, tensor_scatter_sub,
|
||||
tensor_scatter_mul, unique_consecutive,
|
||||
tensor_scatter_div, scatter_max, scatter_min, nonzero, space_to_batch_nd, range, select,
|
||||
one_hot, matrix_diag, diag, masked_select, meshgrid)
|
||||
from .array_func import fills
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le, lerp,
|
||||
lp_norm, round, tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
|
||||
tensor_floordiv, floor_div, floordiv, tensor_pow, pow, pows, tensor_mod, floor_mod, floormod,
|
||||
tensor_exp, exp, tensor_expm1, expm1, equal, not_equal, ne, isfinite, isnan, same_type_shape,
|
||||
log, log_matrix_determinant, matrix_determinant, maximum, logaddexp, logaddexp2,
|
||||
invert, minimum, floor, logical_not, logical_or, logical_and, sin, cos, tan,
|
||||
asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, bitwise_and, bitwise_or,
|
||||
bitwise_xor, erf, erfc, cdist, bessel_i0, bessel_i0e, bessel_j0, bessel_j1, bessel_k0,
|
||||
bessel_k0e, bessel_y0, bessel_y1, bessel_i1, bessel_i1e, bessel_k1, bessel_k1e, exp2, deg2rad)
|
||||
from .nn_func import (fast_gelu, hardshrink)
|
||||
from .linalg_func import svd
|
||||
from .clip_func import (clip_by_norm)
|
||||
from . import (
|
||||
array_func,
|
||||
parameter_func,
|
||||
math_func,
|
||||
nn_func,
|
||||
linalg_func,
|
||||
clip_func,
|
||||
)
|
||||
from .array_func import (
|
||||
unique,
|
||||
eye,
|
||||
matrix_band_part,
|
||||
padding,
|
||||
fill,
|
||||
fill_,
|
||||
tile,
|
||||
size,
|
||||
ones,
|
||||
ones_like,
|
||||
shape,
|
||||
shape_,
|
||||
ger,
|
||||
dyn_shape,
|
||||
rank,
|
||||
reshape,
|
||||
reshape_,
|
||||
tensor_slice,
|
||||
slice,
|
||||
scalar_to_array,
|
||||
scalar_to_tensor,
|
||||
tuple_to_array,
|
||||
expand_dims,
|
||||
transpose,
|
||||
scatter_nd,
|
||||
scatter_nd_add,
|
||||
scatter_nd_sub,
|
||||
gather,
|
||||
gather_d,
|
||||
gather_nd,
|
||||
scalar_cast,
|
||||
masked_fill,
|
||||
tensor_scatter_add,
|
||||
tensor_scatter_sub,
|
||||
tensor_scatter_mul,
|
||||
unique_consecutive,
|
||||
tensor_scatter_div,
|
||||
scatter_max,
|
||||
scatter_min,
|
||||
nonzero,
|
||||
space_to_batch_nd,
|
||||
range,
|
||||
select,
|
||||
one_hot,
|
||||
matrix_diag,
|
||||
diag,
|
||||
masked_select,
|
||||
meshgrid,
|
||||
fills,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
assign_add,
|
||||
assign_sub,
|
||||
index_add,
|
||||
)
|
||||
from .math_func import (
|
||||
addn,
|
||||
absolute,
|
||||
abs,
|
||||
tensor_add,
|
||||
add,
|
||||
neg_tensor,
|
||||
neg,
|
||||
tensor_lt,
|
||||
less,
|
||||
tensor_le,
|
||||
le,
|
||||
lerp,
|
||||
lp_norm,
|
||||
round,
|
||||
tensor_gt,
|
||||
gt,
|
||||
tensor_ge,
|
||||
ge,
|
||||
tensor_sub,
|
||||
sub,
|
||||
tensor_mul,
|
||||
mul,
|
||||
tensor_div,
|
||||
div,
|
||||
tensor_floordiv,
|
||||
floor_div,
|
||||
floordiv,
|
||||
tensor_pow,
|
||||
pow,
|
||||
pows,
|
||||
tensor_mod,
|
||||
floor_mod,
|
||||
floormod,
|
||||
tensor_exp,
|
||||
exp,
|
||||
tensor_expm1,
|
||||
expm1,
|
||||
equal,
|
||||
not_equal,
|
||||
ne,
|
||||
isfinite,
|
||||
isnan,
|
||||
same_type_shape,
|
||||
log,
|
||||
log_matrix_determinant,
|
||||
matrix_determinant,
|
||||
maximum,
|
||||
logaddexp,
|
||||
logaddexp2,
|
||||
inv,
|
||||
invert,
|
||||
minimum,
|
||||
floor,
|
||||
logical_not,
|
||||
logical_or,
|
||||
logical_and,
|
||||
sin,
|
||||
cos,
|
||||
tan,
|
||||
asin,
|
||||
acos,
|
||||
atan,
|
||||
sinh,
|
||||
cosh,
|
||||
tanh,
|
||||
asinh,
|
||||
acosh,
|
||||
atanh,
|
||||
atan2,
|
||||
bitwise_and,
|
||||
bitwise_or,
|
||||
bitwise_xor,
|
||||
erf,
|
||||
erfc,
|
||||
cdist,
|
||||
bessel_i0,
|
||||
bessel_i0e,
|
||||
bessel_j0,
|
||||
bessel_j1,
|
||||
bessel_k0,
|
||||
bessel_k0e,
|
||||
bessel_y0,
|
||||
bessel_y1,
|
||||
bessel_i1,
|
||||
bessel_i1e,
|
||||
bessel_k1,
|
||||
bessel_k1e,
|
||||
exp2,
|
||||
deg2rad,
|
||||
)
|
||||
from .nn_func import (
|
||||
fast_gelu,
|
||||
hardshrink,
|
||||
softsign,
|
||||
)
|
||||
from .linalg_func import (
|
||||
svd,
|
||||
)
|
||||
from .clip_func import (
|
||||
clip_by_norm,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(array_func.__all__)
|
||||
|
|
|
@ -123,12 +123,12 @@ def matrix_band_part(x, lower, upper):
|
|||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Args:
|
||||
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
|
||||
The data type must be float16, float32, float64, int32 or int64.
|
||||
- **lower** (int) - Number of subdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
- **upper** (int) - Number of superdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
x (Tensor): Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
|
||||
The data type must be float16, float32, float64, int32 or int64.
|
||||
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
@ -144,8 +144,8 @@ def matrix_band_part(x, lower, upper):
|
|||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = np.ones([2, 4, 4]).astype(np.float32)
|
||||
>>> output = F.matrix_band_part(Tensor(x), 2, 1)
|
||||
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
|
||||
>>> output = F.matrix_band_part(x, 2, 1)
|
||||
>>> print(output)
|
||||
[[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
|
@ -160,6 +160,39 @@ def matrix_band_part(x, lower, upper):
|
|||
return matrix_band_part_(x, lower, upper)
|
||||
|
||||
|
||||
def padding(x, pad_dim_size=8):
|
||||
r"""
|
||||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
|
||||
The last dimension of `x` must be 1. The data type is Number.
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If `pad_dim_size` is not an int.
|
||||
ValueError: If `pad_dim_size` is less than 1.
|
||||
ValueError: If last dim of `x` is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
|
||||
>>> pad_dim_size = 4
|
||||
>>> output = F.padding(x, pad_dim_size)
|
||||
>>> print(output)
|
||||
[[ 8. 0. 0. 0.]
|
||||
[10. 0. 0. 0.]]
|
||||
"""
|
||||
padding_ = P.array_ops.Padding(pad_dim_size)
|
||||
return padding_(x)
|
||||
|
||||
|
||||
def one_hot(indices, depth, on_value, off_value, axis=-1):
|
||||
r"""
|
||||
Computes a one-hot tensor.
|
||||
|
@ -2114,6 +2147,7 @@ __all__ = [
|
|||
'unique_consecutive',
|
||||
'eye',
|
||||
'matrix_band_part',
|
||||
'padding',
|
||||
'fill',
|
||||
'fill_',
|
||||
'fills',
|
||||
|
|
|
@ -82,6 +82,7 @@ atan2_ = P.Atan2()
|
|||
bitwise_and_ = P.BitwiseAnd()
|
||||
bitwise_or_ = P.BitwiseOr()
|
||||
bitwise_xor_ = P.BitwiseXor()
|
||||
inv_ = P.math_ops.Inv()
|
||||
invert_ = P.Invert()
|
||||
erf_ = P.Erf()
|
||||
erfc_ = P.Erfc()
|
||||
|
@ -1372,12 +1373,41 @@ def bitwise_xor(x, y):
|
|||
return bitwise_xor_(x, y)
|
||||
|
||||
|
||||
def inv(x):
|
||||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = \frac{1}{x_{i} }
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of any dimension. Must be one of the following types: float16, float32 or int32.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is not one of float16, float32, int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
|
||||
>>> output = F.inv(x)
|
||||
>>> print(output)
|
||||
[4. 2.5 3.2258065 1.923077 ]
|
||||
"""
|
||||
return inv_(x)
|
||||
|
||||
|
||||
def invert(x):
|
||||
r"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = ~x_{i}
|
||||
|
||||
Args:
|
||||
|
@ -1391,11 +1421,12 @@ def invert(x):
|
|||
TypeError: If dtype of `x` is neither int16 nor uint16.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
|
||||
>>> output = ops.invert(x)
|
||||
>>> output = F.invert(x)
|
||||
>>> print(output)
|
||||
[-26 -5 -14 -10]
|
||||
"""
|
||||
|
@ -2250,7 +2281,7 @@ def logaddexp(x1, x2):
|
|||
log_op = P.Log()
|
||||
exp_op = P.Exp()
|
||||
|
||||
y = log_op(exp_op(x1)+exp_op(x2))
|
||||
y = log_op(exp_op(x1) + exp_op(x2))
|
||||
return y
|
||||
|
||||
|
||||
|
@ -2532,6 +2563,7 @@ def deg2rad(x):
|
|||
out = x * math.pi / 180.0
|
||||
return out
|
||||
|
||||
|
||||
#####################################
|
||||
# Reduction Operation Functions.
|
||||
#####################################
|
||||
|
@ -2649,6 +2681,7 @@ __all__ = [
|
|||
'bitwise_and',
|
||||
'bitwise_or',
|
||||
'bitwise_xor',
|
||||
'inv',
|
||||
'invert',
|
||||
'erf',
|
||||
'erfc',
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
fast_gelu_ = P.FastGeLU()
|
||||
softsign_ = P.Softsign()
|
||||
|
||||
|
||||
def fast_gelu(x):
|
||||
r"""
|
||||
Fast Gaussian Error Linear Units activation function.
|
||||
|
@ -91,8 +93,42 @@ def hardshrink(x, lambd=0.5):
|
|||
return hshrink_op(x)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
r"""
|
||||
Softsign activation function.
|
||||
|
||||
The function is shown as follows:
|
||||
|
||||
.. math::
|
||||
\text{SoftSign}(x) = \frac{x}{1 + |x|}
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
additional dimensions, with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
|
||||
>>> output = F.softsign(x)
|
||||
>>> print(output)
|
||||
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
|
||||
"""
|
||||
return softsign_(x)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'fast_gelu',
|
||||
'hardshrink'
|
||||
'hardshrink',
|
||||
'softsign'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -901,6 +901,10 @@ tensor_operator_registry.register('one_hot', P.OneHot)
|
|||
tensor_operator_registry.register('masked_select', masked_select)
|
||||
tensor_operator_registry.register('nonzero', nonzero)
|
||||
tensor_operator_registry.register('matrix_diag', matrix_diag)
|
||||
tensor_operator_registry.register('inv', inv)
|
||||
tensor_operator_registry.register('invert', invert)
|
||||
tensor_operator_registry.register('matrix_band_part', matrix_band_part)
|
||||
tensor_operator_registry.register('padding', padding)
|
||||
tensor_operator_registry.register('hardshrink', P.HShrink)
|
||||
tensor_operator_registry.register('svd', linalg_ops.Svd)
|
||||
tensor_operator_registry.register('diag', P.Diag)
|
||||
|
|
|
@ -968,28 +968,16 @@ class Padding(Primitive):
|
|||
"""
|
||||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
|
||||
Args:
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
|
||||
The last dimension of `x` must be 1. The data type is Number.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `pad_dim_size` is not an int.
|
||||
ValueError: If `pad_dim_size` is less than 1.
|
||||
ValueError: If last dim of `x` is not equal to 1.
|
||||
Refer to :func:`mindspore.ops.padding` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations.array_ops import Padding
|
||||
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
|
||||
>>> pad_dim_size = 4
|
||||
>>> output = ops.Padding(pad_dim_size)(x)
|
||||
>>> output = Padding(pad_dim_size)(x)
|
||||
>>> print(output)
|
||||
[[ 8. 0. 0. 0.]
|
||||
[10. 0. 0. 0.]]
|
||||
|
@ -1435,7 +1423,7 @@ class MatrixBandPart(Primitive):
|
|||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Refer to :func:`mindspore.ops.matrix_band_part` for more detail.
|
||||
Refer to :func:`mindspore.ops.matrix_band_part` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
|
|
@ -4780,25 +4780,14 @@ class Inv(Primitive):
|
|||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = \frac{1}{x_{i} }
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of any dimension. Must be one of the following types: float16, float32 or int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is not one of float16, float32, int32.
|
||||
Refer to :func:`mindspore.ops.inv` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> inv = ops.Inv()
|
||||
>>> from mindspore.ops.operations.math_ops import Inv
|
||||
>>> inv = Inv()
|
||||
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
|
||||
>>> output = inv(x)
|
||||
>>> print(output)
|
||||
|
@ -4814,13 +4803,14 @@ class Invert(Primitive):
|
|||
r"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
|
||||
Refer to :func:`mindspore.ops.invert` for more detail.
|
||||
Refer to :func:`mindspore.ops.invert` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> invert = ops.Invert()
|
||||
>>> from mindspore.ops.operations.math_ops import Invert
|
||||
>>> invert = Invert()
|
||||
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
|
||||
>>> output = invert(x)
|
||||
>>> print(output)
|
||||
|
|
|
@ -429,29 +429,15 @@ class Softsign(Primitive):
|
|||
r"""
|
||||
Softsign activation function.
|
||||
|
||||
The function is shown as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{SoftSign}(x) = \frac{x}{1 + |x|}
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
additional dimensions, with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
Refer to :func:`mindspore.ops.softsign` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations.nn_ops import Softsign
|
||||
>>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
|
||||
>>> softsign = ops.Softsign()
|
||||
>>> softsign = Softsign()
|
||||
>>> output = softsign(input_x)
|
||||
>>> print(output)
|
||||
[ 0. -0.5 0.6666667 0.9677419 -0.9677419]
|
||||
|
@ -9565,7 +9551,6 @@ class PSROIPooling(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, spatial_scale, group_size, output_dim):
|
||||
|
||||
"""Initialize PSROIPooling"""
|
||||
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
||||
validator.check_value_type("group_size", group_size, [int], self.name)
|
||||
|
|
Loading…
Reference in New Issue