forked from mindspore-Ecosystem/mindspore
add det method and st test cases.
This commit is contained in:
parent
eaacc10cab
commit
43ffa30789
|
@ -649,3 +649,74 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||||
x = lu_solve_core(m_lu, permutation, b, trans)
|
x = lu_solve_core(m_lu, permutation, b, trans)
|
||||||
return x[..., 0] if rhs_vector else x
|
return x[..., 0] if rhs_vector else x
|
||||||
|
|
||||||
|
|
||||||
|
def _det_2x2(a):
|
||||||
|
return (a[..., 0, 0] * a[..., 1, 1] -
|
||||||
|
a[..., 0, 1] * a[..., 1, 0])
|
||||||
|
|
||||||
|
|
||||||
|
def _det_3x3(a):
|
||||||
|
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
|
||||||
|
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
|
||||||
|
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
|
||||||
|
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
|
||||||
|
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
|
||||||
|
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def det(a, overwrite_a=False, check_finite=True):
|
||||||
|
"""
|
||||||
|
Compute the determinant of a matrix
|
||||||
|
|
||||||
|
The determinant of a square matrix is a value derived arithmetically
|
||||||
|
from the coefficients of the matrix.
|
||||||
|
|
||||||
|
The determinant for a 3x3 matrix, for example, is computed as follows::
|
||||||
|
|
||||||
|
a b c
|
||||||
|
d e f = A
|
||||||
|
g h i
|
||||||
|
|
||||||
|
det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
|
||||||
|
then it will be casted to :class:`mstype.float32`.
|
||||||
|
overwrite_a (bool, optional): Allow overwriting data in a (may enhance performance).
|
||||||
|
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||||
|
only finite numbers.
|
||||||
|
Disabling may give a performance gain, but may result in problems
|
||||||
|
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If :math:`a` is not square.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, Determinant of `a`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as onp
|
||||||
|
>>> from mindspore.common import Tensor
|
||||||
|
>>> from mindspore.scipy.linalg import det
|
||||||
|
>>> a = Tensor(onp.array([[0, 2, 3], [4, 5, 6], [7, 8, 9]])).astype(onp.float64)
|
||||||
|
>>> det(a)
|
||||||
|
3.0
|
||||||
|
"""
|
||||||
|
# special case
|
||||||
|
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
||||||
|
return _det_2x2(a)
|
||||||
|
if a.ndim >= 2 and a.shape[-1] == 3 and a.shape[-2] == 3:
|
||||||
|
return _det_3x3(a)
|
||||||
|
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
|
||||||
|
_raise_value_error("Arguments to det must be [..., n, n], but got shape {}.".format(a.shape))
|
||||||
|
|
||||||
|
if F.dtype(a) not in float_types:
|
||||||
|
a = F.cast(a, mstype.float32)
|
||||||
|
|
||||||
|
lu_matrix, pivot = lu_factor(a)
|
||||||
|
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
|
||||||
|
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)
|
||||||
|
pivot_sign = mnp.count_nonzero(pivot_not_equal, axis=-1)
|
||||||
|
sign = -2. * (pivot_sign % 2) + 1.
|
||||||
|
return sign * P.ReduceProd(keep_dims=False)(diag, -1)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import mindspore.nn as nn
|
||||||
import mindspore.scipy as msp
|
import mindspore.scipy as msp
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
import mindspore.numpy as mnp
|
import mindspore.numpy as mnp
|
||||||
|
from mindspore.scipy.linalg import det
|
||||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
||||||
create_random_rank_matrix
|
create_random_rank_matrix
|
||||||
|
|
||||||
|
@ -328,6 +329,50 @@ def test_lu_solve(n: int, dtype):
|
||||||
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)
|
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
|
||||||
|
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||||
|
def test_det(shape, dtype):
|
||||||
|
"""
|
||||||
|
Feature: ALL To ALL
|
||||||
|
Description: test cases for det
|
||||||
|
Expectation: the result match to scipy
|
||||||
|
"""
|
||||||
|
a = onp.random.random(shape).astype(dtype)
|
||||||
|
sp_det = osp.linalg.det(a)
|
||||||
|
tensor_a = Tensor(a)
|
||||||
|
ms_det = msp.linalg.det(tensor_a)
|
||||||
|
rtol = 1.e-5
|
||||||
|
atol = 1.e-5
|
||||||
|
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('shape', [(2, 3, 3), (2, 3, 5, 5)])
|
||||||
|
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||||
|
def test_batch_det(shape, dtype):
|
||||||
|
"""
|
||||||
|
Feature: ALL To ALL
|
||||||
|
Description: test batch cases for det
|
||||||
|
Expectation: the result match to scipy
|
||||||
|
"""
|
||||||
|
a = onp.random.random(shape).astype(dtype)
|
||||||
|
tensor_a = Tensor(a)
|
||||||
|
ms_det = msp.linalg.det(tensor_a)
|
||||||
|
sp_det = onp.empty(shape=ms_det.shape, dtype=dtype)
|
||||||
|
for index, _ in onp.ndenumerate(sp_det):
|
||||||
|
sp_det[index] = osp.linalg.det(a[index])
|
||||||
|
rtol = 1.e-5
|
||||||
|
atol = 1.e-5
|
||||||
|
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@ -351,3 +396,30 @@ def test_block_diag_graph(args):
|
||||||
|
|
||||||
scipy_res = osp.linalg.block_diag(*args)
|
scipy_res = osp.linalg.block_diag(*args)
|
||||||
match_array(ms_res.asnumpy(), scipy_res)
|
match_array(ms_res.asnumpy(), scipy_res)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
|
||||||
|
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||||
|
def test_det_graph(shape, dtype):
|
||||||
|
"""
|
||||||
|
Feature: ALL To ALL
|
||||||
|
Description: test cases for det in graph mode
|
||||||
|
Expectation: the result match to scipy
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
class TestNet(nn.Cell):
|
||||||
|
def construct(self, a):
|
||||||
|
return det(a)
|
||||||
|
|
||||||
|
a = onp.random.random(shape).astype(dtype)
|
||||||
|
sp_det = osp.linalg.det(a)
|
||||||
|
tensor_a = Tensor(a)
|
||||||
|
ms_det = TestNet()(tensor_a)
|
||||||
|
rtol = 1.e-5
|
||||||
|
atol = 1.e-5
|
||||||
|
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||||
|
|
Loading…
Reference in New Issue