forked from mindspore-Ecosystem/mindspore
!25590 fix lu api bugs
Merge pull request !25590 from zhuzhongrui/pub_master
This commit is contained in:
commit
073ff9d2b9
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/Core"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::Lower;
|
||||
using Eigen::Map;
|
||||
using Eigen::MatrixBase;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::UnitLower;
|
||||
using Eigen::UnitUpper;
|
||||
using Eigen::Upper;
|
||||
template <typename T, int Major>
|
||||
using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>;
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/LU"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kLUInputsNum = 1;
|
||||
constexpr size_t kLUaIndex = 0;
|
||||
|
@ -73,27 +73,44 @@ template <typename T>
|
|||
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);
|
||||
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
|
||||
|
||||
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
|
||||
Map<Matrix<T, RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
|
||||
int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_pivots(
|
||||
pivots_value, pivots_row_, pivots_col_);
|
||||
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_permutation(
|
||||
permutation_value, permutation_row_, permutation_col_);
|
||||
Map<Matrix<int, RowMajor>> output_permutation(permutation_value, permutation_row_, permutation_col_);
|
||||
|
||||
if (a_row_ == a_col_) {
|
||||
// partial_piv_lu
|
||||
output_lu = input_a.lu().matrixLU();
|
||||
output_pivots = input_a.lu().permutationP().indices();
|
||||
auto partial_lu = input_a.lu();
|
||||
auto partial_p = partial_lu.permutationP();
|
||||
output_lu.noalias() = partial_lu.matrixLU();
|
||||
output_permutation.noalias() = partial_p.toDenseMatrix();
|
||||
} else {
|
||||
// full_piv_lu
|
||||
output_lu = input_a.fullPivLu().matrixLU();
|
||||
output_pivots = input_a.fullPivLu().permutationP().indices();
|
||||
auto full_piv_lu = input_a.fullPivLu();
|
||||
auto full_piv_p = full_piv_lu.permutationP();
|
||||
output_lu.noalias() = full_piv_lu.matrixLU();
|
||||
output_permutation.noalias() = full_piv_p.toDenseMatrix();
|
||||
}
|
||||
output_permutation = output_pivots;
|
||||
|
||||
// calculate permutation array from permutation matrix to indicate scipy's pivots.
|
||||
for (int i = 0; i < static_cast<int>(output_permutation.rows()); ++i) {
|
||||
if (output_permutation(i, i) != 0) {
|
||||
pivots_value[i] = i;
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < static_cast<int>(output_permutation.cols()); ++j) {
|
||||
if (output_permutation(i, j) != 0) {
|
||||
pivots_value[i] = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// here, we note that eigen calculate permutation matrix is col major, so transpose it to row major,
|
||||
// but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax.
|
||||
output_permutation.transposeInPlace();
|
||||
if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
|
@ -62,4 +62,4 @@ MS_REG_CPU_KERNEL_T(LU,
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h"
|
||||
#include "Eigen/Dense"
|
||||
#include "Eigen/LU"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kLUInputsNum = 2;
|
||||
constexpr size_t kLUaIndex = 0;
|
||||
|
@ -70,6 +70,7 @@ void LUSolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex);
|
||||
out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex);
|
||||
}
|
||||
trans_ = AnfAlgo ::GetNodeAttr<std::string>(kernel_node, TRANS);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -77,23 +78,28 @@ bool LUSolverCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);
|
||||
Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_);
|
||||
|
||||
T *b_value = reinterpret_cast<T *>(inputs[kLUbIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(b_value, b_row_, b_col_);
|
||||
Map<Matrix<T, RowMajor>> input_b(b_value, b_row_, b_col_);
|
||||
T *output_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(output_lu_value, out_row_,
|
||||
out_col_);
|
||||
if (a_row_ == a_col_) {
|
||||
// partial_piv_lu
|
||||
output_lu = input_a.lu().solve(input_b);
|
||||
Map<Matrix<T, RowMajor>> output_lu(output_lu_value, out_row_, out_col_);
|
||||
if (trans_ == "N") {
|
||||
output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(input_b);
|
||||
output_lu.noalias() = input_a.template triangularView<Upper>().solve(output_lu);
|
||||
} else if (trans_ == "T") {
|
||||
output_lu.noalias() = input_a.template triangularView<Upper>().solve(input_b);
|
||||
output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(output_lu);
|
||||
} else if (trans_ == "C") {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is not supported C: " << trans_;
|
||||
} else {
|
||||
// full_piv_lu
|
||||
output_lu = input_a.fullPivLu().solve(input_b);
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is invalid: " << trans_;
|
||||
}
|
||||
|
||||
if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_SOLVER_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
|
@ -39,6 +40,7 @@ class LUSolverCPUKernel : public CPUKernel {
|
|||
size_t b_col_{1};
|
||||
size_t out_row_{1};
|
||||
size_t out_col_{1};
|
||||
std::string trans_{};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
|
@ -53,4 +55,4 @@ MS_REG_CPU_KERNEL_T(
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_
|
||||
|
|
|
@ -106,6 +106,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
|
||||
}
|
||||
|
|
|
@ -72,6 +72,13 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ScatterNdUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterNdUpdateCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,11 +19,14 @@ from mindspore import Tensor
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import PrimitiveWithInfer
|
||||
from mindspore.ops import prim_attr_register
|
||||
from scipy.linalg import lu_factor
|
||||
from scipy.linalg import lu_solve
|
||||
from mindspore._checkparam import Validator as validator
|
||||
import mindspore.numpy as mnp
|
||||
import scipy as scp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
|
@ -41,19 +44,14 @@ class LU(PrimitiveWithInfer):
|
|||
def __infer__(self, x):
|
||||
x_shape = list(x['shape'])
|
||||
x_dtype = x['dtype']
|
||||
pivots_shape = []
|
||||
permutation_shape = []
|
||||
ndim = len(x_shape)
|
||||
permutation_shape = x_shape
|
||||
if ndim == 0:
|
||||
pivots_shape = x_shape
|
||||
permutation_shape = x_shape
|
||||
elif ndim == 1:
|
||||
pivots_shape = x_shape[:-1]
|
||||
permutation_shape = x_shape[:-1]
|
||||
else:
|
||||
pivots_shape = x_shape[-2:-1]
|
||||
permutation_shape = x_shape[-2:-1]
|
||||
|
||||
output = {
|
||||
'shape': (x_shape, pivots_shape, permutation_shape),
|
||||
'dtype': (x_dtype, mstype.int32, mstype.int32),
|
||||
|
@ -68,9 +66,10 @@ class LUSolver(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, trans: str):
|
||||
super().__init__(name="LUSolver")
|
||||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
|
||||
self.trans = validator.check_value_type("trans", trans, [str], self.name)
|
||||
|
||||
def __infer__(self, x, b):
|
||||
b_shape = list(b['shape'])
|
||||
|
@ -92,42 +91,128 @@ class LuNet(nn.Cell):
|
|||
return self.lu(a)
|
||||
|
||||
|
||||
class LUSolverNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LUSolverNet, self).__init__()
|
||||
self.lu_solver = LUSolver()
|
||||
def lu_pivots_to_permutation(pivots, permutation_size: int):
|
||||
batch_dims = pivots.shape[:-1]
|
||||
k = pivots.shape[-1]
|
||||
per = mnp.arange(0, permutation_size)
|
||||
permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,))
|
||||
permutation = mnp.array(permutation)
|
||||
if permutation_size == 0:
|
||||
return permutation
|
||||
|
||||
def construct(self, a, b):
|
||||
return self.lu_solver(a, b)
|
||||
for i in range(k):
|
||||
j = pivots[..., i]
|
||||
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[loc + (j,)]
|
||||
permutation[..., i] = y
|
||||
permutation[loc + (j,)] = x
|
||||
return permutation
|
||||
|
||||
|
||||
def _match_array(actual, expected, error=0):
|
||||
if isinstance(actual, int):
|
||||
actual = np.asarray(actual)
|
||||
if isinstance(actual, tuple):
|
||||
actual = np.asarray(actual)
|
||||
|
||||
if error > 0:
|
||||
np.testing.assert_almost_equal(actual, expected, decimal=error)
|
||||
def _lu_solve_core(in_lu, permutation, b, trans):
|
||||
m = in_lu.shape[0]
|
||||
res_shape = b.shape[1:]
|
||||
prod_result = 1
|
||||
for sh in res_shape:
|
||||
prod_result *= sh
|
||||
x = mnp.reshape(b, (m, prod_result))
|
||||
if trans == 0:
|
||||
trans_str = "N"
|
||||
x = x[permutation, :]
|
||||
elif trans == 1:
|
||||
trans_str = "T"
|
||||
elif trans == 2:
|
||||
trans_str = "C"
|
||||
else:
|
||||
np.testing.assert_equal(actual, expected)
|
||||
raise ValueError("trans error, it's value must be 0, 1, 2")
|
||||
ms_lu_solve = LUSolver(trans_str)
|
||||
output = ms_lu_solve(in_lu, x)
|
||||
return mnp.reshape(output, b.shape)
|
||||
|
||||
|
||||
def _check_lu_shape(in_lu, b):
|
||||
if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]:
|
||||
raise ValueError("last two dimensions of LU decomposition must be equal.")
|
||||
|
||||
if b.shape is None:
|
||||
raise ValueError(" LU decomposition input b's rank must >=1.")
|
||||
rhs_vector = in_lu.ndim == b.ndim + 1
|
||||
if rhs_vector:
|
||||
if b.shape[-1] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
mnp.expand_dims(b, axis=1)
|
||||
else:
|
||||
if b.shape[-2] != in_lu.shape[-1]:
|
||||
raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions")
|
||||
|
||||
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
mscp_lu = LuNet()
|
||||
m_lu, pivots, _ = mscp_lu(a)
|
||||
return m_lu, pivots
|
||||
|
||||
|
||||
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
mscp_lu = LuNet()
|
||||
m_lu, _, p = mscp_lu(a)
|
||||
m = a.shape[-2]
|
||||
n = a.shape[-1]
|
||||
k = min(m, n)
|
||||
a_dtype = a.dtype
|
||||
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.matmul(p, l), u
|
||||
return p, l, u
|
||||
|
||||
|
||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
del overwrite_b, check_finite
|
||||
m_lu, pivots = lu_and_piv
|
||||
# 1. check shape
|
||||
_check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. calculate permutation
|
||||
permutation = pivots
|
||||
# 3. rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = _lu_solve_core(m_lu, permutation, b, trans)
|
||||
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
||||
def create_full_rank_matrix(m, n, dtype):
|
||||
a_rank = 0
|
||||
a = np.random.random((m, n)).astype(dtype)
|
||||
while a_rank != m:
|
||||
a = (a + np.eye(m, n)).astype(dtype)
|
||||
a_rank = np.linalg.matrix_rank(a)
|
||||
return a
|
||||
|
||||
|
||||
def create_sym_pos_matrix(m, n, dtype):
|
||||
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
|
||||
return np.dot(a, a.T)
|
||||
|
||||
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_lu_net(n: int, dtype: Generic):
|
||||
def test_square_lu_net(n: int, dtype: Generic):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
s_lu, _ = lu_factor(a)
|
||||
a = create_full_rank_matrix(n, n, dtype)
|
||||
s_lu, _ = scp.linalg.lu_factor(a)
|
||||
mscp_lu_net = LuNet()
|
||||
tensor_a = Tensor(a)
|
||||
mscp_lu, _, _ = mscp_lu_net(tensor_a)
|
||||
_match_array(mscp_lu.asnumpy(), s_lu, error=4)
|
||||
assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
|
||||
|
||||
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -139,13 +224,24 @@ def test_lu_solver_net(n: int, dtype: Generic):
|
|||
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
a = create_full_rank_matrix(n, n, dtype)
|
||||
b = np.random.random((n, 1)).astype(dtype)
|
||||
s_lu, s_piv = lu_factor(a)
|
||||
lu_factor_x = (s_lu, s_piv)
|
||||
scp_x = lu_solve(lu_factor_x, b)
|
||||
mscp_lu_net = LUSolverNet()
|
||||
s_lu, s_piv = scp.linalg.lu_factor(a)
|
||||
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
mscp_x = mscp_lu_net(tensor_a, tensor_b)
|
||||
_match_array(mscp_x.asnumpy(), scp_x, error=4)
|
||||
mscp_lu_net = LuNet()
|
||||
mscp_lu, pivots, _ = mscp_lu_net(tensor_a)
|
||||
np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3)
|
||||
|
||||
lu_factor_x = (s_lu, s_piv)
|
||||
msc_lu_factor = (mscp_lu, pivots)
|
||||
|
||||
scp_x = scp.linalg.lu_solve(lu_factor_x, b)
|
||||
mscp_x = lu_solve(msc_lu_factor, tensor_b)
|
||||
|
||||
real_b = mnp.dot(tensor_a, mscp_x)
|
||||
expected_b = np.dot(a, scp_x)
|
||||
|
||||
assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3)
|
||||
assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3)
|
||||
|
|
Loading…
Reference in New Issue