[feat] [assistant] [I48OF6] Add new AICPU operator MatrixSolveLs

This commit is contained in:
zy 2022-11-17 15:11:07 +08:00
parent 04145279ef
commit bffb600fab
9 changed files with 980 additions and 3 deletions

View File

@ -0,0 +1,498 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/eigen/matrix_solve_ls_cpu_kernel.h"
#include <Eigen/Cholesky>
#include <Eigen/Dense>
#include <algorithm>
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kInputNum = 3;
constexpr auto kOutputNum = 1;
constexpr int64_t kNum2 = 2;
constexpr char kFast[] = "fast";
constexpr bool kMatrixSolveLsComputeOk = true;
constexpr bool kMatrixSolveLsComputeFailed = false;
template <typename InputIt, typename T>
T GetNumElements(InputIt first, InputIt last, T init) {
for (; first != last; ++first) {
init = std::move(init) * (*first);
}
return init;
}
} // namespace
template <typename T, int Major>
using EigenMatrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Major>;
template <typename T>
void MatrixSolveLsCpuKernelMod::RealCholeskySingleCompute(T *aptr, T *bptr, T *xptr, double *l2, int64_t m, int64_t k,
int64_t n) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a(m, k);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(k, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(m, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_copy;
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_b;
for (int i = 0; i < m * k; i++) {
*(a.data() + i) = *(aptr + i);
}
for (int i = 0; i < m * n; i++) {
*(b.data() + i) = *(bptr + i);
}
if (m >= k) {
a_copy = a.transpose() * a + ((T)*l2) * EigenMatrix<T, Eigen::RowMajor>::Identity(k, k);
a_b = a.transpose() * b;
} else {
a_copy = a * a.transpose() + ((T)*l2) * EigenMatrix<T, Eigen::RowMajor>::Identity(m, m);
a_b = b;
}
for (int64_t i = 0; i < n; i++) {
EigenMatrix<T, Eigen::RowMajor> xi = a_copy.ldlt().solve(a_b.col(i));
if (m < k) {
xi = a.transpose() * xi;
}
x.col(i) = xi;
}
for (int64_t i = 0; i < k * n; i++) {
*(xptr + i) = *(x.data() + i);
}
}
template <typename T>
void MatrixSolveLsCpuKernelMod::ComplexCholeskySingleCompute(std::complex<T> *aptr, std::complex<T> *bptr,
std::complex<T> *xptr, double *l2, int64_t m, int64_t k,
int64_t n) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> A(kNum2 * m, kNum2 * k);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(kNum2 * k, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(kNum2 * m, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_copy;
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_b;
auto l2value = abs(*l2);
for (int64_t i = 0; i < k; i++) {
for (int64_t j = 0; j < m; j++) {
*(A.data() + i + j * kNum2 * k) = std::real(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + (i + k) + (j + m) * kNum2 * k) = std::real(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + (i + k) + j * kNum2 * k) = -std::imag(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + i + (j + m) * kNum2 * k) = std::imag(*(aptr + i + j * k));
}
}
for (int64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < m; j++) {
*(b.data() + i + j * n) = std::real(*(bptr + i + j * n));
*(b.data() + i + (j + m) * n) = std::imag(*(bptr + i + j * n));
}
}
if (m >= k) {
a_copy =
A.transpose() * A +
((T)l2value) * Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Identity(kNum2 * k, kNum2 * k);
a_b = A.transpose() * b;
} else {
a_copy =
A * A.transpose() +
((T)l2value) * Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Identity(kNum2 * m, kNum2 * m);
a_b = b;
}
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi;
for (int64_t i = 0; i < n; i++) {
xi = a_copy.ldlt().solve(a_b.col(i));
if (m < k) {
xi = A.transpose() * xi;
}
x.col(i) = xi;
for (int64_t j = 0; j < k; j++) {
(xptr + i + j * n)->real(*(x.data() + i + j * n));
(xptr + i + j * n)->imag(*(x.data() + i + (j + k) * n));
}
}
}
template <typename T>
void MatrixSolveLsCpuKernelMod::RealQrSingleCompute(T *aptr, T *bptr, T *xptr, int64_t m, int64_t k, int64_t n) {
EigenMatrix<T, Eigen::RowMajor> a(m, k);
EigenMatrix<T, Eigen::RowMajor> x(k, n);
EigenMatrix<T, Eigen::RowMajor> b(m, n);
for (int i = 0; i < m * k; i++) {
*(a.data() + i) = *(aptr + i);
}
for (int i = 0; i < m * n; i++) {
*(b.data() + i) = *(bptr + i);
}
Eigen::CompleteOrthogonalDecomposition<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> qr_solve(a);
for (int64_t i = 0; i < n; i++) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi = qr_solve.solve(b.col(i));
x.col(i) = xi;
}
for (int64_t i = 0; i < k * n; i++) {
*(xptr + i) = *(x.data() + i);
}
}
template <typename T>
void MatrixSolveLsCpuKernelMod::ComplexQrSingleCompute(std::complex<T> *aptr, std::complex<T> *bptr,
std::complex<T> *xptr, int64_t m, int64_t k, int64_t n) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> A(kNum2 * m, kNum2 * k);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(kNum2 * k, n);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(kNum2 * m, n);
for (int64_t i = 0; i < k; i++) {
for (int64_t j = 0; j < m; j++) {
*(A.data() + i + j * kNum2 * k) = std::real(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + (i + k) + (j + m) * kNum2 * k) = std::real(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + (i + k) + j * kNum2 * k) = -std::imag(*(aptr + i + j * k));
}
for (int64_t j = 0; j < m; j++) {
*(A.data() + i + (j + m) * kNum2 * k) = std::imag(*(aptr + i + j * k));
}
}
for (int64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < m; j++) {
*(b.data() + i + j * n) = std::real(*(bptr + i + j * n));
*(b.data() + i + (j + m) * n) = std::imag(*(bptr + i + j * n));
}
}
Eigen::CompleteOrthogonalDecomposition<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> qr_solve(A);
for (int64_t i = 0; i < n; i++) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi = qr_solve.solve(b.col(i));
x.col(i) = xi;
for (int64_t j = 0; j < k; j++) {
(xptr + i + j * n)->real(*(x.data() + i + j * n));
(xptr + i + j * n)->imag(*(x.data() + i + (j + k) * n));
}
}
}
template <typename T>
bool MatrixSolveLsCpuKernelMod::ComplexCholesky(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto dims = x_shape_vector.size();
auto l2 = reinterpret_cast<double *>(inputs[2]->addr);
auto aptr = reinterpret_cast<std::complex<T> *>(inputs[0]->addr);
auto bptr = reinterpret_cast<std::complex<T> *>(inputs[1]->addr);
auto xptr = reinterpret_cast<std::complex<T> *>(outputs[0]->addr);
int64_t m = x_shape_vector[dims - kNum2];
int64_t k = x_shape_vector[dims - 1];
int64_t n = 1;
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
if (b_shape_vector.size() > 1) {
n = b_shape_vector[dims - 1];
}
int64_t data_num = 1;
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
const int64_t mat_size = m * k;
const int64_t rhs_size = m * n;
const int64_t res_size = n * k;
const int64_t batch = data_num / mat_size;
const int64_t kParallelDataNum = 16 * mat_size;
if (data_num >= kParallelDataNum) {
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
ComplexCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
}
};
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, &parallel_search_info_);
} else {
for (int64_t i = 0; i < batch; i++) {
ComplexCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
}
}
return kMatrixSolveLsComputeOk;
}
template <typename T>
bool MatrixSolveLsCpuKernelMod::RealQr(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto dims = x_shape_vector.size();
auto aptr = reinterpret_cast<T *>(inputs[0]->addr);
auto bptr = reinterpret_cast<T *>(inputs[1]->addr);
auto xptr = reinterpret_cast<T *>(outputs[0]->addr);
int64_t m = x_shape_vector[dims - kNum2];
int64_t k = x_shape_vector[dims - 1];
int64_t n = 1;
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
if (b_shape_vector.size() > 1) {
n = b_shape_vector[dims - 1];
}
int64_t data_num = 1;
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
const int64_t mat_size = m * k;
const int64_t rhs_size = m * n;
const int64_t res_size = n * k;
const int64_t batch = data_num / mat_size;
const int64_t kParallelDataNum = 16 * mat_size;
if (data_num >= kParallelDataNum) {
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
RealQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
}
};
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, &parallel_search_info_);
} else {
for (int64_t i = 0; i < batch; i++) {
RealQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
}
}
return kMatrixSolveLsComputeOk;
}
template <typename T>
bool MatrixSolveLsCpuKernelMod::ComplexQr(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto dims = x_shape_vector.size();
int64_t m = x_shape_vector[dims - kNum2];
int64_t k = x_shape_vector[dims - 1];
int64_t n = 1;
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
if (b_shape_vector.size() > 1) {
n = b_shape_vector[dims - 1];
}
int64_t data_num = 1;
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
const int64_t mat_size = m * k;
const int64_t rhs_size = m * n;
const int64_t res_size = n * k;
const int64_t batch = data_num / mat_size;
const int64_t kParallelDataNum = 16 * mat_size;
auto aptr = reinterpret_cast<std::complex<T> *>(inputs[0]->addr);
auto bptr = reinterpret_cast<std::complex<T> *>(inputs[1]->addr);
auto xptr = reinterpret_cast<std::complex<T> *>(outputs[0]->addr);
if (data_num >= kParallelDataNum) {
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
ComplexQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
}
};
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, &parallel_search_info_);
} else {
for (int64_t i = 0; i < batch; i++) {
ComplexQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
}
}
return kMatrixSolveLsComputeOk;
}
template <typename T>
bool MatrixSolveLsCpuKernelMod::RealCholesky(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto dims = x_shape_vector.size();
auto aptr = reinterpret_cast<T *>(inputs[0]->addr);
auto bptr = reinterpret_cast<T *>(inputs[1]->addr);
auto xptr = reinterpret_cast<T *>(outputs[0]->addr);
auto l2 = reinterpret_cast<double *>(inputs[2]->addr);
int64_t m = x_shape_vector[dims - kNum2];
int64_t k = x_shape_vector[dims - 1];
int64_t n = 1;
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
if (b_shape_vector.size() > 1) {
n = b_shape_vector[dims - 1];
}
int64_t data_num = 1;
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
const int64_t mat_size = m * k;
const int64_t rhs_size = m * n;
const int64_t res_size = n * k;
const int64_t batch = data_num / mat_size;
const int64_t kParallelDataNum = 16 * mat_size;
if (data_num >= kParallelDataNum) {
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
RealCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
}
};
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, &parallel_search_info_);
} else {
for (int64_t i = 0; i < batch; i++) {
RealCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
}
}
return kMatrixSolveLsComputeOk;
}
void MatrixSolveLsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
node_wpt_ = kernel_node;
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (common::AnfAlgo::HasNodeAttr(kFast, kernel_node)) {
qr_chole = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kFast);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the attribute 'fast' does not exist.";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
bool MatrixSolveLsCpuKernelMod::Resize(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto shapea = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
auto shapeb = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
auto shapel2 = AnfAlgo::GetInputDeviceShape(node_wpt_, 2);
auto shapex = AnfAlgo::GetOutputDeviceShape(node_wpt_, 0);
auto dims = shapea.size();
if (shapeb.size() == 1) {
if (shapea[dims - kNum2] != shapeb[0]) {
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", #Rows mismatch between A and rhs."
<< "#Rows of A = [" << shapea[dims - kNum2] << "]"
<< "#Rows of rhs = [" << shapeb[0] << "]";
return kMatrixSolveLsComputeFailed;
}
} else {
if (shapea[dims - kNum2] != shapeb[dims - kNum2]) {
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "#Rows mismatch between A and rhs."
<< "#Rows of A = [" << shapea[dims - kNum2] << "]"
<< "#Rows of rhs = [" << shapeb[dims - kNum2] << "]";
return kMatrixSolveLsComputeFailed;
}
}
if (shapel2.size() != 0) {
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor l2 should be a scalar.";
return kMatrixSolveLsComputeFailed;
}
if (shapeb.size() == 1) {
if ((shapex.size() != shapeb.size()) || (shapea[dims - 1] != shapex[0]) || (shapex.back() != shapeb[0])) {
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor y shape mismatch.";
return kMatrixSolveLsComputeFailed;
}
} else {
if ((shapex.size() != shapeb.size()) || (shapea[dims - 1] != shapex[shapex.size() - kNum2]) ||
(shapex.back() != shapeb.back())) {
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor y shape mismatch.";
return kMatrixSolveLsComputeFailed;
}
}
return kMatrixSolveLsComputeOk;
}
template <typename T>
bool MatrixSolveLsCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
auto a_data_type = AnfAlgo::GetInputDeviceDataType(node_wpt_, 0);
auto b_data_type = AnfAlgo::GetInputDeviceDataType(node_wpt_, 1);
if (Resize(inputs, outputs) != true) {
return kMatrixSolveLsComputeFailed;
}
if (a_data_type != b_data_type) {
MS_EXCEPTION(TypeError) << "For " << kernel_name_ << "Tensor data type mismatch.";
return kMatrixSolveLsComputeFailed;
}
if (a_data_type != kNumberTypeFloat32 && a_data_type != kNumberTypeFloat64 && a_data_type != kNumberTypeComplex64 &&
a_data_type != kNumberTypeComplex128) {
MS_EXCEPTION(TypeError) << "For " << kernel_name_ << ", unsupported data type: " << TypeIdLabel(a_data_type) << ".";
return kMatrixSolveLsComputeFailed;
}
if (qr_chole) {
if (a_data_type == kNumberTypeComplex64) {
return ComplexCholesky<float>(inputs, outputs);
}
if (a_data_type == kNumberTypeComplex128) {
return ComplexCholesky<double>(inputs, outputs);
}
if (a_data_type == kNumberTypeFloat64) {
return RealCholesky<double>(inputs, outputs);
}
if (a_data_type == kNumberTypeFloat32) {
return RealCholesky<float>(inputs, outputs);
}
} else {
if (a_data_type == kNumberTypeComplex64) {
return ComplexQr<float>(inputs, outputs);
}
if (a_data_type == kNumberTypeComplex128) {
return ComplexQr<double>(inputs, outputs);
}
if (a_data_type == kNumberTypeFloat64) {
return RealQr<double>(inputs, outputs);
}
if (a_data_type == kNumberTypeFloat32) {
return RealQr<float>(inputs, outputs);
}
}
return kMatrixSolveLsComputeOk;
} // un pass
std::vector<std::pair<KernelAttr, MatrixSolveLsCpuKernelMod::MatrixSolveLsFunc>> MatrixSolveLsCpuKernelMod::func_list_ =
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixSolveLsCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixSolveLsCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeComplex64),
&MatrixSolveLsCpuKernelMod::LaunchKernel<std::complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeComplex128),
&MatrixSolveLsCpuKernelMod::LaunchKernel<std::complex<double>>}};
std::vector<KernelAttr> MatrixSolveLsCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixSolveLsFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixSolveLs, MatrixSolveLsCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include <complex>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixSolveLsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MatrixSolveLsCpuKernelMod() = default;
~MatrixSolveLsCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
bool Resize(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using MatrixSolveLsFunc =
std::function<bool(MatrixSolveLsCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixSolveLsFunc>> func_list_;
MatrixSolveLsFunc kernel_func_;
template <typename T>
void RealCholeskySingleCompute(T *aptr, T *bptr, T *xptr, double *l2, int64_t m, int64_t k, int64_t n);
template <typename T>
bool RealCholesky(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
void RealQrSingleCompute(T *aptr, T *bptr, T *xptr, int64_t m, int64_t k, int64_t n);
template <typename T>
bool RealQr(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
void ComplexCholeskySingleCompute(std::complex<T> *aptr, std::complex<T> *bptr, std::complex<T> *xptr, double *l2,
int64_t m, int64_t k, int64_t n);
template <typename T>
bool ComplexCholesky(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
void ComplexQrSingleCompute(std::complex<T> *aptr, std::complex<T> *bptr, std::complex<T> *xptr, int64_t m, int64_t k,
int64_t n);
template <typename T>
bool ComplexQr(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
bool qr_chole{true};
CNodePtr node_wpt_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_

View File

@ -140,6 +140,7 @@ constexpr auto kDiagonal = "Diagonal";
constexpr auto kEditDistance = "EditDistance";
constexpr auto kNextAfter = "NextAfter";
constexpr auto kMaximumGradGrad = "MaximumGradGrad";
constexpr auto kMatrixSolveLs = "MatrixSolveLs";
constexpr auto kSparseSegmentMean = "SparseSegmentMean";
constexpr auto kTridiagonalMatMul = "TridiagonalMatMul";
constexpr auto kTridiagonalSolve = "TridiagonalSolve";
@ -1269,6 +1270,7 @@ GVAR_DEF(PrimitivePtr, kPrimLinSpace, std::make_shared<Primitive>("LinSpace"));
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppression, std::make_shared<Primitive>("NonMaxSuppression"));
GVAR_DEF(PrimitivePtr, kPrimSign, std::make_shared<Primitive>("Sign"));
GVAR_DEF(PrimitivePtr, kPrimACos, std::make_shared<Primitive>(kACos));
GVAR_DEF(PrimitivePtr, kPrimMatrixSolveLs, std::make_shared<Primitive>(kMatrixSolveLs));
GVAR_DEF(PrimitivePtr, kPrimAsinGrad, std::make_shared<Primitive>(kAsinGrad));
GVAR_DEF(PrimitivePtr, kPrimACosGrad, std::make_shared<Primitive>(kACosGrad));
GVAR_DEF(PrimitivePtr, kPrimAtanGrad, std::make_shared<Primitive>("AtanGrad"));

View File

@ -0,0 +1,104 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "ops/matrix_solve_ls.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MatrixSolveLsInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto matrix_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto rhs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input matrix rank", SizeToLong(matrix_shape.size()), kGreaterEqual, 2L,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("input rhs rank", SizeToLong(rhs_shape.size()), kGreaterEqual, 2L,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("input l2 rank", SizeToLong(l2_shape.size()), kEqual, 0L, prim_name);
constexpr size_t offset = 2;
std::vector<int64_t> matrix_last(matrix_shape.end() - offset, matrix_shape.end());
std::vector<int64_t> rhs_last(rhs_shape.end() - offset, rhs_shape.end());
std::vector<int64_t> y_shape(rhs_shape.begin(), rhs_shape.end() - offset);
int64_t matrix_row = matrix_last[0];
int64_t matrix_col = matrix_last[1];
int64_t rhs_row = rhs_last[0];
int64_t rhs_col = rhs_last[1];
for (size_t i = 0; i < matrix_shape.size() - offset; ++i) {
if (matrix_shape[i] != rhs_shape[i]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", shapes in batch dimension must be same, but dim[" << i
<< "] are not the same, "
<< "got matrix_dim[" << i << "]: " << matrix_shape[i] << ", rhs_dim[" << i
<< "]: " << rhs_shape[i] << ".";
}
}
if (matrix_row != rhs_row) {
MS_EXCEPTION(ValueError) << "MatrixSolveLs shape error, got matrix_row: " << matrix_row << ", rhs_row: " << rhs_row
<< ". In MatrixSolveLs matrix_row and rhs_row should be equal.";
}
y_shape.push_back(matrix_col);
y_shape.push_back(rhs_col);
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr MatrixSolveLsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
const std::set<TypePtr> l2_valid_types = {kFloat64};
auto matrix_type = input_args[0]->BuildType();
auto rhs_type = input_args[1]->BuildType();
auto l2_type = input_args[2]->BuildType();
std::map<std::string, TypePtr> types;
(void)types.emplace("matrix", matrix_type);
(void)types.emplace("rhs", rhs_type);
(void)CheckAndConvertUtils::CheckTypeValid("matrix", matrix_type, valid_types, primitive->name());
(void)CheckAndConvertUtils::CheckTypeValid("rhs", rhs_type, valid_types, primitive->name());
(void)CheckAndConvertUtils::CheckTypeValid("l2_regularizer", l2_type, l2_valid_types, primitive->name());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
return matrix_type;
}
} // namespace
AbstractBasePtr MatrixSolveLsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = MatrixSolveLsInferType(primitive, input_args);
auto infer_shape = MatrixSolveLsInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixSolveLs, prim::kPrimMatrixSolveLs, MatrixSolveLsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H
#define MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixSolveLs = "MatrixSolveLs";
class MIND_API MatrixSolveLs : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixSolveLs);
/// \brief Constructor.
MatrixSolveLs() : BaseOperator(kNameMatrixSolveLs) { InitIOName({"matrix", "rhs", "l2"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MatrixSolveLs for the inputs.
void Init() const {}
};
abstract::AbstractBasePtr MatrixSolveLsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMatrixSolveLsPtr = std::shared_ptr<MatrixSolveLs>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H

View File

@ -51,6 +51,7 @@ from mindspore.ops.operations.math_ops import LuUnpack
from mindspore.ops.operations.math_ops import MatrixExp
from mindspore.ops.operations.math_ops import CumulativeLogsumexp
from mindspore.ops.operations.math_ops import MatrixSolve
from mindspore.ops.operations.math_ops import MatrixSolveLs
from mindspore.ops.operations.math_ops import MatrixPower
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
@ -73,6 +74,7 @@ from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_elemen
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
from mindspore.ops._grad.grad_math_ops import binop_grad_common
from mindspore.ops.operations.array_ops import MatrixBandPart
from mindspore.ops.operations.array_ops import ConjugateTranspose
transpose = P.Transpose()
dyn_shape_op = P.TensorShape()
@ -672,6 +674,155 @@ def get_bprop_matrix_solve(self):
return bprop
@constexpr
def _generate_perm_matrix_solve_ls(x_dim):
perm = tuple(range(x_dim - 2))
perm = perm + (x_dim-1, x_dim-2)
return perm
@bprop_getters.register(MatrixSolveLs)
def get_bprop_matrix_solve_ls(self):
"""Grad definition for 'MatrixSolveLs' operation"""
fast = self.fast
cast = P.Cast()
neg = P.Neg()
rank = P.Rank()
cholesky = Cholesky()
eye = P.Eye()
add = P.Add()
mul = P.Mul()
matmul = P.MatMul()
batch_matmul = P.BatchMatMul()
cholesky_solve = CholeskySolve()
_transpose = Transpose()
conjugate_transpose = ConjugateTranspose()
shape = P.Shape()
_complex = P.Complex()
def regularized_gramian_cholesky(matrix, l2, first_kind):
matrix_dim = rank(matrix)
perm = _generate_perm_matrix_solve_ls(matrix_dim)
if matrix.dtype in (mstype.complex64, mstype.complex128):
matrix_temp = conjugate_transpose(matrix, perm)
else:
matrix_temp = _transpose(matrix, perm)
if first_kind:
if matrix_dim > 2:
gramian = batch_matmul(matrix_temp, matrix)
else:
gramian = matmul(matrix_temp, matrix)
else:
if matrix_dim > 2:
gramian = batch_matmul(matrix, matrix_temp)
else:
gramian = matmul(matrix, matrix_temp)
if isinstance(l2, Tensor) or l2 != 0:
matrix_shape = shape(matrix)
if first_kind:
small_dim = matrix_shape[-1]
else:
small_dim = matrix_shape[-2]
identity = eye(small_dim, small_dim, matrix.dtype)
gramian = add(gramian, mul(l2, identity))
#Cholesky not support complex dtype for now
return cholesky(gramian)
def bprop(matrix, rhs, l2, out, dout):
#support dtype:float32
#support dimension: 2D,3D
def over_determined(matrix, rhs, out, l2, dout):
if matrix.dtype == mstype.complex64:
l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
elif matrix.dtype == mstype.complex128:
l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
else:
l2_regularizer = cast(l2, matrix.dtype)
chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=True), matrix.dtype)
#CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
z = cholesky_solve(dout, chol)
matrix_dim = rank(matrix)
perm = _generate_perm_matrix_solve_ls(matrix_dim)
if matrix.dtype in (mstype.complex64, mstype.complex128):
z_temp = conjugate_transpose(z, perm)
else:
z_temp = _transpose(z, perm)
if matrix_dim > 2:
xzt = batch_matmul(out, z_temp)
else:
xzt = matmul(out, z_temp)
zx_sym = add(xzt, _transpose(xzt, perm))
if matrix_dim > 2:
grad_a = add(neg(batch_matmul(matrix, zx_sym)), batch_matmul(rhs, z_temp))
grad_b = batch_matmul(matrix, z)
else:
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
grad_b = matmul(matrix, z)
return (grad_a, grad_b, None)
def under_determined(matrix, rhs, l2, dout):
if matrix.dtype == mstype.complex64:
l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
elif matrix.dtype == mstype.complex128:
l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
else:
l2_regularizer = cast(l2, matrix.dtype)
chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=False), matrix.dtype)
matrix_dim = rank(matrix)
perm = _generate_perm_matrix_solve_ls(matrix_dim)
if matrix_dim > 2:
gramian = batch_matmul(matrix, dout)
else:
gramian = matmul(matrix, dout)
#CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
grad_b = cholesky_solve(gramian, chol)
tmp = cholesky_solve(rhs, chol)
if matrix.dtype in (mstype.complex64, mstype.complex128):
tmp_temp = conjugate_transpose(tmp, perm)
matrix_temp = conjugate_transpose(matrix, perm)
else:
tmp_temp = _transpose(tmp, perm)
matrix_temp = _transpose(matrix, perm)
if matrix_dim > 2:
a1 = batch_matmul(tmp_temp, matrix)
a1 = neg(batch_matmul(grad_b, a1))
a2 = dout - batch_matmul(matrix_temp, grad_b)
if matrix.dtype in (mstype.complex64, mstype.complex128):
a2_temp = conjugate_transpose(a2, perm)
else:
a2_temp = _transpose(a2, perm)
a2 = batch_matmul(tmp, a2_temp)
else:
a1 = matmul(tmp_temp, matrix)
a1 = neg(matmul(grad_b, a1))
a2 = dout - matmul(matrix_temp, grad_b)
if matrix.dtype in (mstype.complex64, mstype.complex128):
a2_temp = conjugate_transpose(a2, perm)
else:
a2_temp = _transpose(a2, perm)
a2 = matmul(tmp, a2_temp)
grad_a = add(a1, a2)
return (grad_a, grad_b, None)
if fast is False:
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
matrix_shape = shape(matrix)[-2:]
if matrix_shape[-2] >= matrix_shape[-1]:
return over_determined(matrix, rhs, out, l2, dout)
return under_determined(matrix, rhs, l2, dout)
return bprop
@bprop_getters.register(P.MatrixDeterminant)
def get_bprop_matrix_determinant(self):
"""Generate bprop for MatrixDeterminant"""

View File

@ -0,0 +1,36 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixSolveLs op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_solve_ls_op_info = AiCPURegOp("MatrixSolveLs") \
.fusion_type("OPAQUE") \
.attr("fast", "bool") \
.input(0, "matrix", "required") \
.input(1, "rhs", "required") \
.input(2, "l2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.F64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.F64_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(matrix_solve_ls_op_info)
def _matrix_solve_ls_aicpu():
"""MatrixSolveLs aicpu register"""
return

View File

@ -1518,9 +1518,8 @@ class MatMul(PrimitiveWithCheck):
def check_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensors_dtypes_same_and_valid(args,
mstype.float_type + mstype.int_type + (mstype.complex64,),
self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type
+ (mstype.complex64, mstype.complex128), self.name)
class BatchMatMul(Primitive):
@ -6218,6 +6217,57 @@ class MatrixSolve(Primitive):
self.adjoint = validator.check_value_type("adjoint", adjoint, [bool], self.name)
class MatrixSolveLs(Primitive):
r"""
Solves one or more linear least-squares problems.
If `fast` is `True`,then the solution is computed by solving the normal equations using Cholesky decomposition.
If `fast` is `False` an algorithm based on the numerically robust complete orthogonal decomposition is used. This
path is typically 6-7 times slower than the fast path. If `fast` is `False` then `l2_regularizer` is ignored.
Args:
fast (bool): An optional bool. Defaults to True.
Inputs:
- **matrix** (Tensor) - A Tensor. Must be one of the following data types: float64, float32, complex64,
complex128. Shape is :math:`(*, M, N)`.
- **rhs** (Tensor) - A Tensor. Must have the same data type as matrix. Shape is :math:`(*, M, K)`.
`matrix` and `rhs` should have the same dimensions except the last one.
- **l2_regularizer** (Tensor) - A Tensor of type float64. Scalar tensor.
Outputs:
Tensor of shape :math:`(*, N, K)` with the same data type as `matrix`.
Raises:
TypeError: If `matrix`, `rhs` or `l2_regularizer` is not tensor.
TypeError: If either of `matrix` and `rhs` is not float32, float64, complex64 or complex128.
TypeError: If `l2_regularizer` is not float64.
TypeError: If `fast` is not bool.
ValueError: If dimensions of `matrix` or `rhs` is less than 2.
ValueError: If shape of `matrix` dose not match the shape of `rhs`.
Supported Platforms:
``CPU``
Examples:
>>> matrix_solve_ls = ops.MatrixSolveLs(fast=True)
>>> matrix = Tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], mstype.float32)
>>> rhs = Tensor(np.array([[4], [2], [4], [2]]), mstype.float32)
>>> l2 = Tensor(0.0, mstype.float64)
>>> output = matrix_solve_ls(matrix, rhs, l2)
>>> print(output)
[[ 1.3333334]
[-0.6666667]
[ 2.6666665]
[-1.3333333]]
"""
@prim_attr_register
def __init__(self, fast=True):
"""Initialize MatrixSolveLs"""
validator.check_value_type('fast', fast, [bool], self.name)
class LuSolve(Primitive):
"""
Return the solution of the linear equation Ax = b.

View File

@ -51,6 +51,7 @@ from mindspore.ops.operations.math_ops import MatrixExp
from mindspore.ops.operations.math_ops import FFTWithSize
from mindspore.ops.operations.math_ops import MatrixPower
from mindspore.ops.operations.math_ops import MatrixSolve
from mindspore.ops.operations.math_ops import MatrixSolveLs
from mindspore.ops.operations.math_ops import MatrixLogarithm
from mindspore.ops.operations.math_ops import CholeskySolve
from mindspore.ops.operations.math_ops import NextAfter
@ -2465,6 +2466,12 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[[1., 4.], [2., 7.]], [[1., 4.], [2., 7.]]]).astype(np.float32)),
Tensor(np.array([[[1.], [3.]], [[1.], [3.]]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[[1.], [1.]], [[1.], [1.]]]).astype(np.float32))]}),
('MatrixSolveLs', {
'block': MatrixSolveLs(fast=True),
'desc_inputs': [Tensor(np.array([[[1., 4.], [2., 7.]], [[1., 4.], [2., 7.]]]).astype(np.float32)),
Tensor(np.array([[[1.], [3.]], [[1.], [3.]]]).astype(np.float32)),
Tensor(np.random.uniform(0.0, 5.0), mstype.float64)],
'desc_bprop': [Tensor(np.array([[[1.], [1.]], [[1.], [1.]]]).astype(np.float32))]}),
('MatrixDeterminant', {
'block': P.MatrixDeterminant(),
'desc_inputs': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))],