forked from mindspore-Ecosystem/mindspore
fix solve_triangular and eigh api issue
This commit is contained in:
parent
33b74255d2
commit
29febb2e2a
|
@ -34,11 +34,11 @@ void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (A_shape.size() != kShape2dDims) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0]
|
||||
<< " X " << A_shape[kDim1] << "]";
|
||||
<< " X " << A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
}
|
||||
|
|
|
@ -42,20 +42,20 @@ void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
|
||||
if (A_shape.size() != kAMatrixDimNum) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix, but got [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "]";
|
||||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
|
||||
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions.";
|
||||
}
|
||||
if (SizeToInt(b_shape[kDim0]) != m_) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, b should match the shape of A, excepted [" << m_ << "] but got ["
|
||||
<< b_shape[kDim0] << "]";
|
||||
<< b_shape[kDim0] << "].";
|
||||
}
|
||||
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
|
||||
n_ = 1;
|
||||
|
@ -72,7 +72,7 @@ void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
} else if (trans == "C") {
|
||||
trans_ = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Trans should be in [N, T], but got [" << trans << "]";
|
||||
MS_LOG(EXCEPTION) << "Trans should be in [N, T, C], but got [" << trans << "].";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,9 +74,14 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
|
|||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (A_shape.size() != kShape2dDims || A_shape[1] != A_shape[1]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be square matrix, but got ["
|
||||
<< A_shape[0] << " X " << A_shape[1] << "]";
|
||||
if (A_shape.size() != kShape2dDims) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name_ << "', a should be 2D, but got ["
|
||||
<< A_shape.size() << "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, For '" << kernel_name_
|
||||
<< "', a should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[0];
|
||||
InitSizeLists();
|
||||
|
|
|
@ -59,9 +59,14 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
|
|||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be square matrix, but got ["
|
||||
<< A_shape[0] << " X " << A_shape[1] << "]";
|
||||
if (A_shape.size() != kShape2dDims) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name << "', a should be 2D, but got [" << A_shape.size()
|
||||
<< "] dimensions.";
|
||||
}
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, For '" << kernel_name
|
||||
<< "', a should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[0];
|
||||
InitSizeLists();
|
||||
|
|
|
@ -128,18 +128,19 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be square matrix, but got ["
|
||||
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "]";
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name
|
||||
<< "', the shape of input matrix A should be square matrix like [N X N], but got ["
|
||||
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
|
||||
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input should be 1 or 2, but got "
|
||||
<< b_shape.size();
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', b should be 1D or 2D, but got [" << b_shape.size()
|
||||
<< "] dimensions.";
|
||||
}
|
||||
if (b_shape[kDim0] != m_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be [" << m_ << "], but got ["
|
||||
<< b_shape[kDim0] << "]";
|
||||
<< b_shape[kDim0] << "].";
|
||||
}
|
||||
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
|
||||
n_ = 1;
|
||||
|
@ -156,8 +157,10 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
|
|||
trans_ = CUBLAS_OP_T;
|
||||
} else if (trans == "T") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else if (trans == "C") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', trans should be in [N, T], but got [" << trans << "]";
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', trans should be in [N, T, C], but got [" << trans << "].";
|
||||
}
|
||||
|
||||
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
|
||||
|
|
|
@ -98,7 +98,7 @@ def block_diag(*arrs):
|
|||
|
||||
|
||||
def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
||||
overwrite_b=False, debug=None, check_finite=True):
|
||||
overwrite_b=False, debug=None, check_finite=False):
|
||||
"""
|
||||
Assuming a is a triangular matrix, solve the equation
|
||||
|
||||
|
@ -106,16 +106,15 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
A x = b
|
||||
|
||||
Note:
|
||||
`solve_triangular` is not supported on Windows platform yet.
|
||||
- `solve_triangular` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
A (Tensor): A non-singular triangular matrix of shape :math:`(M, M)`. Note that if the input tensor is neither
|
||||
`float32` nor `float64`, then it will be cast to :class:`mstype.float32`.
|
||||
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(M, N)`.
|
||||
Right-hand side matrix in :math:`A x = b`. Note that if the input tensor is neither `float32` nor `float64`,
|
||||
then it will be cast to :class:`mstype.float32`.
|
||||
A (Tensor): A non-singular triangular matrix of shape :math:`(M, M)`.
|
||||
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(M, N)`. Right-hand side matrix in :math:`A x = b`.
|
||||
lower (bool, optional): Use only data contained in the lower triangle of `a`. Default: False.
|
||||
trans (0, 1, 2, 'N', 'T', 'C', optional): Type of system to solve. Default: 'N'.
|
||||
trans (0, 1, 2, 'N', 'T', 'C', optional): Type of system to solve. Default: 0.
|
||||
|
||||
======== =========
|
||||
trans system
|
||||
|
@ -127,6 +126,7 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and
|
||||
will not be referenced. Default: False.
|
||||
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance). Default: False.
|
||||
debug (None): Not implemented now. Default: False.
|
||||
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: False.
|
||||
|
@ -142,13 +142,13 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
TypeError: If dtype of `A` and `b` are not the same.
|
||||
RuntimeError: If shape of `A` and `b` are not matched or more than 2D.
|
||||
TypeError: If `trans` is not int or str.
|
||||
RuntimeError: If `trans` is not in set {0, 1, 2, 'N', 'T', 'C'}.
|
||||
ValueError: If `trans` is not in set {0, 1, 2, 'N', 'T', 'C'}.
|
||||
TypeError: If `lower` is not bool.
|
||||
TypeError: If `unit_diagonal` is not bool.
|
||||
TypeError: If `overwrite_b` is not bool.
|
||||
TypeError: If `check_finite` is not bool.
|
||||
ValueError: If `debug` is not None.
|
||||
ValueError: If `A` is singular matrix.
|
||||
ValueError: If the shape of `A` and `b` not match.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
@ -173,12 +173,14 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
|
|||
>>> print(mnp.dot(A, x)) # Check the result
|
||||
[4. 2. 4. 2.]
|
||||
"""
|
||||
_type_check('trans', trans, (int, str), 'solve_triangular')
|
||||
_type_check('overwrite_b', overwrite_b, bool, 'solve_triangular')
|
||||
_type_check('check_finite', check_finite, bool, 'solve_triangular')
|
||||
if F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float32)
|
||||
if F.dtype(b) in (mstype.int32, mstype.int64):
|
||||
b = F.cast(b, mstype.float32)
|
||||
if debug is not None:
|
||||
_raise_value_error("Currently only case debug=None of solve_triangular Implemented.")
|
||||
if F.dtype(A) == F.dtype(b) and F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float64)
|
||||
b = F.cast(b, mstype.float64)
|
||||
if trans not in (0, 1, 2, 'N', 'T', 'C'):
|
||||
_raise_value_error("The value of trans should be one of (0, 1, 2, 'N', 'T', 'C'), but got " + str(trans))
|
||||
if isinstance(trans, int):
|
||||
|
@ -427,7 +429,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
In the standard problem, `b` is assumed to be the identity matrix.
|
||||
|
||||
Note:
|
||||
`eigh` is not supported on Windows platform yet.
|
||||
- `eigh` is not supported on Windows platform yet.
|
||||
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
|
||||
`int64` is passed, it will be cast to :class:`mstype.float64`.
|
||||
|
||||
Args:
|
||||
a (Tensor): A :math:`(M, M)` complex Hermitian or real symmetric matrix whose eigenvalues and
|
||||
|
@ -469,7 +473,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|||
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
|
||||
be reported but results will be wrong.
|
||||
TypeError: If `A` is not Tensor.
|
||||
Runtime: If `A` is not square matrix.
|
||||
RuntimeError: If `A` is not square matrix.
|
||||
ValueError: If `b` is not None.
|
||||
TypeError: If `lower` is not bool.
|
||||
TypeError: If `eigvals_only` is not bool.
|
||||
|
|
|
@ -99,7 +99,7 @@ class SolveTriangular(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, A_dtype, b_dtype):
|
||||
validator.check_scalar_or_tensor_types_same({"A_dtype": A_dtype, "b_dtype": b_dtype},
|
||||
[mstype.float32, mstype.float64], self.name, True)
|
||||
[mstype.float32, mstype.float64], self.name)
|
||||
return A_dtype
|
||||
|
||||
|
||||
|
@ -234,8 +234,8 @@ class EighNet(nn.Cell):
|
|||
self.eigh = Eigh(bv, lower)
|
||||
|
||||
def construct(self, A):
|
||||
if F.dtype(A) in (mstype.int8, mstype.int16, mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float32)
|
||||
if F.dtype(A) in (mstype.int32, mstype.int64):
|
||||
A = F.cast(A, mstype.float64)
|
||||
r = self.eigh(A)
|
||||
if self.bv:
|
||||
return (r[0], r[1])
|
||||
|
|
|
@ -23,7 +23,7 @@ import mindspore.nn as nn
|
|||
import mindspore.scipy as msp
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.scipy.linalg import det
|
||||
from mindspore.scipy.linalg import det, solve_triangular
|
||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
||||
create_random_rank_matrix
|
||||
|
||||
|
@ -50,6 +50,41 @@ def test_block_diag(args):
|
|||
match_array(ms_res.asnumpy(), scipy_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20, 52])
|
||||
@pytest.mark.parametrize('trans', ["N", "T", "C"])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False, True])
|
||||
def test_solve_triangular(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for triangular matrix solver [N,N]
|
||||
Expectation: the result match scipy solve_triangular result
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
if dtype in (onp.int32, onp.int64):
|
||||
a = (onp.random.randint(low=-1024, high=1024, size=(n, n)) + onp.eye(n)).astype(dtype)
|
||||
b = onp.random.randint(low=-1024, high=1024, size=(n,)).astype(dtype)
|
||||
else:
|
||||
a = (onp.random.random((n, n)) + onp.eye(n)).astype(dtype)
|
||||
b = onp.random.random(n).astype(dtype)
|
||||
|
||||
output = solve_triangular(Tensor(a), Tensor(b), trans, lower, unit_diagonal).asnumpy()
|
||||
expect = osp.linalg.solve_triangular(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
|
||||
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-8
|
||||
if dtype == onp.float32:
|
||||
rtol = 1.e-3
|
||||
atol = 1.e-3
|
||||
|
||||
assert onp.allclose(expect, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -155,8 +190,7 @@ def test_cholesky_solver(n: int, lower: bool, data_type):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 6, 9, 20])
|
||||
@pytest.mark.parametrize('data_type',
|
||||
[(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"),
|
||||
(onp.float64, "d")])
|
||||
[(onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"), (onp.float64, "d")])
|
||||
def test_eigh(n: int, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
|
|
|
@ -318,7 +318,7 @@ def test_eigh_net_gpu(n: int):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('trans', ["N", "T", "C"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False])
|
||||
|
@ -342,7 +342,7 @@ def test_solve_triangular_2d(n: int, dtype, lower: bool, unit_diagonal: bool, tr
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('trans', ["N", "T", "C"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False, True])
|
||||
|
@ -366,7 +366,7 @@ def test_solve_triangular_1d(n: int, dtype, lower: bool, unit_diagonal: bool, tr
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(4, 5), (10, 20)])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('trans', ["N", "T", "C"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [False, True])
|
||||
|
|
Loading…
Reference in New Issue