!30095 Remove CholeskySolve and replace it by using solve_triangular

Merge pull request !30095 from zhuzhongrui/pub_master3
This commit is contained in:
i-robot 2022-02-16 08:10:56 +00:00 committed by Gitee
commit 9d1839119d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 67 additions and 63 deletions

View File

@ -41,14 +41,15 @@ void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, s
} }
*row = shape.at(shape.size() - kRowIndex); *row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex); *col = shape.at(shape.size() - kColIndex);
outer_batch_ = min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - kRowIndex); ++batch) {
outer_batch_ *= shape.at(batch);
}
if (*row != *col) { if (*row != *col) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. " MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. "
<< "Cholesky expects a square matrix. but input or output shape is: " << *row << ", " << *col; << "Cholesky expects a square matrix. but input or output shape is: " << *row << ", " << *col;
} }
outer_batch_ = min_dim;
for (const auto &sh : shape) {
outer_batch_ *= sh;
}
outer_batch_ /= ((*row) * (*col));
} }
template <typename T> template <typename T>
@ -98,11 +99,6 @@ bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, cons
output = llt.matrixLLT().transpose(); output = llt.matrixLLT().transpose();
} }
} }
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
continue;
} else {
MS_LOG_EXCEPTION << kernel_name_ << " cholesky llt calculating failed.";
}
} }
return true; return true;
} }

View File

@ -41,9 +41,10 @@ void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t>
*row = shape.at(shape.size() - kRowIndex); *row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex); *col = shape.at(shape.size() - kColIndex);
outer_batch_ = min_dim; outer_batch_ = min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - kRowIndex); ++batch) { for (const auto &sh : shape) {
outer_batch_ *= shape.at(batch); outer_batch_ *= sh;
} }
outer_batch_ /= ((*row) * (*col));
} }
template <typename T> template <typename T>
@ -106,11 +107,6 @@ bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
output.noalias() = input.adjoint().template triangularView<Lower>().solve(input_b); output.noalias() = input.adjoint().template triangularView<Lower>().solve(input_b);
input.template triangularView<Upper>().solveInPlace(output); input.template triangularView<Upper>().solveInPlace(output);
} }
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
continue;
} else {
MS_LOG_EXCEPTION << kernel_name_ << " cholesky solve failed, please check input info.";
}
} }
return true; return true;
} }

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Linear algebra submodule""" """Linear algebra submodule"""
from .ops import Cholesky from .ops import Cholesky
from .ops import CholeskySolve
from .ops import EighNet from .ops import EighNet
from .ops import LU from .ops import LU
from .ops import LUSolver from .ops import LUSolver
@ -313,15 +312,15 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
a_type = F.dtype(a) a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64): if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error( _raise_type_error(
"mindspore.scipy.linalg.cho_factor only support (Tensor[int32], Tensor[int64], Tensor[float32], " "mindspore.scipy.linalg.cho_factor input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).") "Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64): if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64) a = F.cast(a, mstype.float64)
a_shape = a.shape a_shape = a.shape
if a.ndim < 2: if a.ndim < 2:
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be greater or equal to 2 dimensions.") _raise_value_error("mindspore.scipy.linalg.cho_factor input a must be equal to 2 dimensions.")
if a_shape[-1] != a_shape[-2]: if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cho_factor must be a square matrix.") _raise_value_error("mindspore.scipy.linalg.cho_factor input a must be a square matrix.")
cholesky_net = Cholesky(clean=False) cholesky_net = Cholesky(clean=False)
c = cholesky_net(a) c = cholesky_net(a)
if not lower: if not lower:
@ -377,16 +376,16 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
a_type = F.dtype(a) a_type = F.dtype(a)
if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64): if a_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error( _raise_type_error(
"mindspore.scipy.linalg.cholesky only support (Tensor[int32], Tensor[int64], Tensor[float32], " "mindspore.scipy.linalg.cholesky input a only support (Tensor[int32], Tensor[int64], Tensor[float32], "
"Tensor[float64]).") "Tensor[float64]).")
if a_type not in (mstype.float32, mstype.float64): if a_type not in (mstype.float32, mstype.float64):
a = F.cast(a, mstype.float64) a = F.cast(a, mstype.float64)
a_shape = a.shape a_shape = a.shape
if a.ndim < 2: if a.ndim != 2:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be greater or equal to dimensions.") _raise_value_error("mindspore.scipy.linalg.cholesky input a must be equal to 2 dimensions.")
if a_shape[-1] != a_shape[-2]: if a_shape[-1] != a_shape[-2]:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must be a square matrix.") _raise_value_error("mindspore.scipy.linalg.cholesky input a must be a square matrix.")
cholesky_net = Cholesky(clean=True) cholesky_net = Cholesky(clean=True)
c = cholesky_net(a) c = cholesky_net(a)
if not lower: if not lower:
@ -436,13 +435,23 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
c_type = F.dtype(c) c_type = F.dtype(c)
if c_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64): if c_type not in (mstype.int32, mstype.int64, mstype.float32, mstype.float64):
_raise_type_error( _raise_type_error(
"mindspore.scipy.linalg.cho_solve only support (Tensor[int32], Tensor[int64], Tensor[float32], " "mindspore.scipy.linalg.cho_solve input c only support (Tensor[int32], Tensor[int64], Tensor[float32],"
" Tensor[float64]).") " Tensor[float64]).")
if c_type not in (mstype.float32, mstype.float64): if c_type not in (mstype.float32, mstype.float64):
c = F.cast(c, mstype.float64) c = F.cast(c, mstype.float64)
cholesky_solve_net = CholeskySolve(lower=lower) c_type = mstype.float64
x = cholesky_solve_net(c, b) if F.dtype(b) != c_type:
return x b = F.cast(b, c_type)
# Do not support complex, so trans is chosen from ('T', 'N')
if lower:
l_trans = 'N'
l_t_trans = 'T'
else:
l_trans = 'T'
l_t_trans = 'N'
b = SolveTriangular(lower=lower, unit_diagonal=False, trans=l_trans)(c, b)
b = SolveTriangular(lower=lower, unit_diagonal=False, trans=l_t_trans)(c, b)
return b
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,

View File

@ -306,40 +306,6 @@ def test_cholesky(n: int, lower: bool, data_type: Generic):
assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol) assert onp.allclose(osp_c, msp_c.asnumpy(), rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 5, 5), (2, 3, 5, 5)])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('data_type', [onp.float32, onp.float64])
def test_batch_cholesky(shape, lower: bool, data_type):
"""
Feature: ALL To ALL
Description: test cases for cholesky decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_s_l = list()
b_s_a = list()
tmp = onp.zeros(shape[:-2])
inner_row = shape[-2]
inner_col = shape[-1]
for _, _ in onp.ndenumerate(tmp):
a = create_sym_pos_matrix((inner_row, inner_col), data_type)
s_l = osp.linalg.cholesky(a, lower)
b_s_l.append(s_l)
b_s_a.append(a)
tensor_b_a = Tensor(onp.array(b_s_a))
b_m_l = msp.linalg.cholesky(tensor_b_a, lower)
b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape)
rtol = 1.e-3
atol = 1.e-3
if data_type == onp.float64:
rtol = 1.e-5
atol = 1.e-8
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu

View File

@ -22,6 +22,7 @@ from scipy.linalg import solve_triangular, eig, eigvals
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
from mindspore.scipy.utils import _nd_transpose
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition, \ from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition, \
match_exception_info match_exception_info
@ -49,6 +50,42 @@ def test_cholesky(n: int, dtype: Generic):
assert np.allclose(expect, output.asnumpy()) assert np.allclose(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 5, 5), (2, 3, 5, 5)])
@pytest.mark.parametrize('lower', [True, False])
@pytest.mark.parametrize('data_type', [np.float32, np.float64])
def test_batch_cholesky(shape, lower: bool, data_type):
"""
Feature: ALL To ALL
Description: test cases for cholesky decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_s_l = list()
b_s_a = list()
tmp = np.zeros(shape[:-2])
inner_row = shape[-2]
inner_col = shape[-1]
for _, _ in np.ndenumerate(tmp):
a = create_sym_pos_matrix((inner_row, inner_col), data_type)
s_l = scp.linalg.cholesky(a, lower)
b_s_l.append(s_l)
b_s_a.append(a)
tensor_b_a = Tensor(np.array(b_s_a))
b_m_l = Cholesky(clean=True)(tensor_b_a)
if not lower:
b_m_l = _nd_transpose(b_m_l)
b_s_l = np.asarray(b_s_l).reshape(b_m_l.shape)
rtol = 1.e-3
atol = 1.e-3
if data_type == np.float64:
rtol = 1.e-5
atol = 1.e-8
assert np.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard