!45522 tensor_logaddexp_logaddexp2_logsumexp_master

Merge pull request !45522 from yide12/tensor_log_master
This commit is contained in:
i-robot 2022-11-22 03:46:44 +00:00 committed by Gitee
commit 31c2e88b46
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 268 additions and 13 deletions

View File

@ -207,6 +207,8 @@ mindspore.ops.function
mindspore.ops.log2
mindspore.ops.log10
mindspore.ops.log1p
mindspore.ops.logaddexp
mindspore.ops.logaddexp2
mindspore.ops.logical_and
mindspore.ops.logical_not
mindspore.ops.logical_or

View File

@ -0,0 +1,6 @@
mindspore.Tensor.logaddexp
==========================
.. py:method:: mindspore.Tensor.logaddexp(other)
详情请参考 :func:`mindspore.ops.logaddexp`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.logaddexp2
===========================
.. py:method:: mindspore.Tensor.logaddexp2(other)
详情请参考 :func:`mindspore.ops.logaddexp2`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.logsumexp
==========================
.. py:method:: mindspore.Tensor.logsumexp(dim, keepdim=False)
详情请参考 :func:`mindspore.ops.logsumexp`

View File

@ -145,6 +145,9 @@ mindspore.Tensor
mindspore.Tensor.log
mindspore.Tensor.log10
mindspore.Tensor.log2
mindspore.Tensor.logaddexp
mindspore.Tensor.logaddexp2
mindspore.Tensor.logsumexp
mindspore.Tensor.log_matrix_determinant
mindspore.Tensor.log1p
mindspore.Tensor.logical_and

View File

@ -0,0 +1,20 @@
mindspore.ops.logaddexp
=======================
.. py:function:: mindspore.ops.logaddexp(x1, x2)
计算输入的指数和的对数。
.. math::
out_i = log(exp(x1_i) + exp(x2_i))
参数:
- **x1** (Tensor) - 输入Tensor。
- **x2** (Tensor) - 输入Tensor。如果 `x1` 的shape不等于 `x2` 的shape它们必须被广播成相同shape(输出的形状)。
返回:
Tensor。
异常:
- **TypeError** - `x1``x2` 不是Tensor。

View File

@ -0,0 +1,20 @@
mindspore.ops.logaddexp2
========================
.. py:function:: mindspore.ops.logaddexp2(x1, x2)
计算以2为底的输入的指数和的对数。
.. math::
out_i = log_2(2^{x1_i} + 2^{x2_i})
参数:
- **x1** (Tensor) - 输入Tensor。
- **x2** (Tensor) - 输入Tensor。如果 `x1` 的shape不等于 `x2` 的shape它们必须被广播成相同shape(输出的形状)。
返回:
Tensor。
异常:
- **TypeError** - `x1``x2` 不是Tensor。

View File

@ -151,6 +151,9 @@
mindspore.Tensor.log
mindspore.Tensor.log10
mindspore.Tensor.log2
mindspore.Tensor.logaddexp
mindspore.Tensor.logaddexp2
mindspore.Tensor.logsumexp
mindspore.Tensor.log_matrix_determinant
mindspore.Tensor.log1p
mindspore.Tensor.logical_and

View File

@ -208,6 +208,8 @@ Element-by-Element Operations
mindspore.ops.log2
mindspore.ops.log10
mindspore.ops.log1p
mindspore.ops.logaddexp
mindspore.ops.logaddexp2
mindspore.ops.logical_and
mindspore.ops.logical_not
mindspore.ops.logical_or

View File

@ -256,6 +256,9 @@ BuiltInTypeMap &GetMethodMap() {
{"remainder", std::string("remainder")}, // remainder()
{"log10", std::string("log10")}, // F.log10()
{"log2", std::string("log2")}, // F.log2()
{"logaddexp", std::string("logaddexp")}, // logaddexp()
{"logaddexp2", std::string("logaddexp2")}, // logaddexp2()
{"logsumexp", std::string("logsumexp")}, // logsumexp()
{"minimum", std::string("minimum")}, // P.Minimum()
{"cosh", std::string("cosh")}, // P.Cosh()
{"tanh", std::string("tanh")}, // P.Tanh()

View File

@ -1083,6 +1083,28 @@ def log2(x):
return F.log2(x)
def logaddexp(x, other):
"""
Computes the logarithm of the sum of exponentiations of the inputs.
"""
return F.logaddexp(x, other)
def logaddexp2(x, other):
"""
Computes the logarithm of the sum of exponentiations in base of 2 of the inputs.
"""
return F.logaddexp2(x, other)
def logsumexp(x, dim, keepdim=False):
"""
Reduces a dimension of a tensor by calculating exponential for all elements in the dimension,
then calculate logarithm of the sum.
"""
return F.logsumexp(x, dim, keepdim)
def round_(x):
"""
Returns half to even of a tensor element-wise.

View File

@ -1418,6 +1418,24 @@ class Tensor(Tensor_):
validator.check_value_type('eps', eps, (float,), 'Tensor.logit')
return tensor_operator_registry.get('logit')(self, eps)
def logaddexp(self, other):
r"""
For details, please refer to :func:`mindspore.ops.logaddexp`.
"""
return tensor_operator_registry.get('logaddexp')(self, other)
def logaddexp2(self, other):
r"""
For details, please refer to :func:`mindspore.ops.logaddexp2`.
"""
return tensor_operator_registry.get('logaddexp2')(self, other)
def logsumexp(self, dim, keepdim=False):
r"""
For details, please refer to :func:`mindspore.ops.logsumexp`.
"""
return tensor_operator_registry.get('logsumexp')(self, dim, keepdim)
def log_matrix_determinant(self):
r"""
For details, please refer to :func:`mindspore.ops.log_matrix_determinant`.

View File

@ -3240,7 +3240,6 @@ def isreal(x):
Inputs:
- **x** (Tensor) - The input tensor.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
@ -3595,18 +3594,17 @@ def logaddexp(x1, x2):
"""
Computes the logarithm of the sum of exponentiations of the inputs.
Calculates ``log(exp(x1) + exp(x2))``. This function is useful in statistics, where the
computed probability of an event may be so small that it exceeds the range of a normal
floating point number. In this case, the logarithm of the calculated probability is stored.
This function allows to add probabilities stored in this way.
.. math::
out_i = log(exp(x1_i) + exp(x2_i))
Args:
x1 (Tensor): Input Tensor.
x2 (Tensor): Input Tensor. If ``x1.shape != x2.shape``, they must be broadcastable to
x2 (Tensor): Input Tensor. If the shape of `x1` is not equal to the shape of `x2`, they must be broadcastable to
a common shape (which becomes the shape of the output).
Returns:
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
Tensor.
Raises:
TypeError: If `x1`, `x2` is not a Tensor.
@ -3624,7 +3622,6 @@ def logaddexp(x1, x2):
log_op = _get_cache_prim(P.Log)()
exp_op = _get_cache_prim(P.Exp)()
y = log_op(exp_op(x1) + exp_op(x2))
return y
@ -3633,10 +3630,9 @@ def logaddexp2(x1, x2):
"""
Computes the logarithm of the sum of exponentiations in base of 2 of the inputs.
Calculates ``log2(2**x1 + 2**x2)``. This function is useful in machine learning when the computed
probability of an event may be small beyond the range of normal floating point numbers.
In this case, the base-2 logarithm of the calculated probability can be used instead.
This function allows to add probabilities stored in this way.
.. math::
out_i = log_2(2^{x1_i} + 2^{x2_i})
Args:
x1 (Tensor): Input tensor.
@ -3644,7 +3640,7 @@ def logaddexp2(x1, x2):
a common shape (which becomes the shape of the output).
Returns:
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
Tensor.
Raises:
TypeError: If `x1`, `x2` is not a Tensor.

View File

@ -199,6 +199,9 @@ tensor_operator_registry.register('nonzero', nonzero)
tensor_operator_registry.register('i0', i0)
tensor_operator_registry.register('isclose', isclose)
tensor_operator_registry.register('inv', inv)
tensor_operator_registry.register('logaddexp', logaddexp)
tensor_operator_registry.register('logaddexp2', logaddexp2)
tensor_operator_registry.register('logsumexp', logsumexp)
tensor_operator_registry.register('invert', invert)
tensor_operator_registry.register('hardshrink', P.HShrink)
tensor_operator_registry.register('heaviside', heaviside)

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x1, x2):
return x1.logaddexp(x2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_logaddexp(mode):
"""
Feature: tensor.logaddexp
Description: Verify the result of logaddexp
Expectation: success
"""
ms.set_context(mode=mode)
x1 = ms.Tensor([-100, 1, 30], ms.float32)
x2 = ms.Tensor([-1, -1, 3], ms.float32)
net = Net()
output = net(x1, x2)
expect_output = [-0.99999994, 1.1269258, 30.]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x1, x2):
return x1.logaddexp2(x2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_logaddexp2(mode):
"""
Feature: tensor.logaddexp2
Description: Verify the result of logaddexp2
Expectation: success
"""
ms.set_context(mode=mode)
x1 = ms.Tensor([-100, 1, 10], ms.float32)
x2 = ms.Tensor([-1, -1, 3], ms.float32)
net = Net()
output = net(x1, x2)
expect_output = [-1.0000042, 1.3219349, 10.011246]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, dim):
return x.logsumexp(dim, keepdim=True)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_logsumexp(mode):
"""
Feature: tensor.logsumexp
Description: Verify the result of logsumexp
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.]], ms.float32)
net = Net()
output = net(x, dim=0)
expect_output = [[9.0184765, 10.01848, 11.018479, 12.018477]]
assert np.allclose(output.asnumpy(), expect_output)
output = net(x, dim=1)
expect_output = [[4.440187],
[8.440188],
[12.440187]]
assert np.allclose(output.asnumpy(), expect_output)