forked from mindspore-Ecosystem/mindspore
Add calculation of triangle matrix determinant op at GPU back-end
This commit is contained in:
parent
831d3a69d4
commit
99f2927c21
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "determinant_triangle_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void DetTriangleKernel(T *input, T *output, size_t matrix_n_, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = 1;
|
||||
for (int pos = 0; pos < matrix_n_*matrix_n_; pos += matrix_n_+1) {
|
||||
output[i] *= input[i * matrix_n_ * matrix_n_ + pos];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) {
|
||||
DetTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, matrix_n_, count);
|
||||
return;
|
||||
}
|
||||
|
||||
__device__ bool dev_error_res = false;
|
||||
|
||||
template <typename T>
|
||||
__global__ void CheckTriangleKernel(T *input, int fill_mode_, size_t matrix_n_, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
size_t idx = 0;
|
||||
if (fill_mode_ == 0) { // UPPER half
|
||||
for (size_t row = 0; row < matrix_n_; row++) {
|
||||
for (size_t col = row + 1; col < matrix_n_; col++) {
|
||||
idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col;
|
||||
if (static_cast<float>(input[idx]) != 0) {
|
||||
dev_error_res = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (fill_mode_ == 1) { // LOWER half
|
||||
for (size_t row = 0; row < matrix_n_; row++) {
|
||||
for (size_t col = 0; col < row; col++) {
|
||||
idx = i * matrix_n_ * matrix_n_ + row * matrix_n_ + col;
|
||||
if (static_cast<float>(input[idx]) != 0) {
|
||||
dev_error_res = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dev_error_res = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
dev_error_res = true;
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream) {
|
||||
CheckTriangleKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, fill_mode_, matrix_n_, count);
|
||||
bool host_error_res = false;
|
||||
cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool));
|
||||
return host_error_res;
|
||||
}
|
||||
|
||||
template void DetTriangle<float>(float *input, float *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream);
|
||||
template void DetTriangle<half>(half *input, half *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream);
|
||||
template bool CheckTriangle<float>(float *input, int fill_mode_, size_t matrix_n_, size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template bool CheckTriangle<half>(half *input, int fill_mode_, size_t matrix_n_, size_t count,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void DetTriangle(T *input, T *output, size_t matrix_n_, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
bool CheckTriangle(T *input, int fill_mode_, size_t matrix_n_, size_t count, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DETERMINANT_TRIANGLE_IMPL_H_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/determinant_triangle_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DetTriangleGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(DetTriangle, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
DetTriangleGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/determinant_triangle_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DetTriangleGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DetTriangleGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {}
|
||||
~DetTriangleGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (!CheckTriangle(input_addr, fill_mode_, matrix_n_, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr))) {
|
||||
if (fill_mode_ == 0) {
|
||||
MS_LOG(ERROR) << "The elements in the upper half of the maxtices should be all 0.";
|
||||
} else if (fill_mode_ == 1) {
|
||||
MS_LOG(ERROR) << "The elements in the lower half of the maxtices should be all 0.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The input matrix should be either upper filled or lower filled.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
DetTriangle(input_addr, output_addr, matrix_n_, outputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but DetTriangle needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but DetTriangle needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
matrix_n_ = input_shape[input_shape.size() - 1];
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
}
|
||||
if (output_size_ != input_size_ / matrix_n_ / matrix_n_) {
|
||||
MS_LOG(ERROR) << "The output shape is wrong.";
|
||||
return false;
|
||||
}
|
||||
if (input_shape[input_shape.size() - 2] != input_shape[input_shape.size() - 1]) {
|
||||
MS_LOG(ERROR) << "The maxtices should be in shape of square.";
|
||||
return false;
|
||||
}
|
||||
fill_mode_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("fill_mode"));
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t matrix_n_;
|
||||
int fill_mode_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DETRMINANT_TRIANGLE_GPU_KERNEL_H_
|
|
@ -87,7 +87,7 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, Popul
|
|||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky)
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, DetTriangle)
|
||||
from .sparse_ops import SparseToDense
|
||||
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx
|
||||
|
||||
|
|
|
@ -636,3 +636,22 @@ class Cholesky(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
|
||||
return x1_dtype
|
||||
|
||||
class DetTriangle(PrimitiveWithInfer):
|
||||
"""
|
||||
Calculate the determinant of triangle matrices
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, fill_mode=0):
|
||||
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
|
||||
self.fill_mode = fill_mode
|
||||
self.add_prim_attr('fill_mode', self.fill_mode)
|
||||
|
||||
def infer_shape(self, x1_shape):
|
||||
out_shape = x1_shape
|
||||
del out_shape[-2:]
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
|
||||
return x1_dtype
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, fill_mode=0):
|
||||
super(Net, self).__init__()
|
||||
self.det_triangle = P.DetTriangle(fill_mode=fill_mode)
|
||||
|
||||
def construct(self, x):
|
||||
return self.det_triangle(x)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_1D():
|
||||
fill_mode = 0
|
||||
input_x = np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]).astype(np.float32)
|
||||
net = Net(fill_mode=fill_mode)
|
||||
tx = Tensor(input_x, mstype.float32)
|
||||
output = net(tx)
|
||||
assert output == 18
|
Loading…
Reference in New Issue