!37892 [MS][LITE] stack & unstack & atan2 & log1p support function & tensor interface

Merge pull request !37892 from jianghui58/leak-master
This commit is contained in:
i-robot 2022-07-15 01:03:41 +00:00 committed by Gitee
commit 8807bf038a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 353 additions and 120 deletions

View File

@ -140,6 +140,7 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
mindspore.ops.invert
mindspore.ops.lerp
mindspore.ops.log
mindspore.ops.log1p
mindspore.ops.logical_and
mindspore.ops.logical_not
mindspore.ops.logical_or
@ -364,6 +365,7 @@ Array操作
mindspore.ops.slice
mindspore.ops.space_to_batch_nd
mindspore.ops.split
mindspore.ops.stack
mindspore.ops.tensor_scatter_add
mindspore.ops.tensor_scatter_div
mindspore.ops.tensor_scatter_max
@ -380,6 +382,7 @@ Array操作
mindspore.ops.unique_consecutive
mindspore.ops.unique_with_pad
mindspore.ops.unsorted_segment_sum
mindspore.ops.unstack
.. list-table::
:widths: 50 50
@ -399,8 +402,6 @@ Array操作
- Refer to :class:`mindspore.ops.Sort`.
* - mindspore.ops.squeeze
- Refer to :class:`mindspore.ops.Squeeze`.
* - mindspore.ops.stack
- Refer to :class:`mindspore.ops.Stack`.
* - mindspore.ops.strided_slice
- Refer to :class:`mindspore.ops.StridedSlice`.
* - mindspore.ops.tensor_scatter_update

View File

@ -189,6 +189,27 @@ mindspore.Tensor
- **TypeError** - 指定了无法解析的类型。
.. py:method:: atan2(y)
逐元素计算x/y的反正切值。
`x` 指的当前 Tensor。
返回 :math:`\theta\ \in\ [-\pi, \pi]` ,使得 :math:`x = r*\sin(\theta), y = r*\cos(\theta)` 其中 :math:`r = \sqrt{x^2 + y^2}`
输入 `x``y` 会通过隐式数据类型转换使数据类型保持一致。如果数据类型不同,低精度的数据类型会被转换到高精度的数据类型。
**参数:**
- **y** (Tensor) - 输入Tensorshape应能在广播后与 `x` 相同,或 `x` 的shape在广播后与 `y` 相同。
**返回:**
Tensor与广播后的输入shape相同`x` 数据类型相同。
**异常:**
- **TypeError** - `x``y` 不是Tensor。
- **RuntimeError** - `x``y` 之间的数据类型转换不被支持
.. py:method:: bernoulli(p=0.5, seed=-1)
以p的概率随机将输出的元素设置为0或1服从伯努利分布。
@ -916,6 +937,21 @@ mindspore.Tensor
- **ValueError** - 如果 `end` 的维度信息无法相互广播到当前Tensor。
- **ValueError** - 如果 `weight` 为Tensor且 `weight` 的维度信息无法广播到当前Tensor。
.. py:method:: log1p()
对当前Tensor逐元素加一后计算自然对数。
.. math::
out_i = {log_e}(x_i + 1)
**返回:**
Tensor`x` 的shape相同。
**异常:**
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型非float16或float32。
.. py:method:: log_matrix_determinant()
计算一个或多个平方矩阵行列式绝对值的对数的符号和绝对值的对数。

View File

@ -5,20 +5,4 @@ mindspore.ops.Log1p
对输入Tensor逐元素加一后计算自然对数。
.. math::
out_i = {log_e}(x_i + 1)
**输入:**
- **x** (Tensor) - 输入Tensor。数据类型为float16或float32。
该值必须大于-1。
shape :math:`(N,*)` 其中 :math:`*` 表示任何数量的附加维度其秩应小于8。
**输出:**
Tensor`x` 的shape相同。
**异常:**
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型非float16或float32。
更多参考详见 :func:`mindspore.ops.log1p`

View File

@ -5,23 +5,4 @@
在指定轴上对输入Tensor序列进行堆叠。
输入秩为 `R` 的Tensor序列则输出秩为 `(R+1)` 的Tensor。
给定输入Tensor的shape为 :math:`(x_1, x_2, ..., x_R)` 。若输入Tensor的长度为 `N` 。如果存在 :math:`axis \ge 0` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`
**参数:**
- **axis** (int) - 指定堆叠运算的轴。取值范围为[-(R+1), R+1)。默认值0。
**输入:**
- **input_x** (Union[tuple, list]) - 输入多个Tensor对象组成的tuple或list每个Tensor具有相同shape和数据类型。
**输出:**
堆叠运算后的Tensor数据类型和 `input_x` 的相同。
**异常:**
- **TypeError** - `input_x` 中元素的数据类型不相同。
- **ValueError** - `input_x` 的长度不大于1或axis不在[-(R+1),R+1)范围中,或 `input_x` 中元素的shape不相同。
更多参考详见 :func:`mindspore.ops.stack`

View File

@ -5,24 +5,4 @@
根据指定轴对输入矩阵进行分解。
若输入Tensor在指定的轴上的rank为 `R` 则输出Tensor的rank为 `(R-1)`
给定一个shape为 :math:`(x_1, x_2, ..., x_R)` 的Tensor。如果存在 :math:`0 \le axis` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`
与Stack函数操作相反。
**参数:**
- **axis** (int) - 指定矩阵分解的轴。取值范围为[-R,R)默认值0。
**输入:**
- **input_x** (Tensor) - 输入Tensor其shape为 :math:`(x_1, x_2, ..., x_R)` 。rank必须大于0。
**输出:**
Tensor对象组成的tuple。每个Tensor对象的shape相同。
**异常:**
- **ValueError** - axis超出[-len(input_x.shape), len(input_x.shape))范围。
更多参考详见 :func:`mindspore.ops.unstack`

View File

@ -0,0 +1,23 @@
mindspore.ops.atan2
===================
.. py:function:: mindspore.ops.atan2
逐元素计算x/y的反正切值。
返回 :math:`\theta\ \in\ [-\pi, \pi]` ,使得 :math:`x = r*\sin(\theta), y = r*\cos(\theta)` 其中 :math:`r = \sqrt{x^2 + y^2}`
输入 `x``y` 会通过隐式数据类型转换使数据类型保持一致。如果数据类型不同,低精度的数据类型会被转换到高精度的数据类型。
**参数:**
- **x** (Tensor) - 输入Tensorshape: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。
- **y** (Tensor) - 输入Tensorshape应能在广播后与 `x` 相同,或 `x` 的shape在广播后与 `y` 相同。
**返回:**
Tensor与广播后的输入shape相同`x` 数据类型相同。
**异常:**
- **TypeError** - `x``y` 不是Tensor。
- **RuntimeError** - `x``y` 之间的数据类型转换不被支持

View File

@ -0,0 +1,24 @@
mindspore.ops.log1p
===================
.. py:function:: mindspore.ops.Log1p
对输入Tensor逐元素加一后计算自然对数。
.. math::
out_i = {log_e}(x_i + 1)
**参数:**
- **x** (Tensor) - 输入Tensor。数据类型为float16或float32。
该值必须大于-1。
shape :math:`(N,*)` 其中 :math:`*` 表示任何数量的附加维度其秩应小于8。
**返回:**
Tensor`x` 的shape相同。
**异常:**
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型非float16或float32。

View File

@ -0,0 +1,24 @@
mindspore.ops.stack
====================
.. py:function:: mindspore.ops.stack(input_x, axis)
在指定轴上对输入Tensor序列进行堆叠。
输入秩为 `R` 的Tensor序列则输出秩为 `(R+1)` 的Tensor。
给定输入Tensor的shape为 :math:`(x_1, x_2, ..., x_R)` 。若输入Tensor的长度为 `N` 。如果存在 :math:`axis \ge 0` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`
**参数:**
- **input_x** (Union[tuple, list]) - 输入多个Tensor对象组成的tuple或list每个Tensor具有相同shape和数据类型。
- **axis** (int) - 指定堆叠运算的轴。取值范围为[-(R+1), R+1)。默认值0。
**返回:**
堆叠运算后的Tensor数据类型和 `input_x` 的相同。
**异常:**
- **TypeError** - `input_x` 中元素的数据类型不相同。
- **ValueError** - `input_x` 的长度不大于1或axis不在[-(R+1),R+1)范围中,或 `input_x` 中元素的shape不相同。

View File

@ -0,0 +1,25 @@
mindspore.ops.unstack
=======================
.. py:function:: mindspore.ops.unstack(axis=0)
根据指定轴对输入矩阵进行分解。
若输入Tensor在指定的轴上的rank为 `R` 则输出Tensor的rank为 `(R-1)`
给定一个shape为 :math:`(x_1, x_2, ..., x_R)` 的Tensor。如果存在 :math:`0 \le axis` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`
与Stack函数操作相反。
**参数:**
- **input_x** (Tensor) - 输入Tensor其shape为 :math:`(x_1, x_2, ..., x_R)` 。rank必须大于0。
- **axis** (int) - 指定矩阵分解的轴。取值范围为[-R,R)默认值0。
**返回:**
Tensor对象组成的tuple。每个Tensor对象的shape相同。
**异常:**
- **ValueError** - axis超出[-len(input_x.shape), len(input_x.shape))范围。

View File

@ -152,6 +152,7 @@ BuiltInTypeMap &GetMethodMap() {
{"addcdiv", std::string("addcdiv")}, // C.addcdiv
{"addcmul", std::string("addcmul")}, // C.addcmul
{"all", std::string("all_")}, // C.reduce_all
{"atan2", std::string("atan2")}, // P.Atan2
{"any", std::string("any_")}, // C.reduce_any
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
@ -209,6 +210,7 @@ BuiltInTypeMap &GetMethodMap() {
{"copy", std::string("copy")}, // copy()
{"inplace_update", std::string("inplace_update")}, // P.InplaceUpdate
{"lerp", std::string("lerp")}, // lerp()
{"log1p", std::string("log1p")}, // P.Log1p()
{"log_matrix_determinant", std::string("log_matrix_determinant")}, // log_matrix_determinant()
{"matrix_determinant", std::string("matrix_determinant")}, // log_matrix_determinant()
{"max", std::string("max")}, // P.reduce_max()

View File

@ -182,6 +182,14 @@ def any_(x, axis=(), keep_dims=False):
return reduce_any(x, axis)
def atan2(x, y):
r"""
Computes the first input tensor multiplied by the logarithm of second input tensor element-wise.
Refer to :func:`mindspore.ops.atan2` for more details.
"""
return F.atan2(x, y)
def size_(x):
"""
Return the number of elements in tensor `x`.
@ -2143,6 +2151,14 @@ def matrix_determinant(x):
return F.matrix_determinant(x)
def log1p(x):
r"""
Returns the natural logarithm of one plus the input tensor element-wise.
Refer to :func:`mindspore.ops.log1p` for more detail.
"""
return F.log1p(x)
def log_matrix_determinant(x):
"""Computes the sign and the log of the absolute value of the determinant of one or more square matrices."""
return F.log_matrix_determinant(x)

View File

@ -682,6 +682,43 @@ class Tensor(Tensor_):
axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
def atan2(self, y):
r"""
Returns arctangent of x/y element-wise.
It returns :math:`\theta\ \in\ [-\pi, \pi]`
such that :math:`x = r*\sin(\theta), y = r*\cos(\theta)`, where :math:`r = \sqrt{x^2 + y^2}`.
Args of `x` and `y` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, the lower precision data type will be converted to
the relatively highest precision data type.
Args:
x (Tensor): The input tensor.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
y (Tensor): The input tensor. It has the same shape with `x`.
Returns:
Tensor, the shape is the same as the one after broadcasting,and the data type is same as `x`.
Raises:
TypeError: If `x` or `y` is not a Tensor.
RuntimeError: If the data type of `x` and `y` conversion of Parameter is required
when data type conversion of Parameter is not supported.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([0, 1]), mindspore.float32)
>>> y = Tensor(np.array([1, 1]), mindspore.float32)
>>> output = x.atan2(y)
>>> print(output)
[0. 0.7853982]
"""
self._init_check()
return tensor_operator_registry.get('atan2')(self, y)
def view(self, *shape):
"""
Reshape the tensor according to the input shape. It's the same as :func:`mindspore.Tensor.reshape`,
@ -1280,6 +1317,37 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('matrix_determinant')(self)
def log1p(self):
r"""
Returns the natural logarithm of one plus the input tensor element-wise.
.. math::
out_i = {log_e}(x_i + 1)
Args:
- **x** (Tensor) - The input tensor. With float16 or float32 data type.
The value must be greater than -1.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
Returns:
Tensor, has the same shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> output = x.log1p()
>>> print(output)
[0.6931472 1.0986123 1.609438 ]
"""
self._init_check()
return tensor_operator_registry.get('log1p')(self)
def log_matrix_determinant(self):
r"""
Computes the sign and the log of the absolute value of the determinant of one or more square matrices.

View File

@ -47,6 +47,8 @@ from .array_func import (
reshape_,
flatten,
concat,
stack,
unstack,
tensor_slice,
slice,
scalar_to_array,
@ -244,6 +246,7 @@ from .math_func import (
log2,
xlogy,
log10,
log1p,
approximate_equal,
frac,
kron

View File

@ -1100,6 +1100,81 @@ def concat(input_x, axis=0):
return _concat(input_x)
def stack(input_x, axis=0):
r"""
Stacks a list of tensors in specified axis.
Stacks the list of input tensors with the same rank `R`, output is a tensor of rank `(R+1)`.
Given input tensors of shape :math:`(x_1, x_2, ..., x_R)`. Set the number of input tensors as `N`.
If :math:`0 \le axis`, the shape of the output tensor is
:math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`.
Args:
input_x (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
axis (int): Dimension to stack. Default: 0.
Negative values wrap around. The range is [-(R+1), R+1).
Returns:
Tensor. A stacked Tensor with the same type as `input_x`.
Raises:
TypeError: If the data types of elements in `input_x` are not the same.
ValueError: If the length of `input_x` is not greater than 1;
or if axis is out of the range [-(R+1), R+1);
or if the shapes of elements in input_x are not the same.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x1 = Tensor(np.array([0, 1]).astype(np.float32))
>>> input_x2 = Tensor(np.array([2, 3]).astype(np.float32))
>>> output = ops.stack((input_x1, input_x2), 0)
>>> print(output)
[[0. 1.]
[2. 3.]]
"""
_stack = _get_cache_prim(P.Stack)(axis)
return _stack(input_x)
def unstack(input_x, axis=0):
r"""
Unstacks tensor in specified axis.
Unstacks a tensor of rank `R` along axis dimension, output tensors will have rank `(R-1)`.
Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
This is the opposite of pack.
Args:
input_x (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
A tensor to be unstacked and the rank of the tensor must be greater than 0.
axis (int): Dimension along which to unpack. Default: 0.
Negative values wrap around. The range is [-R, R).
Returns:
A tuple of tensors, the shape of each objects is the same.
Raises:
ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
>>> output = ops.unstack(input_x, 0)
>>> print(output)
(Tensor(shape=[4], dtype=Int64, value= [1, 1, 1, 1]), Tensor(shape=[4], dtype=Int64, value= [2, 2, 2, 2]))
"""
_unstack = _get_cache_prim(P.Unstack)(axis)
return _unstack(input_x)
def expand_dims(input_x, axis):
"""
Adds an additional dimension to `input_x` at the given axis.
@ -3902,6 +3977,8 @@ __all__ = [
'tensor_slice',
'slice',
'concat',
'stack',
'unstack',
'scalar_cast',
'scalar_to_array',
'scalar_to_tensor',

View File

@ -114,7 +114,6 @@ tanh_ = P.Tanh()
asinh_ = P.Asinh()
acosh_ = P.Acosh()
atanh_ = P.Atanh()
atan2_ = P.Atan2()
bitwise_and_ = P.BitwiseAnd()
bitwise_or_ = P.BitwiseOr()
bitwise_xor_ = P.BitwiseXor()
@ -1615,7 +1614,8 @@ def atan2(x, y):
>>> print(output)
[0. 0.7853982]
"""
return atan2_(x, y)
_atan2 = _get_cache_prim(P.Atan2)()
return _atan2(x, y)
def bitwise_and(x, y):
@ -4685,6 +4685,38 @@ def log10(x):
return output
def log1p(x):
r"""
Returns the natural logarithm of one plus the input tensor element-wise.
.. math::
out_i = {log_e}(x_i + 1)
Args:
- **x** (Tensor) - The input tensor. With float16 or float32 data type.
The value must be greater than -1.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
Returns:
Tensor, has the same shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> output = ops.log1p(x)
>>> print(output)
[0.6931472 1.0986123 1.609438 ]
"""
_log1p = _get_cache_prim(P.Log1p)()
return _log1p(x)
def kron(x, y):
"""
Computes the Kronecker product, denoted by , of `x` and `y`.
@ -5007,6 +5039,7 @@ __all__ = [
'log2',
'xlogy',
'log10',
'log1p',
'approximate_equal',
'frac',
'kron'

View File

@ -66,7 +66,6 @@ if not security.enable_security():
squeeze = P.Squeeze()
tensor_scatter_update = P.TensorScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
stack = P.Stack()
def pack(x):
@ -798,6 +797,7 @@ tensor_operator_registry.register('addcdiv', P.Addcdiv)
tensor_operator_registry.register('addcmul', P.Addcmul)
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('atan2', atan2)
tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('tan', P.Tan)
tensor_operator_registry.register('cosh', P.Cosh)
@ -827,6 +827,7 @@ tensor_operator_registry.register('random_categorical', random_categorical)
tensor_operator_registry.register('maximum', P.Maximum)
tensor_operator_registry.register('minimum', P.Minimum)
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
tensor_operator_registry.register('log1p', log1p)
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
tensor_operator_registry.register('ceil', P.Ceil)
tensor_operator_registry.register('fill', P.Fill)
@ -882,7 +883,8 @@ tensor_operator_registry.register('gather', gather)
tensor_operator_registry.register('gather_d', gather_d)
tensor_operator_registry.register('gather_elements', gather_elements)
tensor_operator_registry.register('gather_nd', gather_nd)
tensor_operator_registry.register('stack', P.Stack)
tensor_operator_registry.register('stack', stack)
tensor_operator_registry.register('unstack', unstack)
tensor_operator_registry.register('log', log)
tensor_operator_registry.register('lerp', lerp)
tensor_operator_registry.register('floor', floor)

View File

@ -101,6 +101,7 @@ class UnravelIndex(Primitive):
[[0 2]
[1 2]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Shape"""
@ -2874,27 +2875,7 @@ class Stack(PrimitiveWithInfer):
r"""
Stacks a list of tensors in specified axis.
Stacks the list of input tensors with the same rank `R`, output is a tensor of rank `(R+1)`.
Given input tensors of shape :math:`(x_1, x_2, ..., x_R)`. Set the number of input tensors as `N`.
If :math:`0 \le axis`, the shape of the output tensor is
:math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`.
Args:
axis (int): Dimension to stack. Default: 0.
Negative values wrap around. The range is [-(R+1), R+1).
Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
Outputs:
Tensor. A stacked Tensor with the same type as `input_x`.
Raises:
TypeError: If the data types of elements in `input_x` are not the same.
ValueError: If the length of `input_x` is not greater than 1;
or if axis is out of the range [-(R+1), R+1);
or if the shapes of elements in input_x are not the same.
Refer to :func:`mindspore.ops.stack` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -3018,26 +2999,7 @@ class Unstack(Primitive):
r"""
Unstacks tensor in specified axis.
Unstacks a tensor of rank `R` along axis dimension, output tensors will have rank `(R-1)`.
Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
This is the opposite of pack.
Args:
axis (int): Dimension along which to unpack. Default: 0.
Negative values wrap around. The range is [-R, R).
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
A tensor to be unstacked and the rank of the tensor must be greater than 0.
Outputs:
A tuple of tensors, the shape of each objects is the same.
Raises:
ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
Refer to :func:`mindspore.ops.unstack` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -6280,6 +6242,7 @@ class IdentityN(Primitive):
>>> print(output)
(Tensor(shape=[4], dtype=Int64, value= [1, 2, 3, 4]), Tensor(shape=[4], dtype=Int64, value= [4, 3, 1, 1]))
"""
@prim_attr_register
def __init__(self):
"""Initialize IdentityN"""

View File

@ -2586,23 +2586,10 @@ class Log(Primitive):
class Log1p(Primitive):
"""
r"""
Returns the natural logarithm of one plus the input tensor element-wise.
.. math::
out_i = {log_e}(x_i + 1)
Inputs:
- **x** (Tensor) - The input tensor. With float16 or float32 data type.
The value must be greater than -1.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
Outputs:
Tensor, has the same shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
Refer to :func:`mindspore.ops.log1p` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -4598,6 +4585,10 @@ class Atan2(_MathBinaryOp):
>>> print(output)
[0. 0.7853982]
"""
@prim_attr_register
def __init__(self):
"""Initialize Atan2"""
_MathBinaryOp.__init__(self)
class SquareSumAll(Primitive):