From e0d36aa741448320d9b330186f09842a70679afa Mon Sep 17 00:00:00 2001 From: linjie <1789133296@qq.com> Date: Tue, 18 Oct 2022 18:23:32 +0800 Subject: [PATCH] [feat] [assistant] [I5EWKK] implement CholeskySolve operator in a new way --- .../kernel/math/cholesky_solve_gpu_kernel.cc | 19 +- .../kernel/math/cholesky_solve_gpu_kernel.h | 269 +++++++++--------- mindspore/core/ops/cholesky_solve.cc | 10 +- .../ops/_grad_experimental/grad_math_ops.py | 7 +- .../mindspore/ops/operations/math_ops.py | 3 +- tests/st/ops/gpu/test_cholesky_solve_op.py | 55 ++++ 6 files changed, 218 insertions(+), 145 deletions(-) create mode 100644 tests/st/ops/gpu/test_cholesky_solve_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.cc index 7630abdbf1f..7d784464974 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.cc @@ -15,16 +15,17 @@ */ #include "plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h" +#include "mindspore/core/ops/cholesky_solve.h" + namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE( - CholeskySolve, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - CholeskySolveGpuKernelMod, float) - -MS_REG_GPU_KERNEL_ONE( - CholeskySolve, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - CholeskySolveGpuKernelMod, double) +using CSGKM = CholeskySolveGpuKernelMod; +std::vector> CSGKM::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &CholeskySolveGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &CholeskySolveGpuKernelMod::LaunchKernel}, +}; +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CholeskySolve, CholeskySolveGpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h index 191a81a3561..1fab774cfff 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h @@ -20,13 +20,15 @@ #include #include #include -#include #include +#include +#include #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triangle_matrix_copy_impl.cuh" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/kernel_constants.h" #include "include/common/utils/convert_utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh" namespace mindspore { namespace kernel { @@ -36,163 +38,174 @@ constexpr size_t kCholeskyOutputsNum = 1; constexpr size_t kOutputIndex = 0; constexpr size_t kRowIndex = 2; constexpr size_t kColIndex = 1; +inline cublasStatus_t cublasXtrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha, + const float *A, int lda, float *B, int ldb) { + return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} +inline cublasStatus_t cublasXtrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha, + const double *A, int lda, double *B, int ldb) { + return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} +inline cublasStatus_t cublasXtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, const float *const A[], int lda, float *const B[], int ldb, + int batchCount) { + return cublasStrsmBatched(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} +inline cublasStatus_t cublasXtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, const double *const A[], int lda, double *const B[], + int ldb, int batchCount) { + return cublasDtrsmBatched(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); +} -template class CholeskySolveGpuKernelMod : public NativeGpuKernelMod { public: - using pointer = T *; CholeskySolveGpuKernelMod() = default; ~CholeskySolveGpuKernelMod() override = default; bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) override { - constexpr size_t input_num = 1; - constexpr size_t output_num = 1; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); - kernel_name_ = base_operator->GetPrim()->name(); - if (base_operator->HasAttr("upper")) { - upper_ = GetValue(base_operator->GetAttr("upper")); + kernel_name_ = base_operator->name(); + upper_ = GetValue(base_operator->GetAttr("upper")); + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For 'CholeskySolve', it does not support this kernel type: " << kernel_attr; + return false; } - // Gpu input is col major default, so need to change row major. - // In order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix - // when input b_col is not equal to one, maybe need a normal transpose op inplace. + kernel_func_ = func_list_[index].second; + return true; + } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = reinterpret_cast(stream_ptr); + return kernel_func_(this, inputs, workspace, outputs); + } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + using pointer = T *; + CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast(cuda_stream_)), + "cholesky solve cublasSetStream failed"); + auto input_a_addr = GetDeviceAddress(inputs, kDim0); + auto input_b_addr = GetDeviceAddress(inputs, kDim1); + auto output_addr = GetDeviceAddress(outputs, kDim0); + auto d_a_array_addr = GetDeviceAddress(workspace, kDim0); + auto d_b_array_addr = GetDeviceAddress(workspace, kDim1); + auto d_c_array_addr = GetDeviceAddress(workspace, kDim2); + std::vector h_a_array(batch_num_); + std::vector h_b_array(batch_num_); + std::vector h_c_array(batch_num_); + for (size_t i = 0; i < batch_num_; i++) { + h_a_array[i] = input_a_addr + i * lda_ * nrhs_; + h_b_array[i] = input_b_addr + i * ldb_ * m_; + h_c_array[i] = output_addr + i * lda_ * nrhs_; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(d_a_array_addr, h_a_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream_)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(d_b_array_addr, h_b_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream_)), + "cuda memcopy Fail"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(d_c_array_addr, h_c_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream_)), + "cuda memcopy Fail"); + MatrixTranspose(input_a_addr, SizeToInt(batch_num_ * lda_ * nrhs_), SizeToInt(lda_), SizeToInt(nrhs_), output_addr, + device_id_, reinterpret_cast(cuda_stream_)); + cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER; if (upper_) { uplo_ = CUBLAS_FILL_MODE_LOWER; - } else { - uplo_ = CUBLAS_FILL_MODE_UPPER; + transa_ = CUBLAS_OP_N; + transa_t_ = CUBLAS_OP_T; } - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); - + T alpha = 1; + if (batch_num_ == 1) { + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasXtrsm(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_, CUBLAS_DIAG_NON_UNIT, + lda_, nrhs_, &alpha, input_b_addr, ldb_, output_addr, lda_), + "cholesky solve cublasXtrsm failed!"); + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( + cublasXtrsm(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_t_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha, + input_b_addr, ldb_, output_addr, lda_), + "cholesky solve cublasXtrsm failed!"); + } else { + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( + cublasXtrsmBatched(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha, + d_b_array_addr, ldb_, d_c_array_addr, lda_, batch_num_), + "cholesky solve cublasXgetrsBatched failed!"); + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( + cublasXtrsmBatched(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_t_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha, + d_b_array_addr, ldb_, d_c_array_addr, lda_, batch_num_), + "cholesky solve cublasXgetrsBatched failed!"); + } + MatrixTranspose(output_addr, SizeToInt(batch_num_ * lda_ * nrhs_), SizeToInt(nrhs_), SizeToInt(lda_), input_a_addr, + device_id_, reinterpret_cast(cuda_stream_)); + auto output_elements = batch_num_ * lda_ * nrhs_; + MatrixCopy(input_a_addr, output_addr, output_elements, reinterpret_cast(cuda_stream_)); return true; } int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs, const std::map &inputsOnHost) override { - if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) { + if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { return ret; } - ResetResource(); - auto in_a_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector()); - auto in_b_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector()); - (void)InitDim(in_a_shape, in_b_shape); + + const auto b_shape = inputs.at(kIndex0)->GetShapeVector(); + const auto cho_shape = inputs.at(kIndex1)->GetShapeVector(); + + is_null_input_ = CHECK_SHAPE_NULL(LongVecToSizeVec(b_shape), kernel_name_, "input_a") || + CHECK_SHAPE_NULL(LongVecToSizeVec(cho_shape), kernel_name_, "input_b"); + batch_num_ = std::accumulate(b_shape.begin(), b_shape.end() - kIndex2, int64_t(1), std::multiplies{}); + m_ = cho_shape.back(); + ldb_ = m_; + lda_ = m_; + nrhs_ = b_shape.back(); + + workspace_size_list_.clear(); + workspace_size_list_ = {batch_num_ * sizeof(float *), batch_num_ * sizeof(float *), batch_num_ * sizeof(float *), + batch_num_ * sizeof(int)}; + return KRET_OK; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast(stream_ptr)), - "cholesky solve cusolverDnSetStream failed"); - auto input_a_addr = GetDeviceAddress(inputs, kDim0); - auto input_b_addr = GetDeviceAddress(inputs, kDim1); - auto output_addr = GetDeviceAddress(outputs, kDim0); - auto d_a_array_addr = GetDeviceAddress(workspace, kDim0); - auto d_b_array_addr = GetDeviceAddress(workspace, kDim1); - auto d_info_array_addr = GetDeviceAddress(workspace, kDim2); - for (size_t i = 0; i < outer_batch_; i++) { - h_a_array_[i] = input_a_addr + i * lda_ * m_; - h_b_array_[i] = input_b_addr + i * ldb_ * nrhs_; - } - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(d_a_array_addr, h_a_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(d_b_array_addr, h_b_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cuda memcopy Fail"); - // Only support rhs = 1 - if constexpr (std::is_same_v) { - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnSpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_, d_b_array_addr, ldb_, - d_info_array_addr, outer_batch_), - "cusolver cholesky solve batched Fail"); - } else if constexpr (std::is_same_v) { - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnDpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_, d_b_array_addr, ldb_, - d_info_array_addr, outer_batch_), - "cusolver cholesky solve batched Fail"); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; - } - size_t output_elements = outputs.at(kDim0)->size / unit_size_; - // Copy results from written input's matrix to output's matrix. - MatrixCopy(input_b_addr, output_addr, output_elements, reinterpret_cast(stream_ptr)); - return true; - } - - void ResetResource() { - input_size_list_.clear(); - workspace_size_list_.clear(); - output_size_list_.clear(); - h_b_array_.clear(); - h_a_array_.clear(); - } - - protected: - void InitSizeLists() { - size_t input_size = outer_batch_ * m_ * lda_ * unit_size_; - input_size_list_.emplace_back(input_size); - input_size = outer_batch_ * nrhs_ * ldb_ * unit_size_; - input_size_list_.emplace_back(input_size); - - size_t workspace_size = outer_batch_ * sizeof(pointer); - workspace_size_list_.emplace_back(workspace_size); - workspace_size_list_.emplace_back(workspace_size); - workspace_size = outer_batch_ * sizeof(int); - workspace_size_list_.emplace_back(workspace_size); - - size_t output_size = outer_batch_ * m_ * unit_size_; - output_size_list_.push_back(output_size); + std::vector GetOpSupport() override { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; } private: - void InitDim(const std::vector &in_a_shape, const std::vector &in_b_shape) { - constexpr size_t min_dim = 1; - if (in_a_shape.size() <= min_dim) { - MS_LOG_EXCEPTION << kernel_name_ << " input a shape dim is " << in_a_shape.size() << " which is invalid."; - } - cho_row_ = in_a_shape.at(in_a_shape.size() - kRowIndex); - cho_col_ = in_a_shape.at(in_a_shape.size() - kColIndex); - outer_batch_ = min_dim; - for (int batch = 0; batch < static_cast(in_a_shape.size() - kRowIndex); ++batch) { - outer_batch_ *= in_a_shape.at(batch); - } - if (cho_row_ != cho_col_) { - MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. " - << "Cholesky expects a square matrix. but input a shape is: " << cho_row_ << ", " << cho_col_; - } - const bool is_right_equal_left = in_a_shape.size() == in_b_shape.size(); - size_t b_row; - if (is_right_equal_left) { - b_row = in_b_shape.at(in_b_shape.size() - kRowIndex); - } else { - b_row = in_b_shape.back(); - } - if (cho_row_ != b_row) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', right hand matrix should be equal to left matrix"; - } - m_ = SizeToInt(cho_row_); - lda_ = m_; - ldb_ = m_; - h_a_array_.resize(outer_batch_); - h_b_array_.resize(outer_batch_); - InitSizeLists(); - } - size_t cho_row_{0}; - size_t cho_col_{0}; - size_t unit_size_{sizeof(T)}; - size_t nrhs_{1}; - size_t outer_batch_{0}; + using CholeskySolveFunc = + std::function &, + const std::vector &, const std::vector &)>; + size_t nrhs_{0}; + size_t batch_num_{0}; size_t m_{0}; size_t lda_{0}; size_t ldb_{0}; - cusolverDnHandle_t handle_{nullptr}; - cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER; - std::vector h_a_array_; - std::vector h_b_array_; + cublasHandle_t handle_{nullptr}; + cublasOperation_t transa_{CUBLAS_OP_T}; + cublasOperation_t transa_t_{CUBLAS_OP_N}; bool upper_{false}; + bool is_null_input_; + CholeskySolveFunc kernel_func_; + static std::vector> func_list_; + void *cuda_stream_{nullptr}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/cholesky_solve.cc b/mindspore/core/ops/cholesky_solve.cc index 7b844995a55..1437d37490e 100644 --- a/mindspore/core/ops/cholesky_solve.cc +++ b/mindspore/core/ops/cholesky_solve.cc @@ -30,7 +30,7 @@ abstract::ShapePtr CholeskySolveInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const size_t kDefalutRank = 2; - const size_t kBatchRank = 3; + const size_t kBatchRank = 1; const size_t kBatchIndex = 3; const size_t kRowIndex = 2; const size_t kColIndex = 1; @@ -44,12 +44,12 @@ abstract::ShapePtr CholeskySolveInferShape(const PrimitivePtr &primitive, out_shape.push_back(abstract::Shape::kShapeRankAny); return std::make_shared(out_shape); } - if (x1_shape.size() != kDefalutRank && x1_shape.size() != kBatchRank) { - MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x1 must be equal to 2 or 3" + if (x1_shape.size() <= kBatchRank) { + MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x1 have at least 2 dimensions" << ", while got x1 rank " << x1_shape.size() << "."; } - if (x2_shape.size() != kDefalutRank && x2_shape.size() != kBatchRank) { - MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x2 must be equal to 2 or 3" + if (x2_shape.size() <= kBatchRank) { + MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x2 have at least 2 dimensions" << ", while got x2 rank " << x2_shape.size() << "."; } if (x1_shape.size() != x2_shape.size()) { diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py index 77862890ac6..7878039fb48 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py @@ -1256,8 +1256,11 @@ def get_bprop_cholesky_solve(self): else: dx2 = neg_op(matmul_op(common_term, x2)) else: - common_term = batchmatmul_op(dx1, transpose(out, (0, 2, 1))) - common_term = common_term + transpose(common_term, (0, 2, 1)) + x2_dim_size = len(shape_op(x2)) + x2_dim_order = list(range(x2_dim_size)) + target_order = x2_dim_order[:-2] + x2_dim_order[-2:][::-1] + common_term = batchmatmul_op(dx1, transpose(out, tuple(target_order))) + common_term = common_term + transpose(common_term, tuple(target_order)) if upper is True: dx2 = neg_op(batchmatmul_op(x2, common_term)) else: diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 55771a8a6d7..4041a213038 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -6655,6 +6655,7 @@ class CholeskySolve(Primitive): with float32 or float64 data type. - **x2** (Tensor) - Tensor of shape :math:`(*, N, N)`, indicating 2D or 3D square matrices composed of upper or lower triangular Cholesky factor, with float32 or float64 data type. + x1 and x2 must have the same type. Outputs: Tensor, has the same shape and data type as `x1`. @@ -6670,7 +6671,7 @@ class CholeskySolve(Primitive): ValueError: If `x2` is not 2D or 3D square matrices. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x1 = Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), mindspore.float32) diff --git a/tests/st/ops/gpu/test_cholesky_solve_op.py b/tests/st/ops/gpu/test_cholesky_solve_op.py new file mode 100644 index 00000000000..10d415306e7 --- /dev/null +++ b/tests/st/ops/gpu/test_cholesky_solve_op.py @@ -0,0 +1,55 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore +from mindspore import nn +from mindspore import context +from mindspore import Tensor +from mindspore.ops.operations.math_ops import CholeskySolve + + +class Net(nn.Cell): + """a class used to test CholeskySolve gpu operator.""" + + def __init__(self, upper=False): + super(Net, self).__init__() + self.cholesky_solve = CholeskySolve(upper=upper) + + def construct(self, x1, x2): + """construct.""" + return self.cholesky_solve(x1, x2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cholesky_solve(): + """ + Feature: CholeskySolve gpu TEST. + Description: test CholeskySolve operator + Expectation: the result match to numpy + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x1 = Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), mindspore.float32) + x2 = Tensor(np.array([[2, 0, 0], [4, 1, 0], [-1, 1, 2]]), mindspore.float32) + expect = np.array([[5.8125, -2.625, 0.625], [-2.625, 1.25, -0.25], [0.625, -0.25, 0.25]]) + net = Net() + mindspore_output = net(x1, x2) + diff = mindspore_output.asnumpy() - expect + error = np.ones(shape=expect.shape) + assert np.all(diff < error)