!30279 Rename matrix_triangular_solve

Merge pull request !30279 from zhujingxuan/master
This commit is contained in:
i-robot 2022-02-21 07:28:23 +00:00 committed by Gitee
commit 52574812b3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 77 additions and 58 deletions

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h"
#include "plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h"
#include <Eigen/Dense>
#include <vector>
#include <string>
@ -38,7 +38,7 @@ constexpr auto kAMatrixDimNum = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
template <typename T>
void SolveTriangularCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) {
void MatrixTriangularSolveCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) {
auto a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
// Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid.
@ -59,20 +59,30 @@ void SolveTriangularCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) {
}
template <typename T>
void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void MatrixTriangularSolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
InitShape(kernel_node);
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
unit_diagonal_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL);
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS);
if (trans == "N") {
trans_ = false;
} else if (trans == "T") {
trans_ = true;
} else if (trans == "C") {
trans_ = true;
if (AnfAlgo::HasNodeAttr(ADJOINT, kernel_node)) {
// MatrixTriangularSolve attribute
trans_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT);
if (AnfAlgo::HasNodeAttr(TRANS, kernel_node)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
}
} else {
MS_LOG(EXCEPTION) << "Trans should be in [N, T, C], but got [" << trans << "].";
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
unit_diagonal_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL);
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS);
if (trans == "N") {
trans_ = false;
} else if (trans == "T") {
trans_ = true;
} else if (trans == "C") {
trans_ = true;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' should be in ['N', 'T', 'C'], but got [" << trans
<< "].";
}
}
}
@ -96,9 +106,9 @@ inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b
}
template <typename T>
bool SolveTriangularCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> &outputs) {
bool MatrixTriangularSolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_);

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
@ -24,10 +24,10 @@
namespace mindspore {
namespace kernel {
template <typename T>
class SolveTriangularCpuKernelMod : public NativeCpuKernelMod {
class MatrixTriangularSolveCpuKernelMod : public NativeCpuKernelMod {
public:
SolveTriangularCpuKernelMod() = default;
~SolveTriangularCpuKernelMod() override = default;
MatrixTriangularSolveCpuKernelMod() = default;
~MatrixTriangularSolveCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -47,12 +47,20 @@ class SolveTriangularCpuKernelMod : public NativeCpuKernelMod {
MS_REG_CPU_KERNEL_T(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SolveTriangularCpuKernelMod, float)
MatrixTriangularSolveCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SolveTriangularCpuKernelMod, double)
MatrixTriangularSolveCpuKernelMod, double)
MS_REG_CPU_KERNEL_T(
MatrixTriangularSolve,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatrixTriangularSolveCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(
MatrixTriangularSolve,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MatrixTriangularSolveCpuKernelMod, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_

View File

@ -14,17 +14,17 @@
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h"
#include "plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SolveTriangularGpuKernelMod, float)
MatrixTriangularSolveGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
SolveTriangular,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SolveTriangularGpuKernelMod, double)
MatrixTriangularSolveGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <type_traits>
@ -39,10 +39,10 @@ constexpr size_t kIndexBBuffer = 2;
constexpr size_t kIndexBTransposeShape = 3;
constexpr size_t kIndexBTransposeAxis = 4;
template <typename T>
class SolveTriangularGpuKernelMod : public NativeGpuKernelMod {
class MatrixTriangularSolveGpuKernelMod : public NativeGpuKernelMod {
public:
SolveTriangularGpuKernelMod() = default;
~SolveTriangularGpuKernelMod() = default;
MatrixTriangularSolveGpuKernelMod() = default;
~MatrixTriangularSolveGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
@ -159,32 +159,33 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod {
lda_ = SizeToInt(m_);
ldb_ = SizeToInt(m_);
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
// converting row major to col major is the same as reverting the trans flag
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
if (AnfAlgo::HasNodeAttr("adjoint", kernel_node)) {
// MatrixTriangularSolve attribute
bool trans = AnfAlgo::GetNodeAttr<bool>(kernel_node, "adjoint");
// converting row major to col major is the same as reverting the trans flag
trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T;
if (AnfAlgo::HasNodeAttr("trans", kernel_node)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the attribute 'adjoint' and 'trans' could not exist at the same time.";
}
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "].";
}
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
// reverting the trans flag by default, so also flip the lower flag
lower = !lower;
if (lower) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
}
bool unit_diagonal = AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal");
if (unit_diagonal) {
unit_diagonal_ = CUBLAS_DIAG_UNIT;
} else {
unit_diagonal_ = CUBLAS_DIAG_NON_UNIT;
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
// reverting the trans flag by default, so also flip the lower flag
lower = !lower;
uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
bool unit_diagonal = AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal");
unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
// converting row major to col major is the same as reverting the trans flag
if (trans == "N") {
trans_ = CUBLAS_OP_T;
} else if (trans == "T") {
trans_ = CUBLAS_OP_N;
} else if (trans == "C") {
trans_ = CUBLAS_OP_N;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "].";
}
}
InitSizeLists();
@ -263,4 +264,4 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_