!45522 tensor_logaddexp_logaddexp2_logsumexp_master
Merge pull request !45522 from yide12/tensor_log_master
This commit is contained in:
commit
31c2e88b46
|
@ -207,6 +207,8 @@ mindspore.ops.function
|
|||
mindspore.ops.log2
|
||||
mindspore.ops.log10
|
||||
mindspore.ops.log1p
|
||||
mindspore.ops.logaddexp
|
||||
mindspore.ops.logaddexp2
|
||||
mindspore.ops.logical_and
|
||||
mindspore.ops.logical_not
|
||||
mindspore.ops.logical_or
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.logaddexp
|
||||
==========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.logaddexp(other)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.logaddexp`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.logaddexp2
|
||||
===========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.logaddexp2(other)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.logaddexp2`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.logsumexp
|
||||
==========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.logsumexp(dim, keepdim=False)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.logsumexp`。
|
|
@ -145,6 +145,9 @@ mindspore.Tensor
|
|||
mindspore.Tensor.log
|
||||
mindspore.Tensor.log10
|
||||
mindspore.Tensor.log2
|
||||
mindspore.Tensor.logaddexp
|
||||
mindspore.Tensor.logaddexp2
|
||||
mindspore.Tensor.logsumexp
|
||||
mindspore.Tensor.log_matrix_determinant
|
||||
mindspore.Tensor.log1p
|
||||
mindspore.Tensor.logical_and
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
mindspore.ops.logaddexp
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.logaddexp(x1, x2)
|
||||
|
||||
计算输入的指数和的对数。
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = log(exp(x1_i) + exp(x2_i))
|
||||
|
||||
参数:
|
||||
- **x1** (Tensor) - 输入Tensor。
|
||||
- **x2** (Tensor) - 输入Tensor。如果 `x1` 的shape不等于 `x2` 的shape,它们必须被广播成相同shape(输出的形状)。
|
||||
|
||||
返回:
|
||||
Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x1` 或 `x2` 不是Tensor。
|
|
@ -0,0 +1,20 @@
|
|||
mindspore.ops.logaddexp2
|
||||
========================
|
||||
|
||||
.. py:function:: mindspore.ops.logaddexp2(x1, x2)
|
||||
|
||||
计算以2为底的输入的指数和的对数。
|
||||
|
||||
.. math::
|
||||
|
||||
out_i = log_2(2^{x1_i} + 2^{x2_i})
|
||||
|
||||
参数:
|
||||
- **x1** (Tensor) - 输入Tensor。
|
||||
- **x2** (Tensor) - 输入Tensor。如果 `x1` 的shape不等于 `x2` 的shape,它们必须被广播成相同shape(输出的形状)。
|
||||
|
||||
返回:
|
||||
Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x1` 或 `x2` 不是Tensor。
|
|
@ -151,6 +151,9 @@
|
|||
mindspore.Tensor.log
|
||||
mindspore.Tensor.log10
|
||||
mindspore.Tensor.log2
|
||||
mindspore.Tensor.logaddexp
|
||||
mindspore.Tensor.logaddexp2
|
||||
mindspore.Tensor.logsumexp
|
||||
mindspore.Tensor.log_matrix_determinant
|
||||
mindspore.Tensor.log1p
|
||||
mindspore.Tensor.logical_and
|
||||
|
|
|
@ -208,6 +208,8 @@ Element-by-Element Operations
|
|||
mindspore.ops.log2
|
||||
mindspore.ops.log10
|
||||
mindspore.ops.log1p
|
||||
mindspore.ops.logaddexp
|
||||
mindspore.ops.logaddexp2
|
||||
mindspore.ops.logical_and
|
||||
mindspore.ops.logical_not
|
||||
mindspore.ops.logical_or
|
||||
|
|
|
@ -256,6 +256,9 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"remainder", std::string("remainder")}, // remainder()
|
||||
{"log10", std::string("log10")}, // F.log10()
|
||||
{"log2", std::string("log2")}, // F.log2()
|
||||
{"logaddexp", std::string("logaddexp")}, // logaddexp()
|
||||
{"logaddexp2", std::string("logaddexp2")}, // logaddexp2()
|
||||
{"logsumexp", std::string("logsumexp")}, // logsumexp()
|
||||
{"minimum", std::string("minimum")}, // P.Minimum()
|
||||
{"cosh", std::string("cosh")}, // P.Cosh()
|
||||
{"tanh", std::string("tanh")}, // P.Tanh()
|
||||
|
|
|
@ -1083,6 +1083,28 @@ def log2(x):
|
|||
return F.log2(x)
|
||||
|
||||
|
||||
def logaddexp(x, other):
|
||||
"""
|
||||
Computes the logarithm of the sum of exponentiations of the inputs.
|
||||
"""
|
||||
return F.logaddexp(x, other)
|
||||
|
||||
|
||||
def logaddexp2(x, other):
|
||||
"""
|
||||
Computes the logarithm of the sum of exponentiations in base of 2 of the inputs.
|
||||
"""
|
||||
return F.logaddexp2(x, other)
|
||||
|
||||
|
||||
def logsumexp(x, dim, keepdim=False):
|
||||
"""
|
||||
Reduces a dimension of a tensor by calculating exponential for all elements in the dimension,
|
||||
then calculate logarithm of the sum.
|
||||
"""
|
||||
return F.logsumexp(x, dim, keepdim)
|
||||
|
||||
|
||||
def round_(x):
|
||||
"""
|
||||
Returns half to even of a tensor element-wise.
|
||||
|
|
|
@ -1418,6 +1418,24 @@ class Tensor(Tensor_):
|
|||
validator.check_value_type('eps', eps, (float,), 'Tensor.logit')
|
||||
return tensor_operator_registry.get('logit')(self, eps)
|
||||
|
||||
def logaddexp(self, other):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.logaddexp`.
|
||||
"""
|
||||
return tensor_operator_registry.get('logaddexp')(self, other)
|
||||
|
||||
def logaddexp2(self, other):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.logaddexp2`.
|
||||
"""
|
||||
return tensor_operator_registry.get('logaddexp2')(self, other)
|
||||
|
||||
def logsumexp(self, dim, keepdim=False):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.logsumexp`.
|
||||
"""
|
||||
return tensor_operator_registry.get('logsumexp')(self, dim, keepdim)
|
||||
|
||||
def log_matrix_determinant(self):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.log_matrix_determinant`.
|
||||
|
|
|
@ -3240,7 +3240,6 @@ def isreal(x):
|
|||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input tensor.
|
||||
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input, and the dtype is bool.
|
||||
|
@ -3595,18 +3594,17 @@ def logaddexp(x1, x2):
|
|||
"""
|
||||
Computes the logarithm of the sum of exponentiations of the inputs.
|
||||
|
||||
Calculates ``log(exp(x1) + exp(x2))``. This function is useful in statistics, where the
|
||||
computed probability of an event may be so small that it exceeds the range of a normal
|
||||
floating point number. In this case, the logarithm of the calculated probability is stored.
|
||||
This function allows to add probabilities stored in this way.
|
||||
.. math::
|
||||
|
||||
out_i = log(exp(x1_i) + exp(x2_i))
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input Tensor.
|
||||
x2 (Tensor): Input Tensor. If ``x1.shape != x2.shape``, they must be broadcastable to
|
||||
x2 (Tensor): Input Tensor. If the shape of `x1` is not equal to the shape of `x2`, they must be broadcastable to
|
||||
a common shape (which becomes the shape of the output).
|
||||
|
||||
Returns:
|
||||
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x1`, `x2` is not a Tensor.
|
||||
|
@ -3624,7 +3622,6 @@ def logaddexp(x1, x2):
|
|||
|
||||
log_op = _get_cache_prim(P.Log)()
|
||||
exp_op = _get_cache_prim(P.Exp)()
|
||||
|
||||
y = log_op(exp_op(x1) + exp_op(x2))
|
||||
return y
|
||||
|
||||
|
@ -3633,10 +3630,9 @@ def logaddexp2(x1, x2):
|
|||
"""
|
||||
Computes the logarithm of the sum of exponentiations in base of 2 of the inputs.
|
||||
|
||||
Calculates ``log2(2**x1 + 2**x2)``. This function is useful in machine learning when the computed
|
||||
probability of an event may be small beyond the range of normal floating point numbers.
|
||||
In this case, the base-2 logarithm of the calculated probability can be used instead.
|
||||
This function allows to add probabilities stored in this way.
|
||||
.. math::
|
||||
|
||||
out_i = log_2(2^{x1_i} + 2^{x2_i})
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input tensor.
|
||||
|
@ -3644,7 +3640,7 @@ def logaddexp2(x1, x2):
|
|||
a common shape (which becomes the shape of the output).
|
||||
|
||||
Returns:
|
||||
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x1`, `x2` is not a Tensor.
|
||||
|
|
|
@ -199,6 +199,9 @@ tensor_operator_registry.register('nonzero', nonzero)
|
|||
tensor_operator_registry.register('i0', i0)
|
||||
tensor_operator_registry.register('isclose', isclose)
|
||||
tensor_operator_registry.register('inv', inv)
|
||||
tensor_operator_registry.register('logaddexp', logaddexp)
|
||||
tensor_operator_registry.register('logaddexp2', logaddexp2)
|
||||
tensor_operator_registry.register('logsumexp', logsumexp)
|
||||
tensor_operator_registry.register('invert', invert)
|
||||
tensor_operator_registry.register('hardshrink', P.HShrink)
|
||||
tensor_operator_registry.register('heaviside', heaviside)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x1, x2):
|
||||
return x1.logaddexp(x2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_logaddexp(mode):
|
||||
"""
|
||||
Feature: tensor.logaddexp
|
||||
Description: Verify the result of logaddexp
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x1 = ms.Tensor([-100, 1, 30], ms.float32)
|
||||
x2 = ms.Tensor([-1, -1, 3], ms.float32)
|
||||
net = Net()
|
||||
output = net(x1, x2)
|
||||
expect_output = [-0.99999994, 1.1269258, 30.]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x1, x2):
|
||||
return x1.logaddexp2(x2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_logaddexp2(mode):
|
||||
"""
|
||||
Feature: tensor.logaddexp2
|
||||
Description: Verify the result of logaddexp2
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x1 = ms.Tensor([-100, 1, 10], ms.float32)
|
||||
x2 = ms.Tensor([-1, -1, 3], ms.float32)
|
||||
net = Net()
|
||||
output = net(x1, x2)
|
||||
expect_output = [-1.0000042, 1.3219349, 10.011246]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, dim):
|
||||
return x.logsumexp(dim, keepdim=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_logsumexp(mode):
|
||||
"""
|
||||
Feature: tensor.logsumexp
|
||||
Description: Verify the result of logsumexp
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[1., 2., 3., 4.],
|
||||
[5., 6., 7., 8.],
|
||||
[9., 10., 11., 12.]], ms.float32)
|
||||
net = Net()
|
||||
output = net(x, dim=0)
|
||||
expect_output = [[9.0184765, 10.01848, 11.018479, 12.018477]]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
output = net(x, dim=1)
|
||||
expect_output = [[4.440187],
|
||||
[8.440188],
|
||||
[12.440187]]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue