forked from mindspore-Ecosystem/mindspore
add det method and st test cases.
This commit is contained in:
parent
eaacc10cab
commit
43ffa30789
|
@ -649,3 +649,74 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = lu_solve_core(m_lu, permutation, b, trans)
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
||||
def _det_2x2(a):
|
||||
return (a[..., 0, 0] * a[..., 1, 1] -
|
||||
a[..., 0, 1] * a[..., 1, 0])
|
||||
|
||||
|
||||
def _det_3x3(a):
|
||||
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
|
||||
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
|
||||
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
|
||||
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
|
||||
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
|
||||
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
|
||||
|
||||
|
||||
def det(a, overwrite_a=False, check_finite=True):
|
||||
"""
|
||||
Compute the determinant of a matrix
|
||||
|
||||
The determinant of a square matrix is a value derived arithmetically
|
||||
from the coefficients of the matrix.
|
||||
|
||||
The determinant for a 3x3 matrix, for example, is computed as follows::
|
||||
|
||||
a b c
|
||||
d e f = A
|
||||
g h i
|
||||
|
||||
det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
|
||||
|
||||
Args:
|
||||
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
|
||||
then it will be casted to :class:`mstype.float32`.
|
||||
overwrite_a (bool, optional): Allow overwriting data in a (may enhance performance).
|
||||
check_finite (bool, optional): Whether to check that the input matrix contains
|
||||
only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
|
||||
Raises:
|
||||
ValueError: If :math:`a` is not square.
|
||||
|
||||
Returns:
|
||||
Tensor, Determinant of `a`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.linalg import det
|
||||
>>> a = Tensor(onp.array([[0, 2, 3], [4, 5, 6], [7, 8, 9]])).astype(onp.float64)
|
||||
>>> det(a)
|
||||
3.0
|
||||
"""
|
||||
# special case
|
||||
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
||||
return _det_2x2(a)
|
||||
if a.ndim >= 2 and a.shape[-1] == 3 and a.shape[-2] == 3:
|
||||
return _det_3x3(a)
|
||||
if a.ndim < 2 or a.shape[-1] != a.shape[-2]:
|
||||
_raise_value_error("Arguments to det must be [..., n, n], but got shape {}.".format(a.shape))
|
||||
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
|
||||
lu_matrix, pivot = lu_factor(a)
|
||||
diag = lu_matrix.diagonal(axis1=-2, axis2=-1)
|
||||
pivot_not_equal = (pivot != mnp.arange(a.shape[-1])).astype(mstype.int64)
|
||||
pivot_sign = mnp.count_nonzero(pivot_not_equal, axis=-1)
|
||||
sign = -2. * (pivot_sign % 2) + 1.
|
||||
return sign * P.ReduceProd(keep_dims=False)(diag, -1)
|
||||
|
|
|
@ -23,6 +23,7 @@ import mindspore.nn as nn
|
|||
import mindspore.scipy as msp
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.scipy.linalg import det
|
||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
||||
create_random_rank_matrix
|
||||
|
||||
|
@ -328,6 +329,50 @@ def test_lu_solve(n: int, dtype):
|
|||
assert onp.allclose(msp_x.asnumpy(), osp_x, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_det(shape, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for det
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = onp.random.random(shape).astype(dtype)
|
||||
sp_det = osp.linalg.det(a)
|
||||
tensor_a = Tensor(a)
|
||||
ms_det = msp.linalg.det(tensor_a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(2, 3, 3), (2, 3, 5, 5)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_batch_det(shape, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test batch cases for det
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = onp.random.random(shape).astype(dtype)
|
||||
tensor_a = Tensor(a)
|
||||
ms_det = msp.linalg.det(tensor_a)
|
||||
sp_det = onp.empty(shape=ms_det.shape, dtype=dtype)
|
||||
for index, _ in onp.ndenumerate(sp_det):
|
||||
sp_det[index] = osp.linalg.det(a[index])
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -351,3 +396,30 @@ def test_block_diag_graph(args):
|
|||
|
||||
scipy_res = osp.linalg.block_diag(*args)
|
||||
match_array(ms_res.asnumpy(), scipy_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(3, 3), (5, 5), (10, 10), (20, 20)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_det_graph(shape, dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for det in graph mode
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
def construct(self, a):
|
||||
return det(a)
|
||||
|
||||
a = onp.random.random(shape).astype(dtype)
|
||||
sp_det = osp.linalg.det(a)
|
||||
tensor_a = Tensor(a)
|
||||
ms_det = TestNet()(tensor_a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(ms_det.asnumpy(), sp_det, rtol=rtol, atol=atol)
|
||||
|
|
Loading…
Reference in New Issue