From 911e4660ce0f90df34bf5063285b7da59b7f997c Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Sun, 20 Feb 2022 12:00:05 +0800 Subject: [PATCH] Rename matrix_triangular_solve to be compatible with tensorflow's MatrixTriangularSolve OP --- ... => matrix_triangular_solve_cpu_kernel.cc} | 42 ++++++++----- ...h => matrix_triangular_solve_cpu_kernel.h} | 24 ++++--- ... => matrix_triangular_solve_gpu_kernel.cc} | 6 +- ...h => matrix_triangular_solve_gpu_kernel.h} | 63 ++++++++++--------- 4 files changed, 77 insertions(+), 58 deletions(-) rename mindspore/ccsrc/plugin/device/cpu/kernel/eigen/{solve_triangular_cpu_kernel.cc => matrix_triangular_solve_cpu_kernel.cc} (72%) rename mindspore/ccsrc/plugin/device/cpu/kernel/eigen/{solve_triangular_cpu_kernel.h => matrix_triangular_solve_cpu_kernel.h} (61%) rename mindspore/ccsrc/plugin/device/gpu/kernel/math/{solve_triangular_gpu_kernel.cc => matrix_triangular_solve_gpu_kernel.cc} (85%) rename mindspore/ccsrc/plugin/device/gpu/kernel/math/{solve_triangular_gpu_kernel.h => matrix_triangular_solve_gpu_kernel.h} (83%) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.cc similarity index 72% rename from mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.cc rename to mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.cc index e0c844e7627..79ceced3de4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h" +#include "plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h" #include #include #include @@ -38,7 +38,7 @@ constexpr auto kAMatrixDimNum = 2; constexpr size_t kRowIndex = 2; constexpr size_t kColIndex = 1; template -void SolveTriangularCpuKernelMod::InitShape(const CNodePtr &kernel_node) { +void MatrixTriangularSolveCpuKernelMod::InitShape(const CNodePtr &kernel_node) { auto a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); // Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid. @@ -59,20 +59,30 @@ void SolveTriangularCpuKernelMod::InitShape(const CNodePtr &kernel_node) { } template -void SolveTriangularCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { +void MatrixTriangularSolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); InitShape(kernel_node); - lower_ = AnfAlgo::GetNodeAttr(kernel_node, LOWER); - unit_diagonal_ = AnfAlgo::GetNodeAttr(kernel_node, UNIT_DIAGONAL); - const std::string trans = AnfAlgo::GetNodeAttr(kernel_node, TRANS); - if (trans == "N") { - trans_ = false; - } else if (trans == "T") { - trans_ = true; - } else if (trans == "C") { - trans_ = true; + if (AnfAlgo::HasNodeAttr(ADJOINT, kernel_node)) { + // MatrixTriangularSolve attribute + trans_ = AnfAlgo::GetNodeAttr(kernel_node, ADJOINT); + if (AnfAlgo::HasNodeAttr(TRANS, kernel_node)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the attribute 'adjoint' and 'trans' could not exist at the same time."; + } } else { - MS_LOG(EXCEPTION) << "Trans should be in [N, T, C], but got [" << trans << "]."; + lower_ = AnfAlgo::GetNodeAttr(kernel_node, LOWER); + unit_diagonal_ = AnfAlgo::GetNodeAttr(kernel_node, UNIT_DIAGONAL); + const std::string trans = AnfAlgo::GetNodeAttr(kernel_node, TRANS); + if (trans == "N") { + trans_ = false; + } else if (trans == "T") { + trans_ = true; + } else if (trans == "C") { + trans_ = true; + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' should be in ['N', 'T', 'C'], but got [" << trans + << "]."; + } } } @@ -96,9 +106,9 @@ inline void solve(const MatrixBase &a, const MatrixBase &b } template -bool SolveTriangularCpuKernelMod::Launch(const std::vector &inputs, - const std::vector & /* workspace */, - const std::vector &outputs) { +bool MatrixTriangularSolveCpuKernelMod::Launch(const std::vector &inputs, + const std::vector & /* workspace */, + const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h similarity index 61% rename from mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h rename to mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h index 8829e701afa..6493e26508d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ #include #include "plugin/device/cpu/kernel/cpu_kernel.h" @@ -24,10 +24,10 @@ namespace mindspore { namespace kernel { template -class SolveTriangularCpuKernelMod : public NativeCpuKernelMod { +class MatrixTriangularSolveCpuKernelMod : public NativeCpuKernelMod { public: - SolveTriangularCpuKernelMod() = default; - ~SolveTriangularCpuKernelMod() override = default; + MatrixTriangularSolveCpuKernelMod() = default; + ~MatrixTriangularSolveCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, @@ -47,12 +47,20 @@ class SolveTriangularCpuKernelMod : public NativeCpuKernelMod { MS_REG_CPU_KERNEL_T( SolveTriangular, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SolveTriangularCpuKernelMod, float) + MatrixTriangularSolveCpuKernelMod, float) MS_REG_CPU_KERNEL_T( SolveTriangular, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - SolveTriangularCpuKernelMod, double) + MatrixTriangularSolveCpuKernelMod, double) +MS_REG_CPU_KERNEL_T( + MatrixTriangularSolve, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatrixTriangularSolveCpuKernelMod, float) +MS_REG_CPU_KERNEL_T( + MatrixTriangularSolve, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + MatrixTriangularSolveCpuKernelMod, double) } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.cc similarity index 85% rename from mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.cc rename to mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.cc index 905b97762ed..a39f67a95fe 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.cc @@ -14,17 +14,17 @@ * limitations under the License. */ -#include "plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h" +#include "plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h" namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE( SolveTriangular, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SolveTriangularGpuKernelMod, float) + MatrixTriangularSolveGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE( SolveTriangular, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - SolveTriangularGpuKernelMod, double) + MatrixTriangularSolveGpuKernelMod, double) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h similarity index 83% rename from mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h rename to mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h index 66ce1e64541..9de672c25f2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_ #include #include #include @@ -39,10 +39,10 @@ constexpr size_t kIndexBBuffer = 2; constexpr size_t kIndexBTransposeShape = 3; constexpr size_t kIndexBTransposeAxis = 4; template -class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { +class MatrixTriangularSolveGpuKernelMod : public NativeGpuKernelMod { public: - SolveTriangularGpuKernelMod() = default; - ~SolveTriangularGpuKernelMod() = default; + MatrixTriangularSolveGpuKernelMod() = default; + ~MatrixTriangularSolveGpuKernelMod() = default; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { @@ -159,32 +159,33 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { lda_ = SizeToInt(m_); ldb_ = SizeToInt(m_); - const std::string trans = AnfAlgo::GetNodeAttr(kernel_node, "trans"); - // converting row major to col major is the same as reverting the trans flag - if (trans == "N") { - trans_ = CUBLAS_OP_T; - } else if (trans == "T") { - trans_ = CUBLAS_OP_N; - } else if (trans == "C") { - trans_ = CUBLAS_OP_N; + if (AnfAlgo::HasNodeAttr("adjoint", kernel_node)) { + // MatrixTriangularSolve attribute + bool trans = AnfAlgo::GetNodeAttr(kernel_node, "adjoint"); + // converting row major to col major is the same as reverting the trans flag + trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T; + if (AnfAlgo::HasNodeAttr("trans", kernel_node)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the attribute 'adjoint' and 'trans' could not exist at the same time."; + } } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "]."; - } - - bool lower = AnfAlgo::GetNodeAttr(kernel_node, "lower"); - // reverting the trans flag by default, so also flip the lower flag - lower = !lower; - if (lower) { - uplo_ = CUBLAS_FILL_MODE_LOWER; - } else { - uplo_ = CUBLAS_FILL_MODE_UPPER; - } - - bool unit_diagonal = AnfAlgo::GetNodeAttr(kernel_node, "unit_diagonal"); - if (unit_diagonal) { - unit_diagonal_ = CUBLAS_DIAG_UNIT; - } else { - unit_diagonal_ = CUBLAS_DIAG_NON_UNIT; + bool lower = AnfAlgo::GetNodeAttr(kernel_node, "lower"); + // reverting the trans flag by default, so also flip the lower flag + lower = !lower; + uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + bool unit_diagonal = AnfAlgo::GetNodeAttr(kernel_node, "unit_diagonal"); + unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + const std::string trans = AnfAlgo::GetNodeAttr(kernel_node, "trans"); + // converting row major to col major is the same as reverting the trans flag + if (trans == "N") { + trans_ = CUBLAS_OP_T; + } else if (trans == "T") { + trans_ = CUBLAS_OP_N; + } else if (trans == "C") { + trans_ = CUBLAS_OP_N; + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "]."; + } } InitSizeLists(); @@ -263,4 +264,4 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_