forked from mindspore-Ecosystem/mindspore
commit
a26958dc48
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/trsm_solve_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SolveTriangular,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TrsmGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SolveTriangular,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
TrsmGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kAVectorxDimNum = 1;
|
||||
constexpr auto kAMatrixDimNum = 2;
|
||||
template <typename T>
|
||||
class TrsmGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TrsmGpuKernel() = default;
|
||||
~TrsmGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto inputA_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const size_t batch = m_ * n_;
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output_addr failed");
|
||||
|
||||
T alpha = 1;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasStrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, CUBLAS_DIAG_NON_UNIT, m_,
|
||||
n_, &alpha, inputA_addr, lda_, output_addr, ldb_),
|
||||
"cublas trsm Fail");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasDtrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, CUBLAS_DIAG_NON_UNIT, m_,
|
||||
n_, &alpha, inputA_addr, lda_, output_addr, ldb_),
|
||||
"cublas trsm Fail");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
is_null_input_ = CHECK_NULL_INPUT(A_shape) || CHECK_NULL_INPUT(b_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'TrsmGpuKernel', input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (A_shape[kDim0] != A_shape[kDim1]) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a squre matrix, but got [" << A_shape[kDim0] << " X "
|
||||
<< A_shape[kDim1] << "]";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
|
||||
if (b_shape.size() != kAVectorxDimNum && b_shape.size() != kAMatrixDimNum) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, b should be 1D or 2D, but got [" << b_shape.size() << "] dimensions";
|
||||
}
|
||||
if (b_shape[kDim0] != m_) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, b should match the shape of A, excepted [" << m_ << "] but got ["
|
||||
<< b_shape[kDim0] << "]";
|
||||
}
|
||||
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
|
||||
n_ = 1;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "b as a matrix is currently not supported.";
|
||||
}
|
||||
m_ = b_shape[kDim0];
|
||||
|
||||
lda_ = SizeToInt(m_);
|
||||
ldb_ = SizeToInt(m_);
|
||||
|
||||
bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower");
|
||||
if (lower) {
|
||||
uplo_ = CUBLAS_FILL_MODE_LOWER;
|
||||
} else {
|
||||
uplo_ = CUBLAS_FILL_MODE_UPPER;
|
||||
}
|
||||
|
||||
const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans");
|
||||
if (trans == "N") {
|
||||
trans_ = CUBLAS_OP_N;
|
||||
} else if (trans == "T") {
|
||||
trans_ = CUBLAS_OP_T;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "trans should be in [N, T], but got [" << trans << "]";
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t unit_size = sizeof(T);
|
||||
input_size_list_ = {m_ * m_ * unit_size, m_ * n_ * unit_size};
|
||||
output_size_list_ = {m_ * n_ * unit_size};
|
||||
}
|
||||
|
||||
private:
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
int lda_{0};
|
||||
int ldb_{0};
|
||||
bool is_null_input_{false};
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
cublasFillMode_t uplo_{CUBLAS_FILL_MODE_UPPER};
|
||||
cublasOperation_t trans_{CUBLAS_OP_N};
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""test for SolveTriangular"""
|
||||
|
||||
from typing import Generic
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy.linalg import solve_triangular
|
||||
|
@ -65,6 +65,8 @@ def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False,
|
|||
|
||||
def match(a, b, lower, trans):
|
||||
sci_x = solve_triangular(a, b, lower=lower, trans=trans)
|
||||
if context.get_context("device_target") == "GPU":
|
||||
a = a.T
|
||||
mind_x = mind_solve(Tensor(a), Tensor(
|
||||
b), lower=lower, trans=trans).asnumpy()
|
||||
|
||||
|
@ -81,7 +83,7 @@ def match(a, b, lower, trans):
|
|||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
def test_2D(n: int, dtype: Generic, lower: bool, trans: str):
|
||||
def test_2D(n: int, dtype, lower: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
|
@ -100,7 +102,7 @@ def test_2D(n: int, dtype: Generic, lower: bool, trans: str):
|
|||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
def test_1D(n: int, dtype: Generic, lower: bool, trans: str):
|
||||
def test_1D(n: int, dtype, lower: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N]
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test for SolveTriangular"""
|
||||
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy.linalg import solve_triangular
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(100)
|
||||
|
||||
|
||||
class SolveTriangular(PrimitiveWithInfer):
|
||||
"""
|
||||
SolveTriangular op frontend implementation
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower: bool, trans: str):
|
||||
"""Initialize SolveTriangular"""
|
||||
self.upper_triangle = validator.check_value_type(
|
||||
"lower", lower, [bool], self.name)
|
||||
self.trans = validator.check_value_type(
|
||||
"trans", trans, [str], self.name)
|
||||
|
||||
self.init_prim_io_names(inputs=['A', 'b'], outputs=['output'])
|
||||
|
||||
def __infer__(self, A, b):
|
||||
out_shapes = b['shape']
|
||||
return {
|
||||
'shape': tuple(out_shapes),
|
||||
'dtype': A['dtype'],
|
||||
'value': None
|
||||
}
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
|
||||
self.name, True)
|
||||
return x_dtype
|
||||
|
||||
|
||||
def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False,
|
||||
overwrite_b=False, debug=None, check_finite=True):
|
||||
solve = SolveTriangular(lower, trans)
|
||||
return solve(a, b)
|
||||
|
||||
|
||||
def match(a, b, lower, trans):
|
||||
sci_x = solve_triangular(a, b, lower=lower, trans=trans)
|
||||
if context.get_context("device_target") == "GPU":
|
||||
a = a.T
|
||||
mind_x = mind_solve(Tensor(a), Tensor(
|
||||
b), lower=lower, trans=trans).asnumpy()
|
||||
|
||||
print(sci_x)
|
||||
print(mind_x)
|
||||
print(f'lower: {lower}')
|
||||
assert np.allclose(sci_x, mind_x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
def test_2D(n: int, dtype, lower: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
# add Identity matrix to make matrix A non-singular
|
||||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
b = np.random.random((n, 1)).astype(dtype)
|
||||
match(a, b, lower, trans)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [10, 20])
|
||||
@pytest.mark.parametrize('trans', ["N", "T"])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
def test_1D(n: int, dtype, lower: bool, trans: str):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
# add Identity matrix to make matrix A non-singular
|
||||
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
|
||||
b = np.random.random(n).astype(dtype)
|
||||
match(a, b, lower, trans)
|
Loading…
Reference in New Issue