forked from mindspore-Ecosystem/mindspore
Refactor Cholesky to CholeskyTrsm
This commit is contained in:
parent
a07250e8b0
commit
9a81b50e43
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(CholeskyTrsm, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyTrsmGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
|
@ -29,10 +29,10 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CholeskyGpuKernel : public GpuKernel {
|
||||
class CholeskyTrsmGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
|
||||
~CholeskyGpuKernel() = default;
|
||||
CholeskyTrsmGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {}
|
||||
~CholeskyTrsmGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
@ -111,12 +111,12 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
if (in_shape.size() == 2) {
|
||||
batch_ = 1;
|
||||
if (in_shape[0] != in_shape[1]) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
|
||||
}
|
||||
} else if (in_shape.size() == 3) {
|
||||
batch_ = SizeToInt(in_shape[0]);
|
||||
if (in_shape[1] != in_shape[2]) {
|
||||
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
|
||||
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
|
||||
|
@ -140,12 +140,12 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
InitSizeLists();
|
||||
} else {
|
||||
if (in_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
|
||||
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Input Rank as 2.";
|
||||
}
|
||||
height = in_shape[0];
|
||||
width = in_shape[1];
|
||||
if (height != width) {
|
||||
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
|
||||
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Square Matrix as Input.";
|
||||
}
|
||||
if (SizeToInt(height) <= split_dim) {
|
||||
use_split_matrix = false;
|
|
@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul
|
|||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, DetTriangle)
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, CholeskyTrsm, DetTriangle)
|
||||
from .sparse_ops import SparseToDense
|
||||
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx
|
||||
|
||||
|
|
|
@ -608,7 +608,7 @@ class UpdateThorGradient(PrimitiveWithInfer):
|
|||
return x2_dtype
|
||||
|
||||
|
||||
class Cholesky(PrimitiveWithInfer):
|
||||
class CholeskyTrsm(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API for resnet50 THOR GPU backend
|
||||
"""
|
||||
|
|
|
@ -198,7 +198,7 @@ class Conv2d_Thor_GPU(_Conv):
|
|||
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
|
||||
self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32)
|
||||
self.cholesky = P.Cholesky(split_dim=split_dim)
|
||||
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
|
||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||
|
||||
def save_gradient(self, dout):
|
||||
|
@ -340,7 +340,7 @@ class Dense_Thor_GPU(Cell):
|
|||
self.axis = 0
|
||||
self.add = P.TensorAdd()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.cholesky = P.Cholesky(split_dim=split_dim)
|
||||
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
|
||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||
|
||||
def save_gradient(self, dout):
|
||||
|
|
Loading…
Reference in New Issue