add cholesky && lu factorization for gpu backend

This commit is contained in:
z00512249 2021-10-19 16:01:34 +08:00
parent 5b87e64557
commit a125654fbc
8 changed files with 652 additions and 7 deletions

View File

@ -16,8 +16,8 @@
#include "triangle_matrix_copy_impl.cuh"
template <typename T>
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m) {
__global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
@ -42,8 +42,42 @@ __global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMo
}
template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) {
__global__ void ScipyTriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
size_t batchIdx = i / (ldb * m);
size_t row = (i - batchIdx * ldb * m) / m;
size_t col = (i - batchIdx * ldb * m) % m;
// If fill mode is 'CUBLAS_FILL_MODE_LOWER', the upper half of the matrix should be all 0;
// If fill mode is 'CUBLAS_FILL_MODE_UPPER', the lower half of the matrix should be all 0;
// special case, only upper triangle data is correct, so copy up to lower, when lower case.
if (uplo == CUBLAS_FILL_MODE_UPPER) {
if (col < row) {
output[i] = 0;
} else {
output[i] = input[i];
}
} else if (uplo == CUBLAS_FILL_MODE_LOWER) {
if (col > row) {
output[i] = 0;
} else {
output[row * m + col] = input[col * m + row];
}
}
}
}
template <typename T>
__global__ void MatrixCopyKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i];
}
}
template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream) {
TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
return;
}
@ -52,3 +86,29 @@ template void TriangleMatrixCopy<float>(const float *input, float *output, cubla
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template <typename T>
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream) {
ScipyTriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m);
return;
}
template <typename T>
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
MatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template void ScipyTriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m,
cudaStream_t cuda_stream);
template void ScipyTriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count,
const size_t ldb, const size_t m, cudaStream_t cuda_stream);
template void ScipyTriangleMatrixCopy<double>(const double *input, double *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m,
cudaStream_t cuda_stream);
template void MatrixCopy<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void MatrixCopy<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void MatrixCopy<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -19,6 +19,13 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo,
const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream);
void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream);
template <typename T>
void ScipyTriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, const size_t count, const size_t ldb,
const size_t m, cudaStream_t cuda_stream);
template <typename T>
void MatrixCopy(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/cholesky_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScipyCholeskyGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ScipyCholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ScipyCholeskyGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,181 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include <type_traits>
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
constexpr size_t kCholeskyInputsNum = 1;
constexpr size_t kInputIndex = 0;
constexpr size_t kCholeskyOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kCholeskyDefaultShape = 1;
constexpr size_t kCholeskyNormalShape = 2;
constexpr size_t kCholeskyBatchedShape = 3;
template <typename T>
class ScipyCholeskyGpuKernel : public GpuKernel {
public:
ScipyCholeskyGpuKernel() = default;
~ScipyCholeskyGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
// here all addresses are malloc by cuda, so deal with them as device's address.
auto input1_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_array_addr = GetDeviceAddress<T *>(workspace, kDim0);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim1);
for (size_t i = 0; i < batch_; i++) {
h_array_[i] = input1_addr + i * lda_ * m_;
}
// 5. copy input's addr to d_array_addr
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(d_array_addr, h_array_.data(), sizeof(T *) * batch_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
// 6. solve to cholesky factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDpotrfBatched(handle_, uplo_, m_, d_array_addr, lda_, d_info_array_addr, batch_),
"cusolver cholesky batched Fail");
} else {
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
}
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
// 7. copy results from written input's matrix to output's matrix by up or lower flag.
ScipyTriangleMatrixCopy(input1_addr, output_addr, uplo_, output_elements, ldb_, m_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, "lower"));
if (lower_) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
}
// 1. get CuSolver Dense matrix handler
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
// 2. get Cublas handler
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
// 3. check input shape not null
bool is_null_input = CHECK_NULL_INPUT(in_shape);
if (is_null_input) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
}
// 4. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
}
return true;
}
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
if (in_shape.size() == kCholeskyDefaultShape) {
batch_ = 1;
cho_row_ = in_shape.at(kDim0);
cho_col_ = cho_row_;
} else if (in_shape.size() == kCholeskyNormalShape) {
batch_ = 1;
cho_row_ = in_shape.at(kDim0);
cho_col_ = in_shape.at(kDim1);
} else if (in_shape.size() == kCholeskyBatchedShape) {
batch_ = SizeToInt(in_shape.at(kDim0));
cho_row_ = in_shape.at(kDim1);
cho_col_ = in_shape.at(kDim2);
} else {
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
return false;
}
if (cho_row_ != cho_col_) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(cho_row_);
lda_ = m_;
ldb_ = m_;
h_array_.resize(batch_);
InitSizeLists();
return true;
}
void InitSizeLists() override {
input_size_ = batch_ * m_ * lda_ * unit_size_;
input_size_list_.push_back(input_size_);
output_size_ = batch_ * m_ * lda_ * unit_size_;
output_size_list_.push_back(output_size_);
size_t workspace_size = batch_ * sizeof(T *);
workspace_size_list_.resize(kDim2);
workspace_size_list_[kDim0] = workspace_size;
workspace_size = batch_ * sizeof(int);
workspace_size_list_[kDim1] = workspace_size;
}
size_t unit_size_{sizeof(T)};
size_t cho_row_{0};
size_t cho_col_{0};
size_t batch_{0};
size_t m_{0};
size_t lda_{0};
size_t ldb_{0};
size_t input_size_{0};
size_t output_size_{0};
cusolverDnHandle_t handle_{nullptr};
cublasHandle_t blas_handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
bool lower_{false};
std::vector<T *> h_array_{};
std::vector<size_t> input_size_list_{};
std::vector<size_t> output_size_list_{};
std::vector<size_t> workspace_size_list_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/lu_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,175 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include <type_traits>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t kLuInputsNum = 1;
constexpr size_t kInputIndex = 0;
constexpr size_t kLuOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kLuDefaultShape = 1;
constexpr size_t kLuNormalShape = 2;
template <typename T>
class LUGpuKernel : public GpuKernel {
public:
LUGpuKernel() = default;
~LUGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *piv_output_addr = nullptr;
if (pivot_on_) {
piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
}
auto info_output_addr = GetDeviceAddress<int>(outputs, kDim2);
// 4. query working space of getrf
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
"cusolver query lu work size fail");
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
MS_LOG(EXCEPTION) << "cusolver malloc work size fail";
}
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
"cusolver query lu work size fail");
// 5. malloc device working space of getrf
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
MS_LOG(EXCEPTION) << "cusolver malloc work size fail";
}
// 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
MS_LOG(EXCEPTION) << "cholesky factorization do not support other data type but only float or double, right now.";
}
// 7. copy results from written input's matrix to output's matrix.
// if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) {
// MS_LOG(EXCEPTION) << "memcpy lu output fail.";
// }
MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast<cudaStream_t>(stream_ptr));
if (d_work_) {
cudaFree(d_work_);
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
// 1. get CuSolver Dense matrix handler
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
// 2. check input shape not null
bool is_null_input = CHECK_NULL_INPUT(in_shape);
if (is_null_input) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input is null";
}
// 3. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(EXCEPTION) << "For 'PureCholeskyGpuKernel', input shape init failed.";
}
return true;
}
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
if (in_shape.size() == kLuDefaultShape) {
lu_row_ = in_shape.at(kDim0);
lu_col_ = lu_row_;
} else if (in_shape.size() == kLuNormalShape) {
lu_row_ = in_shape.at(kDim0);
lu_col_ = in_shape.at(kDim1);
} else {
MS_LOG(ERROR) << "Input Only support Rank 1 OR 2";
return false;
}
if (lu_row_ != lu_col_) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
lda_ = m_;
ldb_ = m_;
InitSizeLists();
return true;
}
void InitSizeLists() override {
size_t input_size = lda_ * m_ * unit_size_;
input_size_list_.push_back(input_size);
size_t output_size = lda_ * m_ * unit_size_;
size_t output_piv_size = 0;
size_t output_info_size = sizeof(int);
if (pivot_on_) {
output_piv_size = m_ * sizeof(int);
}
output_size_list_.resize(kDim3);
output_size_list_[kDim0] = output_size;
output_size_list_[kDim1] = output_piv_size;
output_size_list_[kDim2] = output_info_size;
}
size_t unit_size_{sizeof(T)};
size_t lu_row_{0};
size_t lu_col_{0};
size_t m_{0};
size_t lda_{0};
size_t ldb_{0};
int lwork_{0};
bool pivot_on_{true};
T *d_work_{nullptr};
cusolverDnHandle_t handle_{nullptr};
std::vector<size_t> input_size_list_{};
std::vector<size_t> output_size_list_{};
std::vector<size_t> workspace_size_list_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CHOLESKY_SOLVE_GPU_KERNEL_H_

View File

@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import scipy as scp
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from mindspore._checkparam import Validator as validator
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetCholesky(nn.Cell):
def __init__(self):
super(NetCholesky, self).__init__()
@ -32,13 +36,75 @@ class NetCholesky(nn.Cell):
return self.cholesky(x)
class ScipyCholesky(PrimitiveWithInfer):
"""
Inner API for Cholesky base class.
"""
@prim_attr_register
def __init__(self, lower=False, clean=False):
super().__init__(name="PureCholesky")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def __infer__(self, x):
x_shape = x['shape']
x_dtype = x['dtype']
return {
'shape': tuple(x_shape),
'dtype': x_dtype,
'value': None
}
class ScipyNetCholesky(nn.Cell):
def __init__(self, lower=False, clean=False):
super(ScipyNetCholesky, self).__init__()
self.cholesky = ScipyCholesky(lower, clean)
def construct(self, x):
return self.cholesky(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cholesky_fp32():
"""
Feature: ALL TO ALL
Description: test cases for origin cholesky [N,N]
Expectation: the result match np cholesky
"""
cholesky = NetCholesky()
x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
output = cholesky(Tensor(x, dtype=mstype.float32))
expect = np.linalg.cholesky(x)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect) < tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scipy_cholesky_fp32():
"""
Feature: ALL TO ALL
Description: test cases for new scipy cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32)
tensor_a = Tensor(a)
cholesky = ScipyNetCholesky(lower=True, clean=False)
output = cholesky(tensor_a)
cholesky1 = ScipyNetCholesky(lower=False, clean=False)
output1 = cholesky1(tensor_a)
expect = scp.linalg.cholesky(a, lower=True)
expect1 = scp.linalg.cholesky(a, lower=False)
rtol = 1.e-4
atol = 1.e-5
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
assert np.allclose(expect1, output1.asnumpy(), rtol=rtol, atol=atol)

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import Generic
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.numpy as mnp
import mindspore.common.dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
import scipy as scp
import numpy as np
import pytest
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class LU(PrimitiveWithInfer):
"""
LU decomposition with partial pivoting
P.A = L.U
"""
@prim_attr_register
def __init__(self):
super().__init__(name="LU")
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
pivots_shape = []
permutation_shape = []
ndim = len(x_shape)
if ndim == 0:
pivots_shape = x_shape
permutation_shape = x_shape
elif ndim == 1:
pivots_shape = x_shape[:-1]
# permutation_shape = x_shape[:-1]
else:
pivots_shape = x_shape[-2:-1]
# permutation_shape = x_shape[-2:-1]
output = {
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}
return output
class LuNet(nn.Cell):
def __init__(self):
super(LuNet, self).__init__()
self.lu = LU()
def construct(self, a):
return self.lu(a)
@pytest.mark.platform_x86_gpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_lu_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
expect, _ = scp.linalg.lu_factor(a)
mscp_lu_net = LuNet()
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
tensor_a = Tensor(a)
tensor_a = mnp.transpose(tensor_a)
output, _, _ = mscp_lu_net(tensor_a)
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
output = mnp.transpose(output)
rtol = 1.e-4
atol = 1.e-5
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)