!22754 [assistant][ops] Add math operators MatrixInverse, MatrixDeterminant and LogMatrixDeterminant

Merge pull request !22754 from 孟权令/MatrixInverse
This commit is contained in:
i-robot 2022-01-14 06:56:30 +00:00 committed by Gitee
commit fd5d5f1d21
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
21 changed files with 1128 additions and 19 deletions

View File

@ -0,0 +1,166 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/log_matrix_determinant_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 2;
constexpr int64_t kParallelDataNums = 8 * 1024;
static constexpr int kNumber0 = 0;
static constexpr int kNumber1 = 1;
static constexpr int kNumber2 = 2;
} // namespace
void LogMatrixDeterminantCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_wpt_ = kernel_node;
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputSize, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputSize, kernel_name_);
auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape_sign = AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto shape_y = AnfAlgo::GetOutputInferShape(kernel_node, 1);
size_t shape_size_x = shape_x.size();
size_t shape_size_sign = shape_sign.size();
size_t shape_size_y = shape_y.size();
if (shape_size_x < kNumber2) {
MS_LOG(EXCEPTION) << "Input x must be at least rank 2.";
}
if (shape_x[shape_size_x - kNumber1] < kNumber1) {
MS_LOG(EXCEPTION) << "Input x last dimension must be at least 1.";
}
if (shape_x[shape_size_x - kNumber2] != shape_x[shape_size_x - kNumber1]) {
MS_LOG(EXCEPTION) << "The last two dimensions of Input x should be equal.";
}
if (shape_size_sign != shape_size_x - kNumber2) {
MS_LOG(EXCEPTION) << "Output sign must be rank [" << shape_size_x - kNumber2 << "], got [" << shape_size_sign
<< "].";
}
if (shape_size_y != shape_size_x - kNumber2) {
MS_LOG(EXCEPTION) << "Output y must be rank [" << shape_size_x - kNumber2 << "], got [" << shape_size_y << "].";
}
for (size_t i = kNumber0; i < shape_size_x - kNumber2; i++) {
if (shape_sign[i] != shape_x[i]) {
MS_LOG(EXCEPTION) << "Output sign and Input x dimension " << i << " must be equal.";
}
if (shape_y[i] != shape_x[i]) {
MS_LOG(EXCEPTION) << "Output y and Input x dimension " << i << " must be equal.";
}
}
}
bool LogMatrixDeterminantCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchLogMatrixDeterminant<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchLogMatrixDeterminant<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex64) {
LaunchLogMatrixDeterminant<std::complex<float>>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex128) {
LaunchLogMatrixDeterminant<std::complex<double>>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "LogMatrixDeterminant kernel data type " << TypeIdLabel(dtype_) << " not support.";
}
return true;
}
template <typename T>
void LogMatrixDeterminantCPUKernel::LaunchLogMatrixDeterminant(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
auto input_x = reinterpret_cast<T *>(inputs[0]->addr);
auto output_sign = reinterpret_cast<T *>(outputs[0]->addr);
auto output_y = reinterpret_cast<T *>(outputs[1]->addr);
auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
size_t shape_size = shape_x.size();
size_t m = shape_x[shape_size - 1];
size_t size_mm = m * m;
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MartixXd;
using RealT = typename Eigen::NumTraits<T>::Real;
if (size_mm > 0) {
size_t input_num = 1;
for (size_t i = 0; i < shape_x.size(); i++) {
input_num *= shape_x[i];
}
size_t matrix_num = input_num / size_mm;
size_t data_size = input_num * sizeof(T);
if (data_size <= kParallelDataNums) {
for (size_t i = 0; i < matrix_num; i++) {
RealT log_abs_det = 0;
T sign = 1;
Eigen::Map<MartixXd> martix_x(input_x + i * m * m, m, m);
if (martix_x.size() > 0) {
Eigen::PartialPivLU<MartixXd> lu(martix_x);
MartixXd LU = lu.matrixLU();
sign = lu.permutationP().determinant();
auto diag = LU.diagonal().array().eval();
auto abs_diag = diag.cwiseAbs().eval();
auto abs_diag_inverse = abs_diag.cwiseInverse();
log_abs_det += abs_diag.log().sum();
sign *= (diag * abs_diag_inverse).prod();
}
if (!Eigen::numext::isfinite(log_abs_det)) {
sign = 0;
log_abs_det = log_abs_det > 0 ? -std::log(RealT(0)) : std::log(RealT(0));
}
*(output_sign + i) = sign;
*(output_y + i) = log_abs_det;
}
} else {
auto task = [this, &m, input_x, output_sign, output_y](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
RealT log_abs_det = 0;
T sign = 1;
Eigen::Map<MartixXd> martix_x(input_x + i * m * m, m, m);
if (martix_x.size() > 0) {
Eigen::PartialPivLU<MartixXd> lu(martix_x);
MartixXd LU = lu.matrixLU();
sign = lu.permutationP().determinant();
auto diag = LU.diagonal().array().eval();
auto abs_diag = diag.cwiseAbs().eval();
auto abs_diag_inverse = abs_diag.cwiseInverse();
log_abs_det += abs_diag.log().sum();
sign *= (diag * abs_diag_inverse).prod();
}
if (!Eigen::numext::isfinite(log_abs_det)) {
sign = 0;
log_abs_det = log_abs_det > 0 ? -std::log(RealT(0)) : std::log(RealT(0));
}
*(output_sign + i) = sign;
*(output_y + i) = log_abs_det;
}
};
CPUKernelUtils::ParallelFor(task, matrix_num);
}
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOG_MATRIX_DETERMINANT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOG_MATRIX_DETERMINANT_CPU_KERNEL_H_
#include <complex>
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LogMatrixDeterminantCPUKernel : public CPUKernel {
public:
LogMatrixDeterminantCPUKernel() = default;
~LogMatrixDeterminantCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
CNodeWeakPtr node_wpt_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchLogMatrixDeterminant(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
MS_REG_CPU_KERNEL(
LogMatrixDeterminant,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
LogMatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(
LogMatrixDeterminant,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
LogMatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(LogMatrixDeterminant,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
LogMatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(LogMatrixDeterminant,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
LogMatrixDeterminantCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOG_MATRIX_DETERMINANT_CPU_KERNEL_H_

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/matrix_determinant_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "Eigen/Core"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
static constexpr int kNumber0 = 0;
static constexpr int kNumber1 = 1;
static constexpr int kNumber2 = 2;
} // namespace
void MatrixDeterminantCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_wpt_ = kernel_node;
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputSize, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputSize, kernel_name_);
}
bool MatrixDeterminantCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchMatrixDeterminant<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchMatrixDeterminant<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex64) {
LaunchMatrixDeterminant<std::complex<float>>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex128) {
LaunchMatrixDeterminant<std::complex<double>>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "MatrixDeterminant kernel data type " << TypeIdLabel(dtype_) << " not support.";
}
return true;
}
template <typename T>
void MatrixDeterminantCPUKernel::LaunchMatrixDeterminant(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
T *input = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output);
// Check if it's a square array
auto dims = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
if (dims.size() < kNumber2) {
MS_LOG(EXCEPTION) << "Input x must be at least rank 2.";
}
if (dims[dims.size() - kNumber1] != dims[dims.size() - kNumber2]) {
MS_LOG(EXCEPTION) << "The last two dimensions of Input x should be equal.";
}
size_t m = dims[dims.size() - 1];
size_t n = 1;
for (size_t i = kNumber0; i < dims.size() - kNumber2; i++) {
n *= dims[i];
}
auto task = [this, &m, input, output](size_t start, size_t end) {
for (size_t k = start; k < end; k++) {
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> eMatrix(m, m);
for (size_t i = 0; i < m; i++) {
for (size_t j = 0; j < m; j++) {
eMatrix(i, j) = *(input + k * m * m + i * m + j);
}
}
// use eigen to calculate determinant
T result = eMatrix.determinant();
*(output + k) = result;
}
};
CPUKernelUtils::ParallelFor(task, n);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DETERMINANT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DETERMINANT_CPU_KERNEL_H_
#include <complex>
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MatrixDeterminantCPUKernel : public CPUKernel {
public:
MatrixDeterminantCPUKernel() = default;
~MatrixDeterminantCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
CNodeWeakPtr node_wpt_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchMatrixDeterminant(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
MS_REG_CPU_KERNEL(MatrixDeterminant, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(MatrixDeterminant, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(MatrixDeterminant,
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
MatrixDeterminantCPUKernel);
MS_REG_CPU_KERNEL(MatrixDeterminant,
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
MatrixDeterminantCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DETERMINANT_CPU_KERNEL_H_

View File

@ -0,0 +1,128 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/matrix_inverse_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "Eigen/Core"
#include "Eigen/LU"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputSize = 1;
constexpr size_t kOutputSize = 1;
static constexpr int kNumber1 = 1;
static constexpr int kNumber2 = 2;
constexpr size_t kParallelDataNums = 1 * 1024;
} // namespace
void MatrixInverseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_wpt_ = kernel_node;
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
}
bool MatrixInverseCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputSize, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputSize, kernel_name_);
if (dtype_ == kNumberTypeFloat32) {
LaunchMatrixInverse<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchMatrixInverse<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex64) {
LaunchMatrixInverse<std::complex<float>>(inputs, outputs);
} else if (dtype_ == kNumberTypeComplex128) {
LaunchMatrixInverse<std::complex<double>>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "MatrixInverse kernel data type " << TypeIdLabel(dtype_) << " not support.";
}
return true;
}
template <typename T>
void MatrixInverseCPUKernel::LaunchMatrixInverse(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
T *input_ptr = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input_ptr);
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_ptr);
// Judge whether the input shape matches
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
if (shape.size() < kNumber2) {
MS_LOG(EXCEPTION) << "Input x must be at least rank 2.";
}
if (shape[shape.size() - kNumber1] != shape[shape.size() - kNumber2]) {
MS_LOG(EXCEPTION) << "The last two dimensions of Input x should be equal.";
}
auto last_dimsize = shape[shape.size() - 1];
// Output length
size_t input_num = 1;
for (size_t i = 0; i < shape.size(); i++) {
input_num *= shape[i];
}
size_t matrix_size = last_dimsize * last_dimsize;
// Number of matrices
size_t matrix_num = input_num / matrix_size;
// Store two-dimensional array of data for slicing
std::vector<std::vector<T>> temp(matrix_num, std::vector<T>(matrix_size));
for (size_t i = 0; i < matrix_num; i++) {
for (size_t j = 0; j < matrix_size; j++) {
temp[i][j] = *(input_ptr + i * matrix_size + j);
}
}
// Gets the value of the property adjoint
adjoint_ = AnfAlgo::GetNodeAttr<bool>(node_, "adjoint");
auto one_size = sizeof(*input_ptr);
if ((one_size * input_num) <= kParallelDataNums) {
for (size_t i = 0; i < matrix_num; i++) {
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> eigen_input(temp[i].data(), last_dimsize,
last_dimsize);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> eigen_output(output_ptr + i * matrix_size,
last_dimsize, last_dimsize);
if (adjoint_) {
eigen_input = eigen_input.adjoint().eval();
}
Eigen::FullPivLU<Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>> lu(eigen_input);
eigen_output = lu.inverse();
}
} else {
auto task = [this, &last_dimsize, &matrix_size, &temp, output_ptr](size_t start, size_t end) {
for (auto i = start; i < end; i++) {
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> eigen_input(temp[i].data(), last_dimsize,
last_dimsize);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> eigen_output(output_ptr + i * matrix_size,
last_dimsize, last_dimsize);
if (adjoint_) {
eigen_input = eigen_input.adjoint().eval();
}
Eigen::FullPivLU<Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>> lu(eigen_input);
eigen_output = lu.inverse();
}
};
CPUKernelUtils::ParallelFor(task, matrix_num);
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_INVERSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_INVERSE_CPU_KERNEL_H_
#include <complex>
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MatrixInverseCPUKernel : public CPUKernel {
public:
MatrixInverseCPUKernel() = default;
~MatrixInverseCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
CNodeWeakPtr node_wpt_;
bool adjoint_{false};
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchMatrixInverse(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
MS_REG_CPU_KERNEL(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatrixInverseCPUKernel);
MS_REG_CPU_KERNEL(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MatrixInverseCPUKernel);
MS_REG_CPU_KERNEL(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
MatrixInverseCPUKernel);
MS_REG_CPU_KERNEL(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
MatrixInverseCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_INVERSE_CPU_KERNEL_H_

View File

@ -65,6 +65,9 @@ constexpr auto kAdd = "Add";
constexpr auto kBiasAdd = "BiasAdd";
constexpr auto kTile = "Tile";
constexpr auto kBiasAddGrad = "BiasAddGrad";
constexpr auto kMatrixInverse = "MatrixInverse";
constexpr auto kMatrixDeterminant = "MatrixDeterminant";
constexpr auto kLogMatrixDeterminant = "LogMatrixDeterminant";
constexpr auto kCos = "Cos";
constexpr auto kAbs = "Abs";
constexpr auto kTrunc = "Trunc";
@ -622,6 +625,9 @@ inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"
inline const PrimitivePtr kPrimCdist = std::make_shared<Primitive>(kCdist);
inline const PrimitivePtr kPrimCdistGrad = std::make_shared<Primitive>(kCdistGrad);
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
inline const PrimitivePtr kPrimMatrixInverse = std::make_shared<Primitive>(kMatrixInverse);
inline const PrimitivePtr kPrimMatrixDeterminant = std::make_shared<Primitive>(kMatrixDeterminant);
inline const PrimitivePtr kPrimLogMatrixDeterminant = std::make_shared<Primitive>(kLogMatrixDeterminant);
inline const PrimitivePtr kPrimIndexAdd = std::make_shared<Primitive>("IndexAdd");
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
inline const PrimitivePtr kPrimInvGrad = std::make_shared<Primitive>("InvGrad");

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/log_matrix_determinant.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
const constexpr int64_t kNumber1 = 1;
const constexpr int64_t kNumber2 = 2;
CheckAndConvertUtils::CheckInteger("x rank", x_rank, kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::Check("row size", x_shape[x_rank - kNumber1], kEqual, "column size", x_shape[x_rank - kNumber2],
prim_name);
CheckAndConvertUtils::CheckInteger("row size", x_shape[x_rank - kNumber1], kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::CheckInteger("column size", x_shape[x_rank - kNumber2], kGreaterEqual, kNumber2, prim_name);
std::vector<int64_t> shape(x_shape.begin(), (x_shape.end() - kNumber2));
abstract::ShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
}
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat32};
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, x_type});
}
} // namespace
AbstractBasePtr LogMatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infertype = InferType(primitive, input_args);
auto infershape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LogMatrixDeterminant, prim::kPrimLogMatrixDeterminant, LogMatrixDeterminantInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LOG_MATRIX_DETERMINANT_H_
#define MINDSPORE_CORE_OPS_LOG_MATRIX_DETERMINANT_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLogMatrixDeterminant = "LogMatrixDeterminant";
class LogMatrixDeterminant : public PrimitiveC {
public:
LogMatrixDeterminant() : PrimitiveC(kNameLogMatrixDeterminant) { InitIOName({"x"}, {"sign", "output"}); }
~LogMatrixDeterminant() = default;
MS_DECLARE_PARENT(LogMatrixDeterminant, PrimitiveC);
};
AbstractBasePtr LogMatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimLogMatrixDeterminantPtr = std::shared_ptr<LogMatrixDeterminant>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LOG_MATRIX_DETERMINANT_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_determinant.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
const constexpr int64_t kNumber1 = 1;
const constexpr int64_t kNumber2 = 2;
CheckAndConvertUtils::CheckInteger("x rank", x_rank, kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::Check("row size", x_shape[x_rank - kNumber1], kEqual, "column size", x_shape[x_rank - kNumber2],
prim_name);
CheckAndConvertUtils::CheckInteger("row size", x_shape[x_rank - kNumber1], kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::CheckInteger("column size", x_shape[x_rank - kNumber2], kGreaterEqual, kNumber2, prim_name);
std::vector<int64_t> out_shape(x_shape.begin(), (x_shape.end() - kNumber2));
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat32};
auto infer_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
AbstractBasePtr MatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infertype = InferType(primitive, input_args);
auto infershape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDeterminant, prim::kPrimMatrixDeterminant, MatrixDeterminantInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_DETERMINANT_H_
#define MINDSPORE_CORE_OPS_MATRIX_DETERMINANT_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixDeterminant = "MatrixDeterminant";
class MatrixDeterminant : public PrimitiveC {
public:
MatrixDeterminant() : PrimitiveC(kNameMatrixDeterminant) { InitIOName({"x"}, {"y"}); }
~MatrixDeterminant() = default;
MS_DECLARE_PARENT(MatrixDeterminant, PrimitiveC);
};
AbstractBasePtr MatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMatrixDeterminantPtr = std::shared_ptr<MatrixDeterminant>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_DETERMINANT_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_inverse.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
const constexpr int64_t kNumber1 = 1;
const constexpr int64_t kNumber2 = 2;
CheckAndConvertUtils::CheckInteger("x rank", x_rank, kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::Check("row size", x_shape[x_rank - kNumber1], kEqual, "column size", x_shape[x_rank - kNumber2],
prim_name);
CheckAndConvertUtils::CheckInteger("row size", x_shape[x_rank - kNumber1], kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::CheckInteger("column size", x_shape[x_rank - kNumber2], kGreaterEqual, kNumber2, prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
auto infer_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
AbstractBasePtr MatrixInverseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infertype = InferType(primitive, input_args);
auto infershape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixInverse, prim::kPrimMatrixInverse, MatrixInverseInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_INVERSE_H_
#define MINDSPORE_CORE_OPS_MATRIX_INVERSE_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixInverse = "MatrixInverse";
class MatrixInverse : public PrimitiveC {
public:
MatrixInverse() : PrimitiveC(kNameMatrixInverse) { InitIOName({"x"}, {"y"}); }
~MatrixInverse() = default;
MS_DECLARE_PARENT(MatrixInverse, PrimitiveC);
};
AbstractBasePtr MatrixInverseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMatrixInversePtr = std::shared_ptr<MatrixInverse>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_INVERSE_H_

View File

@ -16,6 +16,7 @@
"""Define the grad rules of math related operations."""
from mindspore.common import dtype as mstype
from mindspore import nn
import mindspore.numpy as mnp
import numpy as np
from .. import functional as F
@ -148,6 +149,54 @@ def get_bprop_lp_norm(self):
return bprop
@bprop_getters.register(P.MatrixInverse)
def get_bprop_matrix_inverse(self):
"""Generate bprop for MatrixInverse"""
matmul_x1 = nn.MatMul(transpose_x1=True)
matmul_x2 = nn.MatMul(transpose_x2=True)
neg = P.Neg()
def bprop(x, out, dout):
dx = matmul_x2(dout, out)
dx = matmul_x1(out, dx)
dx = neg(dx)
return (dx,)
return bprop
@bprop_getters.register(P.MatrixDeterminant)
def get_bprop_matrix_determinant(self):
"""Generate bprop for MatrixDeterminant"""
inverse_op = P.MatrixInverse(adjoint=True)
shape_op = P.Shape()
reshape = P.Reshape()
def bprop(x, out, dout):
x_adj_inv = inverse_op(x)
multipliers = reshape(dout * out, shape_op(out) + (1, 1))
dx = multipliers * x_adj_inv
return (dx,)
return bprop
@bprop_getters.register(P.LogMatrixDeterminant)
def get_bprop_log_matrix_determinant(self):
"""Generate bprop for LogMatrixDeterminant"""
inverse_op = P.MatrixInverse(adjoint=True)
shape_op = P.Shape()
reshape = P.Reshape()
def bprop(x, out, dout):
x_adj_inv = inverse_op(x)
multipliers = reshape(dout[1], shape_op(out[1]) + (1, 1))
dx = multipliers * x_adj_inv
return (dx,)
return bprop
@bprop_getters.register(P.Erfinv)
def get_bprop_erfinv(self):
"""Grad definition for `Erfinv` operation."""

View File

@ -50,6 +50,9 @@ from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
from .compute_accidental_hits import _compute_accidental_hits_aicpu
from .ctcloss import _ctcloss_aicpu
from .reverse_sequence import _reverse_sequence_aicpu
from .matrix_inverse import _matrix_inverse_aicpu
from .matrix_determinant import _matrix_determinant_aicpu
from .log_matrix_determinant import _log_matrix_determinant_aicpu
from .crop_and_resize import _crop_and_resize_aicpu
from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LogMatrixDeterminant op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
log_matrix_determinant_op_info = AiCPURegOp("LogMatrixDeterminant") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "sign", "required") \
.output(1, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(log_matrix_determinant_op_info)
def _log_matrix_determinant_aicpu():
"""LogMatrixDeterminant aicpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixDeterminant op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_determinant_op_info = AiCPURegOp("MatrixDeterminant") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(matrix_determinant_op_info)
def _matrix_determinant_aicpu():
"""MatrixDeterminant aicpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixInverse op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_inverse_op_info = AiCPURegOp("MatrixInverse") \
.fusion_type("OPAQUE") \
.attr("adjoint", "bool")\
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(matrix_inverse_op_info)
def _matrix_inverse_aicpu():
"""MatrixInverse aicpu register"""
return

View File

@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
@ -476,6 +476,8 @@ __all__ = [
"SparseTensorDenseMatmul",
"SparseApplyAdadelta",
"MatrixInverse",
"MatrixDeterminant",
"LogMatrixDeterminant",
"Range",
"SearchSorted",
"IndexAdd",

View File

@ -5314,7 +5314,7 @@ class LinSpace(PrimitiveWithInfer):
return out
class MatrixInverse(PrimitiveWithInfer):
class MatrixInverse(Primitive):
"""
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
result may be returned.
@ -5323,20 +5323,19 @@ class MatrixInverse(PrimitiveWithInfer):
The parameter 'adjoint' is only supporting False right now. Because complex number is not supported at present.
Args:
adjoint (bool) : Whether to support complex matrix. False means that complex matrix is not supported.
Default: False.
adjoint (bool) : An optional bool. Default: False.
Inputs:
- **x** (Tensor) - A matrix to be calculated. The matrix must be at least two dimensions, and the last two
dimensions must be the same size. dtypes: float32, float64.
dimensions must be the same size.
Outputs:
Tensor, has the same type and shape as input `x`.
Raises:
TypeError: If `adjoint` is not a bool.
TypeError: If dtype of `x` is neither float32 nor float64.
ValueError: If the last two dimensions of `x` is not the same size.
TypeError: If `x` is not a Tensor.
ValueError: If the last two dimensions of `x` is not same size.
ValueError: If the dimension of `x` is less than 2.
Supported Platforms:
@ -5350,8 +5349,8 @@ class MatrixInverse(PrimitiveWithInfer):
>>> matrix_inverse = ops.MatrixInverse(adjoint=False)
>>> output = matrix_inverse(x)
>>> print(output)
[[[ 2.4095483 -1.536419 ]
[-2.4197974 0.97401696]]
[[[ 2.4095478 -1.5364188 ]
[-2.419797 0.9740167 ]]
[[-0.79111797 1.0569006 ]
[ 0.74180895 -0.2904787 ]]]
"""
@ -5359,18 +5358,77 @@ class MatrixInverse(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, adjoint=False):
"""Initialize MatrixInverse"""
validator.check_type_name("adjoint", adjoint, False, self.name)
self.adjoint = adjoint
self.init_prim_io_names(inputs=['x'], outputs=['y'])
validator.check_value_type('adjoint', adjoint, [bool], self.name)
def infer_dtype(self, x_dtype):
valid_type = [mstype.float32, mstype.double]
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 2, Rel.GE, self.name, None)
validator.check_equal_int(x_shape[-1], x_shape[-2], self.name, None)
return x_shape
class MatrixDeterminant(Primitive):
"""
Computes the determinant of one or more square matrices.
Inputs:
- **x** (Tensor) - A matrix to be calculated. The matrix must be at least two dimensions, and the last two
dimensions must be the same size.
Outputs:
Tensor, the shape is `x_shape[:-2]`, the dtype is same as `x`.
Raises:
TypeError: If `x` is not a Tensor.
ValueError: If the last two dimensions of `x` is not same size.
ValueError: If the dimension of `x` is less than 2.
Supported Platforms:
``CPU``
Examples:
>>> input_x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]), mindspore.float32)
>>> op = P.MatrixDeterminant()
>>> output = op(input_x)
>>> print(output)
[-16.5 21. ]
"""
@prim_attr_register
def __init__(self):
"""Initialize MatrixDeterminant."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
class LogMatrixDeterminant(Primitive):
"""
Computes the sign and the log of the absolute value of the determinant of one or more square matrices.
Inputs:
- **x** (Tensor) - A matrix to be calculated. The matrix must be at least two dimensions, and the last two
dimensions must be the same size.
Outputs:
- **sign** (Tensor) - The signs of the log determinants. The shape is `x_shape[:-2]`, the dtype is same as `x`.
- **y** (Tensor) - The absolute values of the log determinants. The shape is `x_shape[:-2]`, the dtype is same
as `x`.
Raises:
TypeError: If `x` is not a Tensor.
ValueError: If the last two dimensions of `x` is not same size.
ValueError: If the dimension of `x` is less than 2.
Supported Platforms:
``CPU``
Examples:
>>> input_x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]), mindspore.float32)
>>> op = P.LogMatrixDeterminant()
>>> output = op(input_x)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value= [-1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32,
value= [ 2.80336046e+00, 3.04452229e+00]))
"""
@prim_attr_register
def __init__(self):
"""Initialize LogMatrixDeterminant."""
self.init_prim_io_names(inputs=['x'], outputs=['sign', 'y'])
class IndexAdd(Primitive):

View File

@ -1783,6 +1783,18 @@ test_case_math_ops = [
Tensor(np.random.rand(4).astype(np.int32))],
'desc_bprop': [],
'skip': ['backward']}),
('MatrixInverse', {
'block': P.MatrixInverse(),
'desc_inputs': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))]}),
('MatrixDeterminant', {
'block': P.MatrixDeterminant(),
'desc_inputs': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([1.0, 2.0]).astype(np.float32))]}),
('LogMatrixDeterminant', {
'block': P.LogMatrixDeterminant(),
'desc_inputs': [Tensor(np.array([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]).astype(np.float32))],
'desc_bprop': [(Tensor(np.array([1, 2]).astype(np.float32)), Tensor(np.array([1, 2]).astype(np.float32)))]}),
('Erfinv', {
'block': P.Erfinv(),
'desc_inputs': [Tensor(np.array([0.1, 0.1, 0.1]).astype(np.float16))],