add test case for exceptions

This commit is contained in:
zhujingxuan 2022-01-28 15:46:14 +08:00
parent 6eef2b33fc
commit 3adaa8922d
6 changed files with 236 additions and 28 deletions

View File

@ -38,29 +38,37 @@ constexpr auto kAVectorxDimNum = 1;
constexpr auto kAMatrixDimNum = 2;
template <typename T>
void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (A_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be 2D, but got [" << A_shape.size() << "] dimensions.";
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', A should be 2D, but got [" << A_shape.size() << "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a squre matrix, but got [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "].";
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the shape of input matrix A should be square matrix like [N X N], but got ["
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
m_ = SizeToInt(A_shape[kDim0]);
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "Wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions.";
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', b should be 1D or 2D, but got [" << b_shape.size()
<< "] dimensions.";
}
if (SizeToInt(b_shape[kDim0]) != m_) {
MS_LOG(EXCEPTION) << "Wrong array shape, b should match the shape of A, excepted [" << m_ << "] but got ["
<< b_shape[kDim0] << "].";
if (b_shape.size() == kAVectorxDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be [" << m_ << "], but got ["
<< b_shape[kDim0] << "].";
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be [" << m_ << " X "
<< b_shape[kDim1] << "], but got [" << b_shape[kDim0] << " X " << b_shape[kDim1] << "].";
}
}
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
} else {
n_ = b_shape[kDim1];
n_ = SizeToInt(b_shape[kDim1]);
}
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
unit_diagonal_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL);

View File

@ -115,32 +115,41 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
is_null_input_ =
CHECK_SHAPE_NULL(A_shape, kernel_name, "input_A") || CHECK_SHAPE_NULL(b_shape, kernel_name, "input_b");
CHECK_SHAPE_NULL(A_shape, kernel_name_, "input_A") || CHECK_SHAPE_NULL(b_shape, kernel_name_, "input_b");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (A_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', A should be 2D, but got [" << A_shape.size()
<< "] dimensions.";
}
if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the shape of input matrix A should be square matrix like [N X N], but got ["
<< A_shape[kDim0] << " X " << A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', b should be 1D or 2D, but got [" << b_shape.size()
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', b should be 1D or 2D, but got [" << b_shape.size()
<< "] dimensions.";
}
if (b_shape[kDim0] != m_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the shape of input should be [" << m_ << "], but got ["
<< b_shape[kDim0] << "].";
if (b_shape.size() == kAVectorxDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be [" << m_ << "], but got ["
<< b_shape[kDim0] << "].";
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input should be [" << m_ << " X "
<< b_shape[kDim1] << "], but got [" << b_shape[kDim0] << " X " << b_shape[kDim1] << "].";
}
}
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
@ -152,16 +161,7 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
ldb_ = SizeToInt(m_);
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
// converting row major to col major is the same as reverting the trans flag
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', trans should be in [N, T, C], but got [" << trans << "].";
}
SetOperation(trans);
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
// reverting the trans flag by default, so also flip the lower flag
@ -195,6 +195,19 @@ class TrsmGpuKernelMod : public NativeGpuKernelMod {
}
}
void SetOperation(const std::string &trans) {
// converting row major to col major is the same as reverting the trans flag
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "].";
}
}
private:
size_t m_{0};
size_t n_{0};

View File

@ -113,7 +113,7 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
Args:
A (Tensor): A non-singular triangular matrix of shape :math:`(M, M)`.
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(M, N)`. Right-hand side matrix in :math:`A x = b`.
lower (bool, optional): Use only data contained in the lower triangle of `a`. Default: False.
lower (bool, optional): Use only data contained in the lower triangle of `A`. Default: False.
trans (0, 1, 2, 'N', 'T', 'C', optional): Type of system to solve. Default: 0.
======== =========
@ -174,6 +174,7 @@ def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
[4. 2. 4. 2.]
"""
_type_check('trans', trans, (int, str), 'solve_triangular')
_type_check('lower', lower, bool, 'solve_triangular')
_type_check('overwrite_b', overwrite_b, bool, 'solve_triangular')
_type_check('check_finite', check_finite, bool, 'solve_triangular')
if debug is not None:

View File

@ -91,6 +91,8 @@ class SolveTriangular(PrimitiveWithInfer):
def __infer__(self, A, b):
out_shapes = b['shape']
validator.check_scalar_or_tensor_types_same({"A_dtype": A['dtype'], "b_dtype": b['dtype']},
[mstype.float32, mstype.float64], self.name)
return {
'shape': tuple(out_shapes),
'dtype': A['dtype'],
@ -98,8 +100,6 @@ class SolveTriangular(PrimitiveWithInfer):
}
def infer_dtype(self, A_dtype, b_dtype):
validator.check_scalar_or_tensor_types_same({"A_dtype": A_dtype, "b_dtype": b_dtype},
[mstype.float32, mstype.float64], self.name)
return A_dtype

View File

@ -25,7 +25,7 @@ from mindspore import context, Tensor
import mindspore.numpy as mnp
from mindspore.scipy.linalg import det, solve_triangular
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
create_random_rank_matrix
create_random_rank_matrix, match_runtime_exception
onp.random.seed(0)
context.set_context(mode=context.PYNATIVE_MODE)
@ -85,6 +85,161 @@ def test_solve_triangular(n: int, dtype, lower: bool, unit_diagonal: bool, trans
assert onp.allclose(expect, output, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [3, 4, 6])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
def test_solve_triangular_error_dims(n: int, dtype):
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: solve_triangular raises expectated Exception
"""
a = onp.random.randint(low=-1024, high=1024, size=(10,) * n).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(dtype)
with pytest.raises(RuntimeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', A should be 2D, but got [{n}] dimensions."
assert match_runtime_exception(err, msg)
a = onp.random.randint(low=-1024, high=1024, size=(n, n + 1)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(dtype)
with pytest.raises(RuntimeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', the shape of input matrix A should be square matrix like [N X N], " \
f"but got [{n} X {n + 1}]."
assert match_runtime_exception(err, msg)
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(11,) * n).astype(dtype)
with pytest.raises(RuntimeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', b should be 1D or 2D, but got [{n}] dimensions."
assert match_runtime_exception(err, msg)
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(n,)).astype(dtype)
with pytest.raises(RuntimeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', the shape of input should be [10], but got [{n}]."
assert match_runtime_exception(err, msg)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_solve_triangular_error_tensor_dtype():
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: solve_triangular raises expectated Exception
"""
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(onp.float16)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float16)
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', the type of `A_dtype` should be in " \
f"[mindspore.float32, mindspore.float64], but got Float16."
assert str(err.value) == msg
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(onp.float32)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float16)
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = f"For 'SolveTriangular', the type of `b_dtype` should be in " \
f"[mindspore.float32, mindspore.float64], but got Float16."
assert str(err.value) == msg
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(onp.float32)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float64)
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b))
msg = "For 'SolveTriangular' type of `b_dtype` should be same as `A_dtype`, " \
"but `A_dtype` is Float32 and `b_dtype` is Float64."
assert str(err.value) == msg
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
@pytest.mark.parametrize('argname, argtype', [('lower', 'bool'), ('overwrite_b', 'bool'), ('check_finite', 'bool')])
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType'), ('test', 'str')])
def test_solve_triangular_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype):
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: solve_triangular raises expectated Exception
"""
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(dtype)
kwargs = {argname: wrong_argvalue}
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b), **kwargs)
msg = f"For 'solve_triangular', the type of `{argname}` should be '{argtype}', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
assert str(err.value) == msg
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType')])
def test_solve_triangular_error_type_trans(dtype, wrong_argvalue, wrong_argtype):
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: solve_triangular raises expectated Exception
"""
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(dtype)
with pytest.raises(TypeError) as err:
solve_triangular(Tensor(a), Tensor(b), trans=wrong_argvalue)
msg = f"For 'solve_triangular', the type of `trans` should be one of '['int', 'str']', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
assert str(err.value) == msg
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_solve_triangular_error_tensor_type():
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: solve_triangular raises expectated Exception
"""
a = 'test'
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float32)
with pytest.raises(TypeError) as err:
solve_triangular(a, Tensor(b))
msg = "For Primitive[DType], the input argument[infer type]must be a Tensor or CSRTensor but got String."
assert match_runtime_exception(err, msg)
a = [1, 2, 3]
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float32)
with pytest.raises(TypeError) as err:
solve_triangular(a, Tensor(b))
msg = "For Primitive[DType], the input argument[infer type]must be a Tensor or CSRTensor but got List[Int64*3]."
assert match_runtime_exception(err, msg)
a = (1, 2, 3)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(onp.float32)
with pytest.raises(TypeError) as err:
solve_triangular(a, Tensor(b))
msg = "For Primitive[DType], the input argument[infer type]must be a Tensor or CSRTensor but got Tuple[Int64*3]."
assert match_runtime_exception(err, msg)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@ -256,6 +411,31 @@ def test_eigh_complex(n: int, data_type):
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
@pytest.mark.parametrize('argname, argtype',
[('lower', 'bool'), ('eigvals_only', 'bool'), ('overwrite_a', 'bool'), ('overwrite_b', 'bool'),
('turbo', 'bool'), ('check_finite', 'bool')])
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType')])
def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype):
"""
Feature: ALL TO ALL
Description: test cases for solve_triangular for triangular matrix solver [N,N]
Expectation: eigh raises expectated Exception
"""
a = onp.random.randint(low=-1024, high=1024, size=(10, 10)).astype(dtype)
b = onp.random.randint(low=-1024, high=1024, size=(10,)).astype(dtype)
kwargs = {argname: wrong_argvalue}
with pytest.raises(TypeError) as err:
msp.linalg.eigh(Tensor(a), Tensor(b), **kwargs)
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be '{argtype}', " \
f"but got '{wrong_argvalue}' with type '{wrong_argtype}'."
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu

View File

@ -138,3 +138,9 @@ def gradient_check(x, net, epsilon=1e-3):
denominator = onp.linalg.norm(x_grad) + onp.linalg.norm(x_grad_approx)
difference = numerator / denominator
return difference
def match_runtime_exception(err, expected_str):
err_str = str(err.value)
err_str = err_str[err_str.find("]") + 2:]
return err_str == expected_str