fix solve_triangular and eigh api issue

This commit is contained in:
zhujingxuan 2022-01-25 20:21:32 +08:00
parent 33b74255d2
commit 29febb2e2a
9 changed files with 95 additions and 44 deletions

View File

@ -34,11 +34,11 @@ void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (A_shape.size() != kShape2dDims) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions";
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0]
<< " X " << A_shape[kDim1] << "]";
<< " X " << A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
}

View File

@ -42,20 +42,20 @@ void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (A_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions";
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix, but got [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "]";
<< A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "Wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions";
MS_LOG(EXCEPTION) << "Wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions.";
}
if (SizeToInt(b_shape[kDim0]) != m_) {
MS_LOG(EXCEPTION) << "Wrong array shape, b should match the shape of A, excepted [" << m_ << "] but got ["
<< b_shape[kDim0] << "]";
<< b_shape[kDim0] << "].";
}
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
@ -72,7 +72,7 @@ void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
} else if (trans == "C") {
trans_ = true;
} else {
MS_LOG(EXCEPTION) << "Trans should be in [N, T], but got [" << trans << "]";
MS_LOG(EXCEPTION) << "Trans should be in [N, T, C], but got [" << trans << "].";
}
}

View File

@ -74,9 +74,14 @@ class EighcGpuKernelMod : public NativeGpuKernelMod {
InitSizeLists();
return true;
}
if (A_shape.size() != kShape2dDims || A_shape[1] != A_shape[1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be square matrix, but got ["
<< A_shape[0] << " X " << A_shape[1] << "]";
if (A_shape.size() != kShape2dDims) {
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name_ << "', a should be 2D, but got ["
<< A_shape.size() << "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "Wrong array shape, For '" << kernel_name_
<< "', a should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "].";
}
m_ = A_shape[0];
InitSizeLists();

View File

@ -59,9 +59,14 @@ class EighGpuKernelMod : public NativeGpuKernelMod {
InitSizeLists();
return true;
}
if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be square matrix, but got ["
<< A_shape[0] << " X " << A_shape[1] << "]";
if (A_shape.size() != kShape2dDims) {
MS_LOG(EXCEPTION) << "Wrong array shape. For '" << kernel_name << "', a should be 2D, but got [" << A_shape.size()
<< "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "Wrong array shape, For '" << kernel_name
<< "', a should be a squre matrix like [N X N], but got shape [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "].";
}
m_ = A_shape[0];
InitSizeLists();

View File

@ -128,18 +128,19 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be square matrix, but got ["
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "]";
MS_LOG(EXCEPTION) << "For '" << kernel_name
<< "', the shape of input matrix A should be square matrix like [N X N], but got ["
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input should be 1 or 2, but got "
<< b_shape.size();
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', b should be 1D or 2D, but got [" << b_shape.size()
<< "] dimensions.";
}
if (b_shape[kDim0] != m_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be [" << m_ << "], but got ["
<< b_shape[kDim0] << "]";
<< b_shape[kDim0] << "].";
}
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
@ -156,8 +157,10 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', trans should be in [N, T], but got [" << trans << "]";
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', trans should be in [N, T, C], but got [" << trans << "].";
}
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");

View File

@ -98,7 +98,7 @@ def block_diag(*arrs):
def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True):
overwrite_b=False, debug=None, check_finite=False):
"""
Assuming a is a triangular matrix, solve the equation
@ -106,16 +106,15 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
A x = b
Note:
`solve_triangular` is not supported on Windows platform yet.
- `solve_triangular` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
A (Tensor): A non-singular triangular matrix of shape :math:`(M, M)`. Note that if the input tensor is neither
`float32` nor `float64`, then it will be cast to :class:`mstype.float32`.
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(M, N)`.
Right-hand side matrix in :math:`A x = b`. Note that if the input tensor is neither `float32` nor `float64`,
then it will be cast to :class:`mstype.float32`.
A (Tensor): A non-singular triangular matrix of shape :math:`(M, M)`.
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(M, N)`. Right-hand side matrix in :math:`A x = b`.
lower (bool, optional): Use only data contained in the lower triangle of `a`. Default: False.
trans (0, 1, 2, 'N', 'T', 'C', optional): Type of system to solve. Default: 'N'.
trans (0, 1, 2, 'N', 'T', 'C', optional): Type of system to solve. Default: 0.
======== =========
trans system
@ -127,6 +126,7 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and
will not be referenced. Default: False.
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance). Default: False.
debug (None): Not implemented now. Default: False.
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: False.
@ -142,13 +142,13 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
TypeError: If dtype of `A` and `b` are not the same.
RuntimeError: If shape of `A` and `b` are not matched or more than 2D.
TypeError: If `trans` is not int or str.
RuntimeError: If `trans` is not in set {0, 1, 2, 'N', 'T', 'C'}.
ValueError: If `trans` is not in set {0, 1, 2, 'N', 'T', 'C'}.
TypeError: If `lower` is not bool.
TypeError: If `unit_diagonal` is not bool.
TypeError: If `overwrite_b` is not bool.
TypeError: If `check_finite` is not bool.
ValueError: If `debug` is not None.
ValueError: If `A` is singular matrix.
ValueError: If the shape of `A` and `b` not match.
Supported Platforms:
``CPU`` ``GPU``
@ -173,12 +173,14 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
>>> print(mnp.dot(A, x)) # Check the result
[4. 2. 4. 2.]
"""
_type_check('trans', trans, (int, str), 'solve_triangular')
_type_check('overwrite_b', overwrite_b, bool, 'solve_triangular')
_type_check('check_finite', check_finite, bool, 'solve_triangular')
if F.dtype(A) in (mstype.int32, mstype.int64):
A = F.cast(A, mstype.float32)
if F.dtype(b) in (mstype.int32, mstype.int64):
b = F.cast(b, mstype.float32)
if debug is not None:
_raise_value_error("Currently only case debug=None of solve_triangular Implemented.")
if F.dtype(A) == F.dtype(b) and F.dtype(A) in (mstype.int32, mstype.int64):
A = F.cast(A, mstype.float64)
b = F.cast(b, mstype.float64)
if trans not in (0, 1, 2, 'N', 'T', 'C'):
_raise_value_error("The value of trans should be one of (0, 1, 2, 'N', 'T', 'C'), but got " + str(trans))
if isinstance(trans, int):
@ -427,7 +429,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
In the standard problem, `b` is assumed to be the identity matrix.
Note:
`eigh` is not supported on Windows platform yet.
- `eigh` is not supported on Windows platform yet.
- Only `float32`, `float64`, `int32`, `int64` are supported Tensor dtypes. If Tensor with dtype `int32` or
`int64` is passed, it will be cast to :class:`mstype.float64`.
Args:
a (Tensor): A :math:`(M, M)` complex Hermitian or real symmetric matrix whose eigenvalues and
@ -469,7 +473,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
definite positive. Note that if input matrices are not symmetric or Hermitian, no error will
be reported but results will be wrong.
TypeError: If `A` is not Tensor.
Runtime: If `A` is not square matrix.
RuntimeError: If `A` is not square matrix.
ValueError: If `b` is not None.
TypeError: If `lower` is not bool.
TypeError: If `eigvals_only` is not bool.

View File

@ -99,7 +99,7 @@ class SolveTriangular(PrimitiveWithInfer):
def infer_dtype(self, A_dtype, b_dtype):
validator.check_scalar_or_tensor_types_same({"A_dtype": A_dtype, "b_dtype": b_dtype},
[mstype.float32, mstype.float64], self.name, True)
[mstype.float32, mstype.float64], self.name)
return A_dtype
@ -234,8 +234,8 @@ class EighNet(nn.Cell):
self.eigh = Eigh(bv, lower)
def construct(self, A):
if F.dtype(A) in (mstype.int8, mstype.int16, mstype.int32, mstype.int64):
A = F.cast(A, mstype.float32)
if F.dtype(A) in (mstype.int32, mstype.int64):
A = F.cast(A, mstype.float64)
r = self.eigh(A)
if self.bv:
return (r[0], r[1])

View File

@ -23,7 +23,7 @@ import mindspore.nn as nn
import mindspore.scipy as msp
from mindspore import context, Tensor
import mindspore.numpy as mnp
from mindspore.scipy.linalg import det
from mindspore.scipy.linalg import det, solve_triangular
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
create_random_rank_matrix
@ -50,6 +50,41 @@ def test_block_diag(args):
match_array(ms_res.asnumpy(), scipy_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20, 52])
@pytest.mark.parametrize('trans', ["N", "T", "C"])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_solve_triangular(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: the result match scipy solve_triangular result
"""
onp.random.seed(0)
if dtype in (onp.int32, onp.int64):
a = (onp.random.randint(low=-1024, high=1024, size=(n, n)) + onp.eye(n)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(n,)).astype(dtype)
else:
a = (onp.random.random((n, n)) + onp.eye(n)).astype(dtype)
b = onp.random.random(n).astype(dtype)
output = solve_triangular(Tensor(a), Tensor(b), trans, lower, unit_diagonal).asnumpy()
expect = osp.linalg.solve_triangular(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
rtol = 1.e-5
atol = 1.e-8
if dtype == onp.float32:
rtol = 1.e-3
atol = 1.e-3
assert onp.allclose(expect, output, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@ -155,8 +190,7 @@ def test_cholesky_solver(n: int, lower: bool, data_type):
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 6, 9, 20])
@pytest.mark.parametrize('data_type',
[(onp.int8, "f"), (onp.int16, "f"), (onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"),
(onp.float64, "d")])
[(onp.int32, "f"), (onp.int64, "d"), (onp.float32, "f"), (onp.float64, "d")])
def test_eigh(n: int, data_type):
"""
Feature: ALL TO ALL

View File

@ -318,7 +318,7 @@ def test_eigh_net_gpu(n: int):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('trans', ["N", "T", "C"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False])
@ -342,7 +342,7 @@ def test_solve_triangular_2d(n: int, dtype, lower: bool, unit_diagonal: bool, tr
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('trans', ["N", "T", "C"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
@ -366,7 +366,7 @@ def test_solve_triangular_1d(n: int, dtype, lower: bool, unit_diagonal: bool, tr
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(4, 5), (10, 20)])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('trans', ["N", "T", "C"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])