!3045 Gpu support TopK kernel

Merge pull request !3045 from chenweifeng/sort
This commit is contained in:
mindspore-ci-bot 2020-07-16 21:54:28 +08:00 committed by Gitee
commit 251683096a
7 changed files with 420 additions and 1 deletions

View File

@ -44,7 +44,7 @@ if(ENABLE_GPU)
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
)
list(APPEND CUDA_NVCC_FLAGS -arch=sm_53)
list(APPEND CUDA_NVCC_FLAGS -arch=sm_53 --expt-relaxed-constexpr)
list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc")
list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc"
"runtime/device/gpu/distribution/collective_wrapper.cc"

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(TopK,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
TopKGpuKernel, float, int)
}
} // namespace mindspore

View File

@ -0,0 +1,110 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class TopKGpuKernel : public GpuKernel {
public:
TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {}
~TopKGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *k = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
S *indices = GetDeviceAddress<S>(outputs, 1);
T *data_buff = nullptr;
S *index_buff = nullptr;
if (use_share_mem_ == false) {
data_buff = GetDeviceAddress<T>(workspaces, 0);
index_buff = GetDeviceAddress<S>(workspaces, 1);
}
TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (sorted_ == false) {
std::cout << "================BitonicSortByKey" << std::endl;
BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shapes.size() - 1; i++) {
outer_size_ *= input_shapes[i];
}
inner_size_ = input_shapes[input_shapes.size() - 1];
k_ = output_shapes[output_shapes.size() - 1];
sorted_ = GetAttr<bool>(kernel_node, "sorted");
ceil_power2_ = RoundUpPower2(inner_size_);
size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S));
if (buffer_size > SHARED_MEM_PER_BLOCK) {
use_share_mem_ = false;
MS_LOG(WARNING) << "CUDA share memory not enough, sort with RAM";
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(outer_size_ * inner_size_ * sizeof(T));
input_size_list_.push_back(sizeof(S));
output_size_list_.push_back(outer_size_ * k_ * sizeof(T));
output_size_list_.push_back(outer_size_ * k_ * sizeof(S));
if (use_share_mem_ == false) {
workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T));
workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S));
}
}
private:
bool sorted_;
int outer_size_;
int inner_size_;
int k_;
bool use_share_mem_;
int ceil_power2_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // TopKpuKernel

View File

@ -0,0 +1,162 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
#include <limits>
#include <algorithm>
int RoundUpPower2(int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename T>
__inline__ __device__ void Swap(T *lhs, T *rhs) {
T tmp = lhs[0];
lhs[0] = rhs[0];
rhs[0] = tmp;
}
template <typename T, typename S>
__global__ void TopkKernel(const int outer, const int inner, const int ceil_power2, const T *input, const S *k,
T *output, S *indices, T *data_buff, S *index_buff) {
// default: sort with share memory
extern __shared__ T share_mem[];
T *data_arr = share_mem;
S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2);
// sort with RAM
if (data_buff != nullptr && index_buff != nullptr) {
data_arr = data_buff + blockIdx.x * ceil_power2;
index_arr = index_buff + blockIdx.x * ceil_power2;
}
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
index_arr[i] = i;
}
__syncthreads();
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
if ((tid & i) == 0) {
if (data_arr[tid] > data_arr[tid_comp]) {
Swap(&data_arr[tid], &data_arr[tid_comp]);
Swap(&index_arr[tid], &index_arr[tid_comp]);
}
} else {
if (data_arr[tid] < data_arr[tid_comp]) {
Swap(&data_arr[tid], &data_arr[tid_comp]);
Swap(&index_arr[tid], &index_arr[tid_comp]);
}
}
}
}
__syncthreads();
}
}
for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) {
output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1];
indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1];
}
}
template <typename T, typename S>
void TopK(const int &outer, const int &inner, const T *input, const S *k, T *output, S *indices, T *data_buff,
S *index_buff, cudaStream_t stream) {
int ceil_power2 = RoundUpPower2(inner);
int share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0;
int thread = std::min(ceil_power2, GET_THREADS);
TopkKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff,
index_buff);
}
template <typename T, typename S>
__global__ void BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input,
S *indices, T *data_buff, S *index_buff) {
// default: sort with share memory
extern __shared__ T share_mem[];
T *data_arr = share_mem;
S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2);
// sort with RAM
if (data_buff != nullptr && index_buff != nullptr) {
data_arr = data_buff + blockIdx.x * ceil_power2;
index_arr = index_buff + blockIdx.x * ceil_power2;
}
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max();;
}
__syncthreads();
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
if ((tid & i) == 0) {
if (index_arr[tid] > index_arr[tid_comp]) {
Swap(&data_arr[tid], &data_arr[tid_comp]);
Swap(&index_arr[tid], &index_arr[tid_comp]);
}
} else {
if (index_arr[tid] < index_arr[tid_comp]) {
Swap(&data_arr[tid], &data_arr[tid_comp]);
Swap(&index_arr[tid], &index_arr[tid_comp]);
}
}
}
}
__syncthreads();
}
}
for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) {
input[blockIdx.x * inner + tid] = data_arr[tid];
indices[blockIdx.x * inner + tid] = index_arr[tid];
}
}
template <typename T, typename S>
void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff,
cudaStream_t stream) {
int ceil_power2 = RoundUpPower2(inner);
size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S));
if (share_mem > SHARED_MEM_PER_BLOCK) {
share_mem = 0;
} else {
data_buff = nullptr;
index_buff = nullptr;
}
int thread = std::min(ceil_power2, GET_THREADS);
BitonicSortByKeyKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff,
index_buff);
}
template void TopK(const int &outer, const int &inner, const float *input_addr, const int *k, float *output,
int *indices, float *data_buff, int *index_buff, cudaStream_t stream);
template void BitonicSortByKey(const int &outer, const int &inner, float *input, int *indices, float *data_buff,
int *index_buff, cudaStream_t stream);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
void TopK(const int &outer, const int &inner, const T *input_addr, const S *k, T *output, S *indices, T *data_buff,
S *index_buff, cudaStream_t stream);
template <typename T, typename S>
void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff,
cudaStream_t stream);
int RoundUpPower2(int v);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_

View File

@ -30,6 +30,7 @@ class CudaCommon {
inline int blocks_num(const int total_threads) const {
return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_);
}
size_t share_memory_size() const { return max_share_memory_; }
static CudaCommon &GetInstance() {
static CudaCommon instance;
@ -44,6 +45,7 @@ class CudaCommon {
threads_per_block_ = prop.maxThreadsPerBlock;
max_blocks_ = prop.multiProcessorCount;
major_sm_ = prop.major;
max_share_memory_ = prop.sharedMemPerBlock;
}
~CudaCommon() = default;
CudaCommon(const CudaCommon &) = delete;
@ -52,10 +54,12 @@ class CudaCommon {
int max_blocks_;
int threads_per_block_;
int major_sm_;
size_t max_share_memory_;
};
#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads)
#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num()
#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm()
#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size()
#define MINIUM_SM 6
#define RECOMMEND_SM 7
} // namespace gpu

View File

@ -0,0 +1,82 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_topk():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(3, 4).astype(np.float32)
k = 4
ms_output = P.TopK(False)(Tensor(x_np), k)
assert np.allclose(ms_output[0].asnumpy(), x_np)
x_np = np.random.rand(2, 3, 4).astype(np.float32)
k = 2
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 1024).astype(np.float32)
k = 512
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
# sorted elements num greater than max thread per block
x_np = np.random.rand(512, 2048).astype(np.float32)
k = 1
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 2048).astype(np.float32)
k = 2048
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
# sorted elements num greater than max share memory per block
x_np = np.random.rand(512, 40960).astype(np.float32)
k = 1
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 40960).astype(np.float32)
k = 40960
ms_output = P.TopK(True)(Tensor(x_np), k)
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
assert np.allclose(ms_output[0].asnumpy(), np_output)
x_np = np.random.rand(512, 40960).astype(np.float32)
k = 40960
ms_output = P.TopK(False)(Tensor(x_np), k)
assert np.allclose(ms_output[0].asnumpy(), x_np)