!30095 Remove CholeskySolve and replace it by using solve_triangular

Merge pull request !30095 from zhuzhongrui/pub_master3
This commit is contained in:
i-robot 2022-02-16 08:10:56 +00:00 committed by Gitee
commit 9d1839119d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 67 additions and 63 deletions

View File

@ -41,14 +41,15 @@ void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, s
}
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
outer_batch_ = min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - kRowIndex); ++batch) {
outer_batch_ *= shape.at(batch);
}
if (*row != *col) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. "
<< "Cholesky expects a square matrix. but input or output shape is: " << *row << ", " << *col;
}
outer_batch_ = min_dim;
for (const auto &sh : shape) {
outer_batch_ *= sh;
}
outer_batch_ /= ((*row) * (*col));
}
template <typename T>
@ -98,11 +99,6 @@ bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, cons
output = llt.matrixLLT().transpose();
}
}
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
continue;
} else {
MS_LOG_EXCEPTION << kernel_name_ << " cholesky llt calculating failed.";
}
}
return true;
}

View File

@ -41,9 +41,10 @@ void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t>
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
outer_batch_ = min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - kRowIndex); ++batch) {
outer_batch_ *= shape.at(batch);
for (const auto &sh : shape) {
outer_batch_ *= sh;
}
outer_batch_ /= ((*row) * (*col));
}
template <typename T>
@ -106,11 +107,6 @@ bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
output.noalias() = input.adjoint().template triangularView<Lower>().solve(input_b);
input.template triangularView<Upper>().solveInPlace(output);
}
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
continue;
} else {
MS_LOG_EXCEPTION << kernel_name_ << " cholesky solve failed, please check input info.";
}
}
return true;
}

View File

@ -14,7 +14,6 @@
# ============================================================================
"""Linear algebra submodule"""
from .ops import Cholesky
from .ops import CholeskySolve
from .ops import EighNet
from .ops import LU
from .ops import LUSolver
@ -313,15 +312,15 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cho_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"mindspore.scipy.linalg.cho_factor input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64)
a_shape = a.shape
if a.ndim < 2:
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be greater or equal to 2 dimensions.")
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be equal to 2 dimensions.")
if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be a square matrix.")
_raise_value_error("mindspore.scipy.linalg.cho_factor input a must be a square matrix.")
cholesky_net = Cholesky(clean=False)
c = cholesky_net(a)
if not lower:
@ -377,16 +376,16 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cholesky only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"mindspore.scipy.linalg.cholesky input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64)
a_shape = a.shape
if a.ndim < 2:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be greater or equal to dimensions.")
if a.ndim != 2:
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be equal to 2 dimensions.")
if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be a square matrix.")
_raise_value_error("mindspore.scipy.linalg.cholesky input a must be a square matrix.")
cholesky_net = Cholesky(clean=True)
c = cholesky_net(a)
if not lower:
@ -436,13 +435,23 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
c_type = F.dtype(c)
if c_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error(
"mindspore.scipy.linalg.cho_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).")
"mindspore.scipy.linalg.cho_solve input c only support (Tensor[int32], Tensor[int64], Tensor[float32],"
" Tensor[float64]).")
if c_type not in (mstype.float32, mstype.float64):
c = F.cast(c, mstype.float64)
cholesky_solve_net = CholeskySolve(lower=lower)
x = cholesky_solve_net(c, b)
return x
c_type = mstype.float64
if F.dtype(b) != c_type:
b = F.cast(b, c_type)
# Do not support complex, so trans is chosen from ('T', 'N')
if lower:
l_trans = 'N'
l_t_trans = 'T'
else:
l_trans = 'T'
l_t_trans = 'N'
b = SolveTriangular(lower=lower, unit_diagonal=False, trans=l_trans)(c, b)
b = SolveTriangular(lower=lower, unit_diagonal=False, trans=l_t_trans)(c, b)
return b
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,

View File

@ -306,40 +306,6 @@ def test_cholesky(n: int, lower: bool, data_type: Generic):
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 5, 5), (2, 3, 5, 5)])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
def test_batch_cholesky(shape, lower: bool, data_type):
"""
Feature: ALL To ALL
Description: test cases for cholesky decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_s_l = list()
b_s_a = list()
tmp = onp.zeros(shape[:-2])
inner_row = shape[-2]
inner_col = shape[-1]
for _, _ in onp.ndenumerate(tmp):
a = create_sym_pos_matrix((inner_row, inner_col), data_type)
s_l = osp.linalg.cholesky(a, lower)
b_s_l.append(s_l)
b_s_a.append(a)
tensor_b_a = Tensor(onp.array(b_s_a))
b_m_l = msp.linalg.cholesky(tensor_b_a, lower)
b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape)
rtol = 1.e-3
atol = 1.e-3
if data_type == onp.float64:
rtol = 1.e-5
atol = 1.e-8
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu

View File

@ -22,6 +22,7 @@ from scipy.linalg import solve_triangular, eig, eigvals
from mindspore import Tensor, context
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
from mindspore.scipy.utils import _nd_transpose
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition, \
match_exception_info
@ -49,6 +50,42 @@ def test_cholesky(n: int, dtype: Generic):
assert np.allclose(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 5, 5), (2, 3, 5, 5)])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('data_type', [np.float32, np.float64])
def test_batch_cholesky(shape, lower: bool, data_type):
"""
Feature: ALL To ALL
Description: test cases for cholesky decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_s_l = list()
b_s_a = list()
tmp = np.zeros(shape[:-2])
inner_row = shape[-2]
inner_col = shape[-1]
for _, _ in np.ndenumerate(tmp):
a = create_sym_pos_matrix((inner_row, inner_col), data_type)
s_l = scp.linalg.cholesky(a, lower)
b_s_l.append(s_l)
b_s_a.append(a)
tensor_b_a = Tensor(np.array(b_s_a))
b_m_l = Cholesky(clean=True)(tensor_b_a)
if not lower:
b_m_l = _nd_transpose(b_m_l)
b_s_l = np.asarray(b_s_l).reshape(b_m_l.shape)
rtol = 1.e-3
atol = 1.e-3
if data_type == np.float64:
rtol = 1.e-5
atol = 1.e-8
assert np.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard