forked from mindspore-Ecosystem/mindspore
!43926 [assistant][ops] Add MatrixSolveLs
Merge pull request !43926 from 张渝/MatrixSolveLs
This commit is contained in:
commit
cbc9496860
|
@ -0,0 +1,498 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/matrix_solve_ls_cpu_kernel.h"
|
||||
#include <Eigen/Cholesky>
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr auto kInputNum = 3;
|
||||
constexpr auto kOutputNum = 1;
|
||||
constexpr int64_t kNum2 = 2;
|
||||
constexpr char kFast[] = "fast";
|
||||
constexpr bool kMatrixSolveLsComputeOk = true;
|
||||
constexpr bool kMatrixSolveLsComputeFailed = false;
|
||||
|
||||
template <typename InputIt, typename T>
|
||||
T GetNumElements(InputIt first, InputIt last, T init) {
|
||||
for (; first != last; ++first) {
|
||||
init = std::move(init) * (*first);
|
||||
}
|
||||
return init;
|
||||
}
|
||||
} // namespace
|
||||
template <typename T, int Major>
|
||||
using EigenMatrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Major>;
|
||||
|
||||
template <typename T>
|
||||
void MatrixSolveLsCpuKernelMod::RealCholeskySingleCompute(T *aptr, T *bptr, T *xptr, double *l2, int64_t m, int64_t k,
|
||||
int64_t n) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a(m, k);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(k, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(m, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_copy;
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_b;
|
||||
|
||||
for (int i = 0; i < m * k; i++) {
|
||||
*(a.data() + i) = *(aptr + i);
|
||||
}
|
||||
for (int i = 0; i < m * n; i++) {
|
||||
*(b.data() + i) = *(bptr + i);
|
||||
}
|
||||
|
||||
if (m >= k) {
|
||||
a_copy = a.transpose() * a + ((T)*l2) * EigenMatrix<T, Eigen::RowMajor>::Identity(k, k);
|
||||
a_b = a.transpose() * b;
|
||||
} else {
|
||||
a_copy = a * a.transpose() + ((T)*l2) * EigenMatrix<T, Eigen::RowMajor>::Identity(m, m);
|
||||
a_b = b;
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
EigenMatrix<T, Eigen::RowMajor> xi = a_copy.ldlt().solve(a_b.col(i));
|
||||
if (m < k) {
|
||||
xi = a.transpose() * xi;
|
||||
}
|
||||
x.col(i) = xi;
|
||||
}
|
||||
for (int64_t i = 0; i < k * n; i++) {
|
||||
*(xptr + i) = *(x.data() + i);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSolveLsCpuKernelMod::ComplexCholeskySingleCompute(std::complex<T> *aptr, std::complex<T> *bptr,
|
||||
std::complex<T> *xptr, double *l2, int64_t m, int64_t k,
|
||||
int64_t n) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> A(kNum2 * m, kNum2 * k);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(kNum2 * k, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(kNum2 * m, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_copy;
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> a_b;
|
||||
auto l2value = abs(*l2);
|
||||
|
||||
for (int64_t i = 0; i < k; i++) {
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + i + j * kNum2 * k) = std::real(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + (i + k) + (j + m) * kNum2 * k) = std::real(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + (i + k) + j * kNum2 * k) = -std::imag(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + i + (j + m) * kNum2 * k) = std::imag(*(aptr + i + j * k));
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(b.data() + i + j * n) = std::real(*(bptr + i + j * n));
|
||||
*(b.data() + i + (j + m) * n) = std::imag(*(bptr + i + j * n));
|
||||
}
|
||||
}
|
||||
|
||||
if (m >= k) {
|
||||
a_copy =
|
||||
A.transpose() * A +
|
||||
((T)l2value) * Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Identity(kNum2 * k, kNum2 * k);
|
||||
a_b = A.transpose() * b;
|
||||
} else {
|
||||
a_copy =
|
||||
A * A.transpose() +
|
||||
((T)l2value) * Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Identity(kNum2 * m, kNum2 * m);
|
||||
a_b = b;
|
||||
}
|
||||
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi;
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
xi = a_copy.ldlt().solve(a_b.col(i));
|
||||
if (m < k) {
|
||||
xi = A.transpose() * xi;
|
||||
}
|
||||
x.col(i) = xi;
|
||||
for (int64_t j = 0; j < k; j++) {
|
||||
(xptr + i + j * n)->real(*(x.data() + i + j * n));
|
||||
(xptr + i + j * n)->imag(*(x.data() + i + (j + k) * n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSolveLsCpuKernelMod::RealQrSingleCompute(T *aptr, T *bptr, T *xptr, int64_t m, int64_t k, int64_t n) {
|
||||
EigenMatrix<T, Eigen::RowMajor> a(m, k);
|
||||
EigenMatrix<T, Eigen::RowMajor> x(k, n);
|
||||
EigenMatrix<T, Eigen::RowMajor> b(m, n);
|
||||
|
||||
for (int i = 0; i < m * k; i++) {
|
||||
*(a.data() + i) = *(aptr + i);
|
||||
}
|
||||
for (int i = 0; i < m * n; i++) {
|
||||
*(b.data() + i) = *(bptr + i);
|
||||
}
|
||||
|
||||
Eigen::CompleteOrthogonalDecomposition<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> qr_solve(a);
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi = qr_solve.solve(b.col(i));
|
||||
x.col(i) = xi;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < k * n; i++) {
|
||||
*(xptr + i) = *(x.data() + i);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixSolveLsCpuKernelMod::ComplexQrSingleCompute(std::complex<T> *aptr, std::complex<T> *bptr,
|
||||
std::complex<T> *xptr, int64_t m, int64_t k, int64_t n) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> A(kNum2 * m, kNum2 * k);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> x(kNum2 * k, n);
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> b(kNum2 * m, n);
|
||||
|
||||
for (int64_t i = 0; i < k; i++) {
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + i + j * kNum2 * k) = std::real(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + (i + k) + (j + m) * kNum2 * k) = std::real(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + (i + k) + j * kNum2 * k) = -std::imag(*(aptr + i + j * k));
|
||||
}
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(A.data() + i + (j + m) * kNum2 * k) = std::imag(*(aptr + i + j * k));
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < m; j++) {
|
||||
*(b.data() + i + j * n) = std::real(*(bptr + i + j * n));
|
||||
*(b.data() + i + (j + m) * n) = std::imag(*(bptr + i + j * n));
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::CompleteOrthogonalDecomposition<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> qr_solve(A);
|
||||
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> xi = qr_solve.solve(b.col(i));
|
||||
x.col(i) = xi;
|
||||
|
||||
for (int64_t j = 0; j < k; j++) {
|
||||
(xptr + i + j * n)->real(*(x.data() + i + j * n));
|
||||
(xptr + i + j * n)->imag(*(x.data() + i + (j + k) * n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixSolveLsCpuKernelMod::ComplexCholesky(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto dims = x_shape_vector.size();
|
||||
auto l2 = reinterpret_cast<double *>(inputs[2]->addr);
|
||||
auto aptr = reinterpret_cast<std::complex<T> *>(inputs[0]->addr);
|
||||
auto bptr = reinterpret_cast<std::complex<T> *>(inputs[1]->addr);
|
||||
auto xptr = reinterpret_cast<std::complex<T> *>(outputs[0]->addr);
|
||||
int64_t m = x_shape_vector[dims - kNum2];
|
||||
int64_t k = x_shape_vector[dims - 1];
|
||||
int64_t n = 1;
|
||||
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
if (b_shape_vector.size() > 1) {
|
||||
n = b_shape_vector[dims - 1];
|
||||
}
|
||||
int64_t data_num = 1;
|
||||
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
|
||||
const int64_t mat_size = m * k;
|
||||
const int64_t rhs_size = m * n;
|
||||
const int64_t res_size = n * k;
|
||||
const int64_t batch = data_num / mat_size;
|
||||
const int64_t kParallelDataNum = 16 * mat_size;
|
||||
if (data_num >= kParallelDataNum) {
|
||||
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
ComplexCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, ¶llel_search_info_);
|
||||
} else {
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
ComplexCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixSolveLsCpuKernelMod::RealQr(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto dims = x_shape_vector.size();
|
||||
auto aptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto bptr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto xptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int64_t m = x_shape_vector[dims - kNum2];
|
||||
int64_t k = x_shape_vector[dims - 1];
|
||||
int64_t n = 1;
|
||||
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
if (b_shape_vector.size() > 1) {
|
||||
n = b_shape_vector[dims - 1];
|
||||
}
|
||||
int64_t data_num = 1;
|
||||
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
|
||||
const int64_t mat_size = m * k;
|
||||
const int64_t rhs_size = m * n;
|
||||
const int64_t res_size = n * k;
|
||||
const int64_t batch = data_num / mat_size;
|
||||
const int64_t kParallelDataNum = 16 * mat_size;
|
||||
if (data_num >= kParallelDataNum) {
|
||||
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
RealQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, ¶llel_search_info_);
|
||||
} else {
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
RealQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixSolveLsCpuKernelMod::ComplexQr(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto dims = x_shape_vector.size();
|
||||
int64_t m = x_shape_vector[dims - kNum2];
|
||||
int64_t k = x_shape_vector[dims - 1];
|
||||
int64_t n = 1;
|
||||
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
if (b_shape_vector.size() > 1) {
|
||||
n = b_shape_vector[dims - 1];
|
||||
}
|
||||
int64_t data_num = 1;
|
||||
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
|
||||
const int64_t mat_size = m * k;
|
||||
const int64_t rhs_size = m * n;
|
||||
const int64_t res_size = n * k;
|
||||
const int64_t batch = data_num / mat_size;
|
||||
const int64_t kParallelDataNum = 16 * mat_size;
|
||||
auto aptr = reinterpret_cast<std::complex<T> *>(inputs[0]->addr);
|
||||
auto bptr = reinterpret_cast<std::complex<T> *>(inputs[1]->addr);
|
||||
auto xptr = reinterpret_cast<std::complex<T> *>(outputs[0]->addr);
|
||||
if (data_num >= kParallelDataNum) {
|
||||
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
ComplexQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, ¶llel_search_info_);
|
||||
} else {
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
ComplexQrSingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, m, k, n);
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixSolveLsCpuKernelMod::RealCholesky(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto dims = x_shape_vector.size();
|
||||
auto aptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto bptr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto xptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto l2 = reinterpret_cast<double *>(inputs[2]->addr);
|
||||
int64_t m = x_shape_vector[dims - kNum2];
|
||||
int64_t k = x_shape_vector[dims - 1];
|
||||
int64_t n = 1;
|
||||
auto b_shape_vector = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
if (b_shape_vector.size() > 1) {
|
||||
n = b_shape_vector[dims - 1];
|
||||
}
|
||||
int64_t data_num = 1;
|
||||
data_num = GetNumElements(x_shape_vector.begin(), x_shape_vector.end(), data_num);
|
||||
const int64_t mat_size = m * k;
|
||||
const int64_t rhs_size = m * n;
|
||||
const int64_t res_size = n * k;
|
||||
const int64_t batch = data_num / mat_size;
|
||||
const int64_t kParallelDataNum = 16 * mat_size;
|
||||
if (data_num >= kParallelDataNum) {
|
||||
auto sharder_matrix_solve_ls = [&](int64_t start, int64_t end) {
|
||||
for (int64_t i = start; i < end; i++) {
|
||||
RealCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(sharder_matrix_solve_ls, batch, this, ¶llel_search_info_);
|
||||
} else {
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
RealCholeskySingleCompute(aptr + i * mat_size, bptr + i * rhs_size, xptr + i * res_size, l2, m, k, n);
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
}
|
||||
|
||||
void MatrixSolveLsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
node_wpt_ = kernel_node;
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (common::AnfAlgo::HasNodeAttr(kFast, kernel_node)) {
|
||||
qr_chole = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, kFast);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the attribute 'fast' does not exist.";
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
bool MatrixSolveLsCpuKernelMod::Resize(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto shapea = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
auto shapeb = AnfAlgo::GetInputDeviceShape(node_wpt_, 1);
|
||||
auto shapel2 = AnfAlgo::GetInputDeviceShape(node_wpt_, 2);
|
||||
auto shapex = AnfAlgo::GetOutputDeviceShape(node_wpt_, 0);
|
||||
auto dims = shapea.size();
|
||||
|
||||
if (shapeb.size() == 1) {
|
||||
if (shapea[dims - kNum2] != shapeb[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", #Rows mismatch between A and rhs."
|
||||
<< "#Rows of A = [" << shapea[dims - kNum2] << "]"
|
||||
<< "#Rows of rhs = [" << shapeb[0] << "]";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
} else {
|
||||
if (shapea[dims - kNum2] != shapeb[dims - kNum2]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "#Rows mismatch between A and rhs."
|
||||
<< "#Rows of A = [" << shapea[dims - kNum2] << "]"
|
||||
<< "#Rows of rhs = [" << shapeb[dims - kNum2] << "]";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
}
|
||||
|
||||
if (shapel2.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor l2 should be a scalar.";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
if (shapeb.size() == 1) {
|
||||
if ((shapex.size() != shapeb.size()) || (shapea[dims - 1] != shapex[0]) || (shapex.back() != shapeb[0])) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor y shape mismatch.";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
} else {
|
||||
if ((shapex.size() != shapeb.size()) || (shapea[dims - 1] != shapex[shapex.size() - kNum2]) ||
|
||||
(shapex.back() != shapeb.back())) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << "Tensor y shape mismatch.";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixSolveLsCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
|
||||
auto a_data_type = AnfAlgo::GetInputDeviceDataType(node_wpt_, 0);
|
||||
auto b_data_type = AnfAlgo::GetInputDeviceDataType(node_wpt_, 1);
|
||||
|
||||
if (Resize(inputs, outputs) != true) {
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
|
||||
if (a_data_type != b_data_type) {
|
||||
MS_EXCEPTION(TypeError) << "For " << kernel_name_ << "Tensor data type mismatch.";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
if (a_data_type != kNumberTypeFloat32 && a_data_type != kNumberTypeFloat64 && a_data_type != kNumberTypeComplex64 &&
|
||||
a_data_type != kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(TypeError) << "For " << kernel_name_ << ", unsupported data type: " << TypeIdLabel(a_data_type) << ".";
|
||||
return kMatrixSolveLsComputeFailed;
|
||||
}
|
||||
|
||||
if (qr_chole) {
|
||||
if (a_data_type == kNumberTypeComplex64) {
|
||||
return ComplexCholesky<float>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeComplex128) {
|
||||
return ComplexCholesky<double>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeFloat64) {
|
||||
return RealCholesky<double>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeFloat32) {
|
||||
return RealCholesky<float>(inputs, outputs);
|
||||
}
|
||||
} else {
|
||||
if (a_data_type == kNumberTypeComplex64) {
|
||||
return ComplexQr<float>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeComplex128) {
|
||||
return ComplexQr<double>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeFloat64) {
|
||||
return RealQr<double>(inputs, outputs);
|
||||
}
|
||||
if (a_data_type == kNumberTypeFloat32) {
|
||||
return RealQr<float>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
return kMatrixSolveLsComputeOk;
|
||||
} // un pass
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixSolveLsCpuKernelMod::MatrixSolveLsFunc>> MatrixSolveLsCpuKernelMod::func_list_ =
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixSolveLsCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixSolveLsCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&MatrixSolveLsCpuKernelMod::LaunchKernel<std::complex<float>>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&MatrixSolveLsCpuKernelMod::LaunchKernel<std::complex<double>>}};
|
||||
|
||||
std::vector<KernelAttr> MatrixSolveLsCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MatrixSolveLsFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixSolveLs, MatrixSolveLsCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MatrixSolveLsCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MatrixSolveLsCpuKernelMod() = default;
|
||||
~MatrixSolveLsCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
bool Resize(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MatrixSolveLsFunc =
|
||||
std::function<bool(MatrixSolveLsCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixSolveLsFunc>> func_list_;
|
||||
MatrixSolveLsFunc kernel_func_;
|
||||
|
||||
template <typename T>
|
||||
void RealCholeskySingleCompute(T *aptr, T *bptr, T *xptr, double *l2, int64_t m, int64_t k, int64_t n);
|
||||
|
||||
template <typename T>
|
||||
bool RealCholesky(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void RealQrSingleCompute(T *aptr, T *bptr, T *xptr, int64_t m, int64_t k, int64_t n);
|
||||
|
||||
template <typename T>
|
||||
bool RealQr(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void ComplexCholeskySingleCompute(std::complex<T> *aptr, std::complex<T> *bptr, std::complex<T> *xptr, double *l2,
|
||||
int64_t m, int64_t k, int64_t n);
|
||||
|
||||
template <typename T>
|
||||
bool ComplexCholesky(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void ComplexQrSingleCompute(std::complex<T> *aptr, std::complex<T> *bptr, std::complex<T> *xptr, int64_t m, int64_t k,
|
||||
int64_t n);
|
||||
|
||||
template <typename T>
|
||||
bool ComplexQr(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
bool qr_chole{true};
|
||||
CNodePtr node_wpt_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_SOLVE_LS_CPU_KERNEL_H_
|
|
@ -141,6 +141,7 @@ constexpr auto kDiagonal = "Diagonal";
|
|||
constexpr auto kEditDistance = "EditDistance";
|
||||
constexpr auto kNextAfter = "NextAfter";
|
||||
constexpr auto kMaximumGradGrad = "MaximumGradGrad";
|
||||
constexpr auto kMatrixSolveLs = "MatrixSolveLs";
|
||||
constexpr auto kSparseSegmentMean = "SparseSegmentMean";
|
||||
constexpr auto kTridiagonalMatMul = "TridiagonalMatMul";
|
||||
constexpr auto kTridiagonalSolve = "TridiagonalSolve";
|
||||
|
@ -1272,6 +1273,7 @@ GVAR_DEF(PrimitivePtr, kPrimLinSpace, std::make_shared<Primitive>("LinSpace"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppression, std::make_shared<Primitive>("NonMaxSuppression"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSign, std::make_shared<Primitive>("Sign"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimACos, std::make_shared<Primitive>(kACos));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixSolveLs, std::make_shared<Primitive>(kMatrixSolveLs));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAsinGrad, std::make_shared<Primitive>(kAsinGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimACosGrad, std::make_shared<Primitive>(kACosGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAtanGrad, std::make_shared<Primitive>("AtanGrad"));
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/matrix_solve_ls.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr MatrixSolveLsInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto matrix_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto rhs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("input matrix rank", SizeToLong(matrix_shape.size()), kGreaterEqual, 2L,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input rhs rank", SizeToLong(rhs_shape.size()), kGreaterEqual, 2L,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input l2 rank", SizeToLong(l2_shape.size()), kEqual, 0L, prim_name);
|
||||
|
||||
constexpr size_t offset = 2;
|
||||
std::vector<int64_t> matrix_last(matrix_shape.end() - offset, matrix_shape.end());
|
||||
std::vector<int64_t> rhs_last(rhs_shape.end() - offset, rhs_shape.end());
|
||||
std::vector<int64_t> y_shape(rhs_shape.begin(), rhs_shape.end() - offset);
|
||||
int64_t matrix_row = matrix_last[0];
|
||||
int64_t matrix_col = matrix_last[1];
|
||||
int64_t rhs_row = rhs_last[0];
|
||||
int64_t rhs_col = rhs_last[1];
|
||||
|
||||
for (size_t i = 0; i < matrix_shape.size() - offset; ++i) {
|
||||
if (matrix_shape[i] != rhs_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", shapes in batch dimension must be same, but dim[" << i
|
||||
<< "] are not the same, "
|
||||
<< "got matrix_dim[" << i << "]: " << matrix_shape[i] << ", rhs_dim[" << i
|
||||
<< "]: " << rhs_shape[i] << ".";
|
||||
}
|
||||
}
|
||||
|
||||
if (matrix_row != rhs_row) {
|
||||
MS_EXCEPTION(ValueError) << "MatrixSolveLs shape error, got matrix_row: " << matrix_row << ", rhs_row: " << rhs_row
|
||||
<< ". In MatrixSolveLs matrix_row and rhs_row should be equal.";
|
||||
}
|
||||
|
||||
y_shape.push_back(matrix_col);
|
||||
y_shape.push_back(rhs_col);
|
||||
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
|
||||
TypePtr MatrixSolveLsInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
const std::set<TypePtr> l2_valid_types = {kFloat64};
|
||||
auto matrix_type = input_args[0]->BuildType();
|
||||
auto rhs_type = input_args[1]->BuildType();
|
||||
auto l2_type = input_args[2]->BuildType();
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("matrix", matrix_type);
|
||||
(void)types.emplace("rhs", rhs_type);
|
||||
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("matrix", matrix_type, valid_types, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("rhs", rhs_type, valid_types, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("l2_regularizer", l2_type, l2_valid_types, primitive->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
|
||||
return matrix_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr MatrixSolveLsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = MatrixSolveLsInferType(primitive, input_args);
|
||||
auto infer_shape = MatrixSolveLsInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixSolveLs, prim::kPrimMatrixSolveLs, MatrixSolveLsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H
|
||||
#define MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMatrixSolveLs = "MatrixSolveLs";
|
||||
|
||||
class MIND_API MatrixSolveLs : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MatrixSolveLs);
|
||||
|
||||
/// \brief Constructor.
|
||||
MatrixSolveLs() : BaseOperator(kNameMatrixSolveLs) { InitIOName({"matrix", "rhs", "l2"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MatrixSolveLs for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MatrixSolveLsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimMatrixSolveLsPtr = std::shared_ptr<MatrixSolveLs>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_MATRIX_SOLVE_LS_H
|
|
@ -51,6 +51,7 @@ from mindspore.ops.operations.math_ops import LuUnpack
|
|||
from mindspore.ops.operations.math_ops import MatrixExp
|
||||
from mindspore.ops.operations.math_ops import CumulativeLogsumexp
|
||||
from mindspore.ops.operations.math_ops import MatrixSolve
|
||||
from mindspore.ops.operations.math_ops import MatrixSolveLs
|
||||
from mindspore.ops.operations.math_ops import MatrixPower
|
||||
from mindspore.ops.operations.math_ops import Median
|
||||
from mindspore.ops.operations.math_ops import MatrixTriangularSolve
|
||||
|
@ -74,6 +75,7 @@ from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_elemen
|
|||
from mindspore.ops._grad.grad_base import dyn_ones, dyn_fill, sum_grad_reduce_axis
|
||||
from mindspore.ops._grad.grad_math_ops import binop_grad_common
|
||||
from mindspore.ops.operations.array_ops import MatrixBandPart
|
||||
from mindspore.ops.operations.array_ops import ConjugateTranspose
|
||||
|
||||
transpose = P.Transpose()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
@ -673,6 +675,155 @@ def get_bprop_matrix_solve(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_perm_matrix_solve_ls(x_dim):
|
||||
perm = tuple(range(x_dim - 2))
|
||||
perm = perm + (x_dim-1, x_dim-2)
|
||||
return perm
|
||||
|
||||
|
||||
@bprop_getters.register(MatrixSolveLs)
|
||||
def get_bprop_matrix_solve_ls(self):
|
||||
"""Grad definition for 'MatrixSolveLs' operation"""
|
||||
fast = self.fast
|
||||
cast = P.Cast()
|
||||
neg = P.Neg()
|
||||
rank = P.Rank()
|
||||
cholesky = Cholesky()
|
||||
eye = P.Eye()
|
||||
add = P.Add()
|
||||
mul = P.Mul()
|
||||
matmul = P.MatMul()
|
||||
batch_matmul = P.BatchMatMul()
|
||||
cholesky_solve = CholeskySolve()
|
||||
_transpose = Transpose()
|
||||
conjugate_transpose = ConjugateTranspose()
|
||||
shape = P.Shape()
|
||||
_complex = P.Complex()
|
||||
|
||||
def regularized_gramian_cholesky(matrix, l2, first_kind):
|
||||
matrix_dim = rank(matrix)
|
||||
perm = _generate_perm_matrix_solve_ls(matrix_dim)
|
||||
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
||||
matrix_temp = conjugate_transpose(matrix, perm)
|
||||
else:
|
||||
matrix_temp = _transpose(matrix, perm)
|
||||
if first_kind:
|
||||
if matrix_dim > 2:
|
||||
gramian = batch_matmul(matrix_temp, matrix)
|
||||
else:
|
||||
gramian = matmul(matrix_temp, matrix)
|
||||
else:
|
||||
if matrix_dim > 2:
|
||||
gramian = batch_matmul(matrix, matrix_temp)
|
||||
else:
|
||||
gramian = matmul(matrix, matrix_temp)
|
||||
if isinstance(l2, Tensor) or l2 != 0:
|
||||
matrix_shape = shape(matrix)
|
||||
if first_kind:
|
||||
small_dim = matrix_shape[-1]
|
||||
else:
|
||||
small_dim = matrix_shape[-2]
|
||||
identity = eye(small_dim, small_dim, matrix.dtype)
|
||||
gramian = add(gramian, mul(l2, identity))
|
||||
|
||||
#Cholesky not support complex dtype for now
|
||||
return cholesky(gramian)
|
||||
|
||||
def bprop(matrix, rhs, l2, out, dout):
|
||||
#support dtype:float32
|
||||
#support dimension: 2D,3D
|
||||
def over_determined(matrix, rhs, out, l2, dout):
|
||||
if matrix.dtype == mstype.complex64:
|
||||
l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
|
||||
elif matrix.dtype == mstype.complex128:
|
||||
l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
|
||||
else:
|
||||
l2_regularizer = cast(l2, matrix.dtype)
|
||||
chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=True), matrix.dtype)
|
||||
#CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
|
||||
z = cholesky_solve(dout, chol)
|
||||
|
||||
matrix_dim = rank(matrix)
|
||||
perm = _generate_perm_matrix_solve_ls(matrix_dim)
|
||||
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
||||
z_temp = conjugate_transpose(z, perm)
|
||||
else:
|
||||
z_temp = _transpose(z, perm)
|
||||
if matrix_dim > 2:
|
||||
xzt = batch_matmul(out, z_temp)
|
||||
else:
|
||||
xzt = matmul(out, z_temp)
|
||||
zx_sym = add(xzt, _transpose(xzt, perm))
|
||||
|
||||
if matrix_dim > 2:
|
||||
grad_a = add(neg(batch_matmul(matrix, zx_sym)), batch_matmul(rhs, z_temp))
|
||||
grad_b = batch_matmul(matrix, z)
|
||||
else:
|
||||
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
|
||||
grad_b = matmul(matrix, z)
|
||||
|
||||
return (grad_a, grad_b, None)
|
||||
|
||||
def under_determined(matrix, rhs, l2, dout):
|
||||
if matrix.dtype == mstype.complex64:
|
||||
l2_regularizer = _complex(cast(l2, mstype.float32), Tensor(0, mstype.float32))
|
||||
elif matrix.dtype == mstype.complex128:
|
||||
l2_regularizer = _complex(cast(l2, mstype.float64), Tensor(0, mstype.float64))
|
||||
else:
|
||||
l2_regularizer = cast(l2, matrix.dtype)
|
||||
chol = cast(regularized_gramian_cholesky(matrix, l2_regularizer, first_kind=False), matrix.dtype)
|
||||
|
||||
matrix_dim = rank(matrix)
|
||||
perm = _generate_perm_matrix_solve_ls(matrix_dim)
|
||||
if matrix_dim > 2:
|
||||
gramian = batch_matmul(matrix, dout)
|
||||
else:
|
||||
gramian = matmul(matrix, dout)
|
||||
#CholeskySolve not support complex dtype and just support 2D or 3D matrices for now
|
||||
grad_b = cholesky_solve(gramian, chol)
|
||||
tmp = cholesky_solve(rhs, chol)
|
||||
|
||||
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
||||
tmp_temp = conjugate_transpose(tmp, perm)
|
||||
matrix_temp = conjugate_transpose(matrix, perm)
|
||||
else:
|
||||
tmp_temp = _transpose(tmp, perm)
|
||||
matrix_temp = _transpose(matrix, perm)
|
||||
if matrix_dim > 2:
|
||||
a1 = batch_matmul(tmp_temp, matrix)
|
||||
a1 = neg(batch_matmul(grad_b, a1))
|
||||
a2 = dout - batch_matmul(matrix_temp, grad_b)
|
||||
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
||||
a2_temp = conjugate_transpose(a2, perm)
|
||||
else:
|
||||
a2_temp = _transpose(a2, perm)
|
||||
a2 = batch_matmul(tmp, a2_temp)
|
||||
else:
|
||||
a1 = matmul(tmp_temp, matrix)
|
||||
a1 = neg(matmul(grad_b, a1))
|
||||
a2 = dout - matmul(matrix_temp, grad_b)
|
||||
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
||||
a2_temp = conjugate_transpose(a2, perm)
|
||||
else:
|
||||
a2_temp = _transpose(a2, perm)
|
||||
a2 = matmul(tmp, a2_temp)
|
||||
|
||||
grad_a = add(a1, a2)
|
||||
return (grad_a, grad_b, None)
|
||||
|
||||
if fast is False:
|
||||
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
|
||||
matrix_shape = shape(matrix)[-2:]
|
||||
|
||||
if matrix_shape[-2] >= matrix_shape[-1]:
|
||||
return over_determined(matrix, rhs, out, l2, dout)
|
||||
|
||||
return under_determined(matrix, rhs, l2, dout)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MatrixDeterminant)
|
||||
def get_bprop_matrix_determinant(self):
|
||||
"""Generate bprop for MatrixDeterminant"""
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MatrixSolveLs op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
matrix_solve_ls_op_info = AiCPURegOp("MatrixSolveLs") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("fast", "bool") \
|
||||
.input(0, "matrix", "required") \
|
||||
.input(1, "rhs", "required") \
|
||||
.input(2, "l2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.F64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.F64_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(matrix_solve_ls_op_info)
|
||||
def _matrix_solve_ls_aicpu():
|
||||
"""MatrixSolveLs aicpu register"""
|
||||
return
|
|
@ -1518,9 +1518,8 @@ class MatMul(PrimitiveWithCheck):
|
|||
|
||||
def check_dtype(self, x1, x2):
|
||||
args = {"x1": x1, "x2": x2}
|
||||
validator.check_tensors_dtypes_same_and_valid(args,
|
||||
mstype.float_type + mstype.int_type + (mstype.complex64,),
|
||||
self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type
|
||||
+ (mstype.complex64, mstype.complex128), self.name)
|
||||
|
||||
|
||||
class BatchMatMul(Primitive):
|
||||
|
@ -6264,6 +6263,57 @@ class MatrixSolve(Primitive):
|
|||
self.adjoint = validator.check_value_type("adjoint", adjoint, [bool], self.name)
|
||||
|
||||
|
||||
class MatrixSolveLs(Primitive):
|
||||
r"""
|
||||
Solves one or more linear least-squares problems.
|
||||
|
||||
If `fast` is `True`,then the solution is computed by solving the normal equations using Cholesky decomposition.
|
||||
If `fast` is `False` an algorithm based on the numerically robust complete orthogonal decomposition is used. This
|
||||
path is typically 6-7 times slower than the fast path. If `fast` is `False` then `l2_regularizer` is ignored.
|
||||
|
||||
Args:
|
||||
fast (bool): An optional bool. Defaults to True.
|
||||
|
||||
Inputs:
|
||||
- **matrix** (Tensor) - A Tensor. Must be one of the following data types: float64, float32, complex64,
|
||||
complex128. Shape is :math:`(*, M, N)`.
|
||||
- **rhs** (Tensor) - A Tensor. Must have the same data type as matrix. Shape is :math:`(*, M, K)`.
|
||||
`matrix` and `rhs` should have the same dimensions except the last one.
|
||||
- **l2_regularizer** (Tensor) - A Tensor of type float64. Scalar tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(*, N, K)` with the same data type as `matrix`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `matrix`, `rhs` or `l2_regularizer` is not tensor.
|
||||
TypeError: If either of `matrix` and `rhs` is not float32, float64, complex64 or complex128.
|
||||
TypeError: If `l2_regularizer` is not float64.
|
||||
TypeError: If `fast` is not bool.
|
||||
ValueError: If dimensions of `matrix` or `rhs` is less than 2.
|
||||
ValueError: If shape of `matrix` dose not match the shape of `rhs`.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> matrix_solve_ls = ops.MatrixSolveLs(fast=True)
|
||||
>>> matrix = Tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], mstype.float32)
|
||||
>>> rhs = Tensor(np.array([[4], [2], [4], [2]]), mstype.float32)
|
||||
>>> l2 = Tensor(0.0, mstype.float64)
|
||||
>>> output = matrix_solve_ls(matrix, rhs, l2)
|
||||
>>> print(output)
|
||||
[[ 1.3333334]
|
||||
[-0.6666667]
|
||||
[ 2.6666665]
|
||||
[-1.3333333]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, fast=True):
|
||||
"""Initialize MatrixSolveLs"""
|
||||
validator.check_value_type('fast', fast, [bool], self.name)
|
||||
|
||||
|
||||
class LuSolve(Primitive):
|
||||
"""
|
||||
Return the solution of the linear equation Ax = b.
|
||||
|
|
|
@ -51,6 +51,7 @@ from mindspore.ops.operations.math_ops import MatrixExp
|
|||
from mindspore.ops.operations.math_ops import FFTWithSize
|
||||
from mindspore.ops.operations.math_ops import MatrixPower
|
||||
from mindspore.ops.operations.math_ops import MatrixSolve
|
||||
from mindspore.ops.operations.math_ops import MatrixSolveLs
|
||||
from mindspore.ops.operations.math_ops import MatrixLogarithm
|
||||
from mindspore.ops.operations.math_ops import CholeskySolve
|
||||
from mindspore.ops.operations.math_ops import InplaceIndexAdd
|
||||
|
@ -2484,6 +2485,12 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[[1., 4.], [2., 7.]], [[1., 4.], [2., 7.]]]).astype(np.float32)),
|
||||
Tensor(np.array([[[1.], [3.]], [[1.], [3.]]]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[[1.], [1.]], [[1.], [1.]]]).astype(np.float32))]}),
|
||||
('MatrixSolveLs', {
|
||||
'block': MatrixSolveLs(fast=True),
|
||||
'desc_inputs': [Tensor(np.array([[[1., 4.], [2., 7.]], [[1., 4.], [2., 7.]]]).astype(np.float32)),
|
||||
Tensor(np.array([[[1.], [3.]], [[1.], [3.]]]).astype(np.float32)),
|
||||
Tensor(np.random.uniform(0.0, 5.0), mstype.float64)],
|
||||
'desc_bprop': [Tensor(np.array([[[1.], [1.]], [[1.], [1.]]]).astype(np.float32))]}),
|
||||
('MatrixDeterminant', {
|
||||
'block': P.MatrixDeterminant(),
|
||||
'desc_inputs': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))],
|
||||
|
|
Loading…
Reference in New Issue