forked from mindspore-Ecosystem/mindspore
tensor_inverse_master
This commit is contained in:
parent
dc5d0d782c
commit
c4b0223330
|
@ -203,6 +203,7 @@ mindspore.ops
|
||||||
mindspore.ops.hypot
|
mindspore.ops.hypot
|
||||||
mindspore.ops.i0
|
mindspore.ops.i0
|
||||||
mindspore.ops.inv
|
mindspore.ops.inv
|
||||||
|
mindspore.ops.inverse
|
||||||
mindspore.ops.invert
|
mindspore.ops.invert
|
||||||
mindspore.ops.lcm
|
mindspore.ops.lcm
|
||||||
mindspore.ops.ldexp
|
mindspore.ops.ldexp
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
mindspore.Tensor.inverse
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. py:method:: mindspore.Tensor.inverse()
|
||||||
|
|
||||||
|
详情请参考 :func:`mindspore.ops.inverse`。
|
|
@ -129,6 +129,7 @@ mindspore.Tensor
|
||||||
mindspore.Tensor.inplace_update
|
mindspore.Tensor.inplace_update
|
||||||
mindspore.Tensor.int
|
mindspore.Tensor.int
|
||||||
mindspore.Tensor.inv
|
mindspore.Tensor.inv
|
||||||
|
mindspore.Tensor.inverse
|
||||||
mindspore.Tensor.invert
|
mindspore.Tensor.invert
|
||||||
mindspore.Tensor.isclose
|
mindspore.Tensor.isclose
|
||||||
mindspore.Tensor.isfinite
|
mindspore.Tensor.isfinite
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
mindspore.ops.inverse
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.inverse(x)
|
||||||
|
|
||||||
|
计算输入矩阵的逆。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **x** (Tensor) - 计算的矩阵。`x` 至少是两维的,最后两个维度大小相同。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,shape和类型和 `x` 相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `x` 不是Tensor。
|
||||||
|
- **ValueError** - `x` 最后两个维度的大小不相同。
|
||||||
|
- **ValueError** - `x` 的维数小于2。
|
|
@ -135,6 +135,7 @@
|
||||||
mindspore.Tensor.inplace_update
|
mindspore.Tensor.inplace_update
|
||||||
mindspore.Tensor.int
|
mindspore.Tensor.int
|
||||||
mindspore.Tensor.inv
|
mindspore.Tensor.inv
|
||||||
|
mindspore.Tensor.inverse
|
||||||
mindspore.Tensor.invert
|
mindspore.Tensor.invert
|
||||||
mindspore.Tensor.isclose
|
mindspore.Tensor.isclose
|
||||||
mindspore.Tensor.isfinite
|
mindspore.Tensor.isfinite
|
||||||
|
|
|
@ -204,6 +204,7 @@ Element-by-Element Operations
|
||||||
mindspore.ops.hypot
|
mindspore.ops.hypot
|
||||||
mindspore.ops.i0
|
mindspore.ops.i0
|
||||||
mindspore.ops.inv
|
mindspore.ops.inv
|
||||||
|
mindspore.ops.inverse
|
||||||
mindspore.ops.invert
|
mindspore.ops.invert
|
||||||
mindspore.ops.lcm
|
mindspore.ops.lcm
|
||||||
mindspore.ops.ldexp
|
mindspore.ops.ldexp
|
||||||
|
|
|
@ -287,6 +287,7 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
{"is_floating_point", std::string("is_floating_point")}, // is_floating_point()
|
{"is_floating_point", std::string("is_floating_point")}, // is_floating_point()
|
||||||
{"is_signed", std::string("is_signed")}, // is_signed()
|
{"is_signed", std::string("is_signed")}, // is_signed()
|
||||||
{"inv", std::string("inv")}, // inv()
|
{"inv", std::string("inv")}, // inv()
|
||||||
|
{"inverse", std::string("inverse")}, // inverse()
|
||||||
{"invert", std::string("invert")}, // invert()
|
{"invert", std::string("invert")}, // invert()
|
||||||
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
||||||
{"take", std::string("take")}, // P.GatherNd()
|
{"take", std::string("take")}, // P.GatherNd()
|
||||||
|
|
|
@ -1418,6 +1418,13 @@ def inv(x):
|
||||||
return F.inv(x)
|
return F.inv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def inverse(x):
|
||||||
|
"""
|
||||||
|
Computes the inverse of a square matrix.
|
||||||
|
"""
|
||||||
|
return F.inverse(x)
|
||||||
|
|
||||||
|
|
||||||
def invert(x):
|
def invert(x):
|
||||||
"""
|
"""
|
||||||
Flips all bits of input tensor element-wise.
|
Flips all bits of input tensor element-wise.
|
||||||
|
|
|
@ -1478,6 +1478,13 @@ class Tensor(Tensor_):
|
||||||
self._init_check()
|
self._init_check()
|
||||||
return tensor_operator_registry.get('inv')(self)
|
return tensor_operator_registry.get('inv')(self)
|
||||||
|
|
||||||
|
def inverse(self):
|
||||||
|
r"""
|
||||||
|
For details, please refer to :func:`mindspore.ops.inverse`.
|
||||||
|
"""
|
||||||
|
self._init_check()
|
||||||
|
return tensor_operator_registry.get('inverse')(self)
|
||||||
|
|
||||||
def invert(self):
|
def invert(self):
|
||||||
r"""
|
r"""
|
||||||
For details, please refer to :func:`mindspore.ops.invert`.
|
For details, please refer to :func:`mindspore.ops.invert`.
|
||||||
|
|
|
@ -216,6 +216,7 @@ from .math_func import (
|
||||||
inplace_sub,
|
inplace_sub,
|
||||||
inplace_update,
|
inplace_update,
|
||||||
inv,
|
inv,
|
||||||
|
inverse,
|
||||||
invert,
|
invert,
|
||||||
minimum,
|
minimum,
|
||||||
renorm,
|
renorm,
|
||||||
|
|
|
@ -2123,6 +2123,36 @@ def inv(x):
|
||||||
return inv_(x)
|
return inv_(x)
|
||||||
|
|
||||||
|
|
||||||
|
def inverse(x):
|
||||||
|
"""
|
||||||
|
Compute the inverse of the input matrix.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - A matrix to be calculated. Input `x` must be at least two dimensions, and the size of
|
||||||
|
the last two dimensions must be the same size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same type and shape as input `x`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `x` is not a Tensor.
|
||||||
|
ValueError: If the size of the last two dimensions of `x` is not the same.
|
||||||
|
ValueError: If the dimension of `x` is less than 2.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor([[1., 2.], [3., 4.]], ms.float32)
|
||||||
|
>>> print(ops.inverse(x))
|
||||||
|
[[-2. 1. ]
|
||||||
|
[ 1.5 -0.5]]
|
||||||
|
"""
|
||||||
|
if x.dtype in mstype.int_type:
|
||||||
|
_get_cache_prim(P.Cast)()(x, mstype.float64)
|
||||||
|
return _get_cache_prim(P.MatrixInverse)()(x)
|
||||||
|
|
||||||
|
|
||||||
def invert(x):
|
def invert(x):
|
||||||
r"""
|
r"""
|
||||||
Flips all bits of input tensor element-wise.
|
Flips all bits of input tensor element-wise.
|
||||||
|
@ -7484,6 +7514,7 @@ __all__ = [
|
||||||
'bitwise_or',
|
'bitwise_or',
|
||||||
'bitwise_xor',
|
'bitwise_xor',
|
||||||
'inv',
|
'inv',
|
||||||
|
'inverse',
|
||||||
'invert',
|
'invert',
|
||||||
'erf',
|
'erf',
|
||||||
'erfc',
|
'erfc',
|
||||||
|
|
|
@ -1419,7 +1419,7 @@ def fractional_max_pool2d(input_x, kernel_size, output_size=None, output_ratio=N
|
||||||
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
|
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
|
||||||
f"at the same time, but got {output_ratio} and {output_size} .")
|
f"at the same time, but got {output_ratio} and {output_size} .")
|
||||||
if len(input_x.shape) == 3:
|
if len(input_x.shape) == 3:
|
||||||
input_x.expend_dims(axis=0)
|
input_x = input_x.expand_dims(axis=0)
|
||||||
if _random_samples is None:
|
if _random_samples is None:
|
||||||
_random_samples = Tensor([[[0, 0]]], mstype.float32)
|
_random_samples = Tensor([[[0, 0]]], mstype.float32)
|
||||||
if output_ratio is not None:
|
if output_ratio is not None:
|
||||||
|
@ -1517,8 +1517,6 @@ def fractional_max_pool3d(input_x, kernel_size, output_size=None, output_ratio=N
|
||||||
if output_ratio is not None and output_size is not None or output_ratio is None and output_size is None:
|
if output_ratio is not None and output_size is not None or output_ratio is None and output_size is None:
|
||||||
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
|
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
|
||||||
f"at the same time, but got {output_ratio} and {output_size} .")
|
f"at the same time, but got {output_ratio} and {output_size} .")
|
||||||
if len(input_x.shape) == 4:
|
|
||||||
input_x.expend_dims(axis=0)
|
|
||||||
if _random_samples is None:
|
if _random_samples is None:
|
||||||
_random_samples = Tensor([[[0, 0, 0]]], mstype.float32)
|
_random_samples = Tensor([[[0, 0, 0]]], mstype.float32)
|
||||||
if output_ratio is not None:
|
if output_ratio is not None:
|
||||||
|
|
|
@ -204,6 +204,7 @@ tensor_operator_registry.register('inv', inv)
|
||||||
tensor_operator_registry.register('logaddexp', logaddexp)
|
tensor_operator_registry.register('logaddexp', logaddexp)
|
||||||
tensor_operator_registry.register('logaddexp2', logaddexp2)
|
tensor_operator_registry.register('logaddexp2', logaddexp2)
|
||||||
tensor_operator_registry.register('logsumexp', logsumexp)
|
tensor_operator_registry.register('logsumexp', logsumexp)
|
||||||
|
tensor_operator_registry.register('inverse', inverse)
|
||||||
tensor_operator_registry.register('invert', invert)
|
tensor_operator_registry.register('invert', invert)
|
||||||
tensor_operator_registry.register('hardshrink', P.HShrink)
|
tensor_operator_registry.register('hardshrink', P.HShrink)
|
||||||
tensor_operator_registry.register('heaviside', heaviside)
|
tensor_operator_registry.register('heaviside', heaviside)
|
||||||
|
|
|
@ -44,7 +44,7 @@ def test_fractional_maxpool2d_normal(mode):
|
||||||
"""
|
"""
|
||||||
ms.set_context(mode=mode)
|
ms.set_context(mode=mode)
|
||||||
net = FractionalMaxPool2dNet()
|
net = FractionalMaxPool2dNet()
|
||||||
input_x = Tensor(np.random.rand(25).reshape([1, 1, 5, 5]), mstype.float32)
|
input_x = Tensor(np.random.rand(25).reshape([1, 5, 5]), mstype.float32)
|
||||||
output1, output2 = net(input_x)
|
output1, output2 = net(input_x)
|
||||||
assert output1[0].shape == output1[1].shape == (1, 1, 2, 2)
|
assert output1[0].shape == output1[1].shape == (1, 1, 2, 2)
|
||||||
assert output2[0].shape == output2[1].shape == (1, 1, 2, 2)
|
assert output2[0].shape == output2[1].shape == (1, 1, 2, 2)
|
||||||
|
@ -88,11 +88,11 @@ def test_fractional_maxpool3d_normal(mode):
|
||||||
Expectation: Success
|
Expectation: Success
|
||||||
"""
|
"""
|
||||||
ms.set_context(mode=mode)
|
ms.set_context(mode=mode)
|
||||||
input_x = Tensor(np.random.rand(16).reshape([1, 1, 2, 2, 4]), mstype.float32)
|
input_x = Tensor(np.random.rand(16).reshape([1, 2, 2, 4]), mstype.float32)
|
||||||
net = FractionalMaxPool3dNet()
|
net = FractionalMaxPool3dNet()
|
||||||
output1, output2 = net(input_x)
|
output1, output2 = net(input_x)
|
||||||
assert output1[0].shape == output1[1].shape == (1, 1, 1, 1, 2)
|
assert output1[0].shape == output1[1].shape == (1, 1, 1, 2)
|
||||||
assert output2[0].shape == output2[1].shape == (1, 1, 1, 1, 2)
|
assert output2[0].shape == output2[1].shape == (1, 1, 1, 2)
|
||||||
input_x = Tensor([[[[[5.76273143e-001, 7.97047436e-001, 5.05385816e-001, 7.98332036e-001],
|
input_x = Tensor([[[[[5.76273143e-001, 7.97047436e-001, 5.05385816e-001, 7.98332036e-001],
|
||||||
[5.79880655e-001, 9.75979388e-001, 3.17571498e-002, 8.08261558e-002]],
|
[5.79880655e-001, 9.75979388e-001, 3.17571498e-002, 8.08261558e-002]],
|
||||||
[[3.82758647e-001, 7.09801614e-001, 4.39641386e-001, 5.71077049e-001],
|
[[3.82758647e-001, 7.09801614e-001, 4.39641386e-001, 5.71077049e-001],
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, ops
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return ops.inverse(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_ops_inverse(mode):
|
||||||
|
"""
|
||||||
|
Feature: ops.inverse
|
||||||
|
Description: Verify the result of inverse
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = Tensor([[1., 2, 3],
|
||||||
|
[4, 5., 6],
|
||||||
|
[8, 8, 9]], ms.float32)
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
expect_output = [[1.0000008, -2.000001, 1.0000005],
|
||||||
|
[-4.0000014, 5.000002, -2.000001],
|
||||||
|
[2.6666675, -2.6666675, 1.0000002]]
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x.inverse()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_tensor_inverse(mode):
|
||||||
|
"""
|
||||||
|
Feature: tensor.inverse
|
||||||
|
Description: Verify the result of inverse
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
x = ms.Tensor([[1., 2, 3],
|
||||||
|
[4, 5., 6],
|
||||||
|
[8, 8, 9]], ms.float32)
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
expect_output = [[1.0000008, -2.000001, 1.0000005],
|
||||||
|
[-4.0000014, 5.000002, -2.000001],
|
||||||
|
[2.6666675, -2.6666675, 1.0000002]]
|
||||||
|
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue