!10856 Change topk shape size from int to size_t
From: @TFbunny Reviewed-by: @robingrosman Signed-off-by:
This commit is contained in:
commit
d27ba8ff2a
|
@ -93,11 +93,11 @@ class TopKGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
bool sorted_;
|
||||
int outer_size_;
|
||||
int inner_size_;
|
||||
int k_;
|
||||
size_t outer_size_;
|
||||
size_t inner_size_;
|
||||
size_t k_;
|
||||
bool use_share_mem_;
|
||||
int ceil_power2_;
|
||||
size_t ceil_power2_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
int RoundUpPower2(int v) {
|
||||
size_t RoundUpPower2(size_t v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
|
@ -37,7 +37,7 @@ __inline__ __device__ void Swap(T *lhs, T *rhs) {
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void TopkKernel(const int outer, const int inner, const int ceil_power2, const T *input, const S *k,
|
||||
__global__ void TopkKernel(const size_t outer, const size_t inner, const size_t ceil_power2, const T *input, const S *k,
|
||||
T *output, S *indices, T *data_buff, S *index_buff) {
|
||||
// default: sort with share memory
|
||||
extern __shared__ T share_mem[];
|
||||
|
@ -49,7 +49,7 @@ __global__ void TopkKernel(const int outer, const int inner, const int ceil_powe
|
|||
index_arr = index_buff + blockIdx.x * ceil_power2;
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
|
||||
index_arr[i] = i;
|
||||
}
|
||||
|
@ -84,17 +84,17 @@ __global__ void TopkKernel(const int outer, const int inner, const int ceil_powe
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void TopK(const int &outer, const int &inner, const T *input, const S *k, T *output, S *indices, T *data_buff,
|
||||
void TopK(const size_t &outer, const size_t &inner, const T *input, const S *k, T *output, S *indices, T *data_buff,
|
||||
S *index_buff, cudaStream_t stream) {
|
||||
int ceil_power2 = RoundUpPower2(inner);
|
||||
int share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0;
|
||||
int thread = std::min(ceil_power2, GET_THREADS);
|
||||
TopkKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff,
|
||||
index_buff);
|
||||
size_t ceil_power2 = RoundUpPower2(inner);
|
||||
size_t share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0;
|
||||
size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS));
|
||||
TopkKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff,
|
||||
index_buff);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input,
|
||||
__global__ void BitonicSortByKeyKernel(const size_t outer, const size_t inner, const size_t ceil_power2, T *input,
|
||||
S *indices, T *data_buff, S *index_buff) {
|
||||
// default: sort with share memory
|
||||
extern __shared__ T share_mem[];
|
||||
|
@ -106,9 +106,9 @@ __global__ void BitonicSortByKeyKernel(const int outer, const int inner, const i
|
|||
index_arr = index_buff + blockIdx.x * ceil_power2;
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
|
||||
index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max();;
|
||||
index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
@ -141,9 +141,9 @@ __global__ void BitonicSortByKeyKernel(const int outer, const int inner, const i
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
cudaStream_t stream) {
|
||||
int ceil_power2 = RoundUpPower2(inner);
|
||||
size_t ceil_power2 = RoundUpPower2(inner);
|
||||
size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S));
|
||||
if (share_mem > SHARED_MEM_PER_BLOCK) {
|
||||
share_mem = 0;
|
||||
|
@ -151,12 +151,12 @@ void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices,
|
|||
data_buff = nullptr;
|
||||
index_buff = nullptr;
|
||||
}
|
||||
int thread = std::min(ceil_power2, GET_THREADS);
|
||||
BitonicSortByKeyKernel<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff,
|
||||
index_buff);
|
||||
size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS));
|
||||
BitonicSortByKeyKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff,
|
||||
index_buff);
|
||||
}
|
||||
|
||||
template void TopK(const int &outer, const int &inner, const float *input_addr, const int *k, float *output,
|
||||
template void TopK(const size_t &outer, const size_t &inner, const float *input_addr, const int *k, float *output,
|
||||
int *indices, float *data_buff, int *index_buff, cudaStream_t stream);
|
||||
template void BitonicSortByKey(const int &outer, const int &inner, float *input, int *indices, float *data_buff,
|
||||
template void BitonicSortByKey(const size_t &outer, const size_t &inner, float *input, int *indices, float *data_buff,
|
||||
int *index_buff, cudaStream_t stream);
|
||||
|
|
|
@ -21,12 +21,12 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void TopK(const int &outer, const int &inner, const T *input_addr, const S *k, T *output, S *indices, T *data_buff,
|
||||
S *index_buff, cudaStream_t stream);
|
||||
void TopK(const size_t &outer, const size_t &inner, const T *input_addr, const S *k, T *output, S *indices,
|
||||
T *data_buff, S *index_buff, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
cudaStream_t stream);
|
||||
int RoundUpPower2(int v);
|
||||
size_t RoundUpPower2(size_t v);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
|
||||
|
|
Loading…
Reference in New Issue