forked from mindspore-Ecosystem/mindspore
!11802 Optimize GPU kernels for fasterrcnn
From: @robingrosman Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ed5f9cb1f2
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
@ -27,7 +28,7 @@ namespace kernel {
|
|||
template <typename T, typename S>
|
||||
class TopKGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {}
|
||||
TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), input_shape_size_(0) {}
|
||||
~TopKGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -40,26 +41,17 @@ class TopKGpuKernel : public GpuKernel {
|
|||
S *k = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S *indices = GetDeviceAddress<S>(outputs, 1);
|
||||
T *data_buff = nullptr;
|
||||
S *index_buff = nullptr;
|
||||
if (use_share_mem_ == false) {
|
||||
data_buff = GetDeviceAddress<T>(workspaces, 0);
|
||||
index_buff = GetDeviceAddress<S>(workspaces, 1);
|
||||
}
|
||||
const T init_k = std::numeric_limits<T>::lowest();
|
||||
|
||||
TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
if (sorted_ == false) {
|
||||
BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
FastTopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, init_k,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
input_shape_size_ = input_shapes.size();
|
||||
for (size_t i = 0; i < input_shapes.size() - 1; i++) {
|
||||
outer_size_ *= input_shapes[i];
|
||||
}
|
||||
|
@ -68,13 +60,6 @@ class TopKGpuKernel : public GpuKernel {
|
|||
|
||||
sorted_ = GetAttr<bool>(kernel_node, "sorted");
|
||||
|
||||
ceil_power2_ = RoundUpPower2(inner_size_);
|
||||
size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S));
|
||||
if (buffer_size > SHARED_MEM_PER_BLOCK) {
|
||||
use_share_mem_ = false;
|
||||
MS_LOG(INFO) << "CUDA share memory not enough, sort with RAM";
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -85,10 +70,6 @@ class TopKGpuKernel : public GpuKernel {
|
|||
input_size_list_.push_back(sizeof(S));
|
||||
output_size_list_.push_back(outer_size_ * k_ * sizeof(T));
|
||||
output_size_list_.push_back(outer_size_ * k_ * sizeof(S));
|
||||
if (use_share_mem_ == false) {
|
||||
workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T));
|
||||
workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -96,8 +77,7 @@ class TopKGpuKernel : public GpuKernel {
|
|||
size_t outer_size_;
|
||||
size_t inner_size_;
|
||||
size_t k_;
|
||||
bool use_share_mem_;
|
||||
size_t ceil_power2_;
|
||||
int input_shape_size_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
@ -106,4 +86,4 @@ class TopKGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // TopKpuKernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,10 @@
|
|||
#define BLOCKSIZE 256
|
||||
#define MAX_DIMENSION 5
|
||||
|
||||
template <typename T, typename S, typename K>
|
||||
void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2,
|
||||
const int &d3, const int &d4, const int &d5, const int &seedc, const int &count,
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
|
||||
|
||||
// Kernel started from here
|
||||
#define L2_RCWM_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \
|
||||
do { \
|
||||
L2Rcwm<T, S, K, NUM_WARP_Q, NUM_THREAD_Q, BLOCK, IS_DESCEND> \
|
||||
<<<1, BLOCK, 0, stream>>>(seedc, input_size, input, output_mask, output_index, k); \
|
||||
} while (0)
|
||||
|
||||
#define LEFT_INSERT_THREAD_QUEUE(_k, _v) \
|
||||
do { \
|
||||
if (is_descend ? Cmp<T>::gt(_k, warp_K_top) : Cmp<T>::lt(_k, warp_K_top)) { \
|
||||
{ \
|
||||
_Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \
|
||||
threadK[i] = threadK[i - 1]; \
|
||||
threadV[i] = threadV[i - 1]; \
|
||||
} \
|
||||
} \
|
||||
threadK[0] = _k; \
|
||||
threadV[0] = _v; \
|
||||
++num_vals; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename T, typename S, typename K, int warp_queue, int thread_queue, int threads_per_block, bool is_descend>
|
||||
__global__ void L2Rcwm(int seedc, int input_size, const K *input, K *output_mask, S *output_index, int k) {
|
||||
constexpr int kNumWarps = threads_per_block / kWarpSize;
|
||||
constexpr T init_K = static_cast<T>(-2.0);
|
||||
constexpr S init_V = static_cast<S>(0);
|
||||
|
||||
__shared__ T shared_K[kNumWarps * warp_queue];
|
||||
__shared__ S shared_V[kNumWarps * warp_queue];
|
||||
|
||||
curandState devState;
|
||||
curand_init(seedc, threadIdx.x, 0, &devState);
|
||||
|
||||
T threadK[thread_queue]; // NOLINT
|
||||
S threadV[thread_queue]; // NOLINT
|
||||
|
||||
T *warp_K;
|
||||
S *warp_V;
|
||||
|
||||
T warp_K_top = init_K;
|
||||
int k_minus_1 = k - 1;
|
||||
int num_vals = 0;
|
||||
int limit = (input_size / kWarpSize) * kWarpSize;
|
||||
int i = threadIdx.x;
|
||||
|
||||
// init begin
|
||||
_Pragma("unroll") for (int i = 0; i < thread_queue; ++i) {
|
||||
threadK[i] = init_K;
|
||||
threadV[i] = init_V;
|
||||
}
|
||||
|
||||
int laneId = GetLaneId();
|
||||
int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3
|
||||
|
||||
// warp shared memory start address
|
||||
warp_K = shared_K + warpId * warp_queue;
|
||||
warp_V = shared_V + warpId * warp_queue;
|
||||
|
||||
for (int i = laneId; i < warp_queue; i += kWarpSize) {
|
||||
warp_K[i] = init_K;
|
||||
warp_V[i] = init_V;
|
||||
}
|
||||
|
||||
// sync till all threads init done
|
||||
__syncwarp();
|
||||
|
||||
// insert begin
|
||||
for (; i < limit; i += threads_per_block) {
|
||||
T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K;
|
||||
LEFT_INSERT_THREAD_QUEUE(rand_num, i);
|
||||
|
||||
// CHECK_AND_MERGE_THREAD_QUEUE() begin
|
||||
bool needSort = (num_vals == thread_queue);
|
||||
needSort = __any_sync(0xffffffff, needSort);
|
||||
if (!needSort) continue;
|
||||
|
||||
MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V);
|
||||
|
||||
num_vals = 0;
|
||||
_Pragma("unroll") for (int i = 0; i < thread_queue; ++i) {
|
||||
threadK[i] = init_K;
|
||||
threadV[i] = init_V;
|
||||
}
|
||||
warp_K_top = warp_K[k_minus_1];
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
if (i < input_size) {
|
||||
T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K;
|
||||
LEFT_INSERT_THREAD_QUEUE(rand_num, i);
|
||||
}
|
||||
|
||||
// reduce begin
|
||||
MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V);
|
||||
__syncthreads();
|
||||
SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
|
||||
|
||||
// ship data from shared memory to output buffer
|
||||
for (int i = threadIdx.x; i < k; i += blockDim.x) {
|
||||
output_mask[i] = shared_K[i] > static_cast<T>(-1.0) ? true : false;
|
||||
output_index[i] = shared_V[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename K>
|
||||
void RCWMScaleK(int seedc, int input_size, K *input, int k, S *output_index, K *output_mask, cudaStream_t stream) {
|
||||
if (k <= 32) {
|
||||
// num-threads-of-block, warp-queue-size, thread-queue-size
|
||||
L2_RCWM_HELPER(256, 32, 2, true);
|
||||
} else if (k <= 64) {
|
||||
L2_RCWM_HELPER(256, 64, 3, true);
|
||||
} else if (k <= 128) {
|
||||
L2_RCWM_HELPER(256, 128, 3, true);
|
||||
} else if (k <= 256) {
|
||||
L2_RCWM_HELPER(256, 256, 4, true);
|
||||
} else if (k <= 512) {
|
||||
L2_RCWM_HELPER(256, 512, 8, true);
|
||||
} else if (k <= 1024) {
|
||||
L2_RCWM_HELPER(128, 1024, 8, true);
|
||||
} else if (k <= 2048) {
|
||||
L2_RCWM_HELPER(64, 2048, 8, true);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename K>
|
||||
void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask,
|
||||
cudaStream_t stream) {
|
||||
RCWMScaleK<T, S, K>(seedc, input_size, input, count, output_index, output_mask, stream);
|
||||
}
|
||||
|
||||
template void CalRandomChoiceWithMaskSmall<float, int, bool>(int input_size, int seedc, int count, bool *input,
|
||||
int *output_index, bool *output_mask, cudaStream_t stream);
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,148 +15,213 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh"
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
size_t RoundUpPower2(size_t v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
v++;
|
||||
return v;
|
||||
}
|
||||
const int kMaxQueue = 128;
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void Swap(T *lhs, T *rhs) {
|
||||
T tmp = lhs[0];
|
||||
lhs[0] = rhs[0];
|
||||
rhs[0] = tmp;
|
||||
}
|
||||
#define TOPK_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \
|
||||
do { \
|
||||
TopKBlock<T, S, NUM_WARP_Q, NUM_THREAD_Q, BLOCK, IS_DESCEND> \
|
||||
<<<block_num_limit, BLOCK, 0, stream>>>(outer_size, inner_size, input, output, output_index, k_cut, init_K); \
|
||||
} while (0)
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void TopkKernel(const size_t outer, const size_t inner, const size_t ceil_power2, const T *input, const S *k,
|
||||
T *output, S *indices, T *data_buff, S *index_buff) {
|
||||
// default: sort with share memory
|
||||
extern __shared__ T share_mem[];
|
||||
T *data_arr = share_mem;
|
||||
S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2);
|
||||
// sort with RAM
|
||||
if (data_buff != nullptr && index_buff != nullptr) {
|
||||
data_arr = data_buff + blockIdx.x * ceil_power2;
|
||||
index_arr = index_buff + blockIdx.x * ceil_power2;
|
||||
#define LEFT_INSERT_THREAD_QUEUE(_k, _v) \
|
||||
do { \
|
||||
if (is_descend ? CmpKV<T, S>::gt(_k, _v, (*ceil_K), (*ceil_V)) : CmpKV<T, S>::lt(_k, _v, (*ceil_K), (*ceil_V))) \
|
||||
break; \
|
||||
if (is_descend ? CmpKV<T, S>::gt(_k, _v, warp_K_top, warp_V_top) \
|
||||
: CmpKV<T, S>::lt(_k, _v, warp_K_top, warp_V_top)) { \
|
||||
{ \
|
||||
_Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \
|
||||
threadK[i] = threadK[i - 1]; \
|
||||
threadV[i] = threadV[i - 1]; \
|
||||
} \
|
||||
} \
|
||||
threadK[0] = _k; \
|
||||
threadV[0] = _v; \
|
||||
++num_vals; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend>
|
||||
inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T *ceil_K, S *ceil_V, int laneId) {
|
||||
constexpr int kNumWarps = threads_per_block / kWarpSize; // kNumWarps is 1024/32=32
|
||||
|
||||
// find last_K, which is max of last element of warp queue
|
||||
T last_K = shared_K[laneId * warp_queue + warp_queue - 1];
|
||||
S last_V = shared_V[laneId * warp_queue + warp_queue - 1];
|
||||
|
||||
__syncwarp();
|
||||
|
||||
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
|
||||
// kNumWarps is 32 if block size is 1024
|
||||
T other_K = __shfl_down_sync(0xffffffff, last_K, offset);
|
||||
S other_V = __shfl_down_sync(0xffffffff, last_V, offset);
|
||||
|
||||
bool is_greater = CmpKV<T, S>::gt(other_K, other_V, last_K, last_V);
|
||||
ConditionalAssign(is_greater, &last_K, other_K);
|
||||
ConditionalAssign(is_greater, &last_V, other_V);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (laneId == 0) {
|
||||
*ceil_K = last_K;
|
||||
*ceil_V = last_V;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// calculate index cut by last_K
|
||||
int L = 0;
|
||||
int R = warp_queue;
|
||||
while (L < R) {
|
||||
int m = (L + R) / 2;
|
||||
CmpKV<T, S>::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V))
|
||||
? L = m + 1
|
||||
: R = m;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// merge top number which value is greater than last_K
|
||||
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
|
||||
R += __shfl_down_sync(0xffffffff, R, offset);
|
||||
}
|
||||
|
||||
for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
|
||||
index_arr[i] = i;
|
||||
__syncwarp();
|
||||
|
||||
if (laneId == 0) {
|
||||
watermark[0] = R;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend>
|
||||
inline __device__ void TopKStep(const int &outer_size, const int &inner_size, const T *input, T *output,
|
||||
S *output_index, S k_cut, const T &init_K, const int &outer_id, T *shared_K,
|
||||
S *shared_V, int *watermark, T *threadK, S *threadV, T *ceil_K, S *ceil_V, S *k_prime) {
|
||||
constexpr int kNumWarps = threads_per_block / kWarpSize;
|
||||
constexpr S init_V = static_cast<S>(-1);
|
||||
|
||||
T *warp_K;
|
||||
S *warp_V;
|
||||
|
||||
T warp_K_top = init_K;
|
||||
S warp_V_top = init_V;
|
||||
int k_minus_1 = (k_cut <= kMaxQueue ? k_cut - 1 : kMaxQueue - 1);
|
||||
int num_vals = 0;
|
||||
int limit = (inner_size / kWarpSize) * kWarpSize;
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < thread_queue; ++i) {
|
||||
threadK[i] = init_K;
|
||||
threadV[i] = init_V;
|
||||
}
|
||||
|
||||
int laneId = GetLaneId();
|
||||
int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3
|
||||
|
||||
warp_K = shared_K + warpId * warp_queue;
|
||||
warp_V = shared_V + warpId * warp_queue;
|
||||
|
||||
for (int i = laneId; i < warp_queue; i += kWarpSize) {
|
||||
warp_K[i] = init_K;
|
||||
warp_V[i] = init_V;
|
||||
}
|
||||
|
||||
__syncwarp();
|
||||
|
||||
int i = threadIdx.x;
|
||||
for (; i < limit; i += threads_per_block) {
|
||||
LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i));
|
||||
|
||||
bool needSort = (num_vals == thread_queue);
|
||||
needSort = __any_sync(0xffffffff, needSort);
|
||||
if (!needSort) continue;
|
||||
|
||||
MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V);
|
||||
|
||||
num_vals = 0;
|
||||
_Pragma("unroll") for (int i = 0; i < thread_queue; ++i) {
|
||||
threadK[i] = init_K;
|
||||
threadV[i] = init_V;
|
||||
}
|
||||
warp_K_top = warp_K[k_minus_1];
|
||||
warp_V_top = warp_V[k_minus_1];
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
if (i < inner_size) {
|
||||
LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i));
|
||||
}
|
||||
|
||||
MergeWarpQueue<T, S, warp_queue, thread_queue, is_descend>(threadK, threadV, warp_K, warp_V);
|
||||
__syncthreads();
|
||||
|
||||
if (k_cut > kMaxQueue && warpId == 0) {
|
||||
TopKInBuffer<T, S, warp_queue, thread_queue, threads_per_block, is_descend>(shared_K, shared_V, watermark, ceil_K,
|
||||
ceil_V, laneId);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
|
||||
for (size_t j = (i >> 1); j > 0; j >>= 1) {
|
||||
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
|
||||
size_t tid_comp = tid ^ j;
|
||||
if (tid_comp > tid) {
|
||||
if ((tid & i) == 0) {
|
||||
if (data_arr[tid] > data_arr[tid_comp]) {
|
||||
Swap(&data_arr[tid], &data_arr[tid_comp]);
|
||||
Swap(&index_arr[tid], &index_arr[tid_comp]);
|
||||
}
|
||||
} else {
|
||||
if (data_arr[tid] < data_arr[tid_comp]) {
|
||||
Swap(&data_arr[tid], &data_arr[tid_comp]);
|
||||
Swap(&index_arr[tid], &index_arr[tid_comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
|
||||
|
||||
for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) {
|
||||
output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1];
|
||||
indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void TopK(const size_t &outer, const size_t &inner, const T *input, const S *k, T *output, S *indices, T *data_buff,
|
||||
S *index_buff, cudaStream_t stream) {
|
||||
size_t ceil_power2 = RoundUpPower2(inner);
|
||||
size_t share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0;
|
||||
size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS));
|
||||
TopkKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, k, output, indices, data_buff,
|
||||
index_buff);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void BitonicSortByKeyKernel(const size_t outer, const size_t inner, const size_t ceil_power2, T *input,
|
||||
S *indices, T *data_buff, S *index_buff) {
|
||||
// default: sort with share memory
|
||||
extern __shared__ T share_mem[];
|
||||
T *data_arr = share_mem;
|
||||
S *index_arr = reinterpret_cast<S *>(data_arr + ceil_power2);
|
||||
// sort with RAM
|
||||
if (data_buff != nullptr && index_buff != nullptr) {
|
||||
data_arr = data_buff + blockIdx.x * ceil_power2;
|
||||
index_arr = index_buff + blockIdx.x * ceil_power2;
|
||||
}
|
||||
|
||||
for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits<T>::max();
|
||||
index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits<S>::max();
|
||||
S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime);
|
||||
for (int i = threadIdx.x; i < k_step; i += blockDim.x) {
|
||||
output[outer_id * k_cut + (*k_prime) + i] = shared_K[i];
|
||||
output_index[outer_id * k_cut + (*k_prime) + i] = shared_V[i] % inner_size;
|
||||
}
|
||||
*k_prime += k_step;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
|
||||
for (size_t j = (i >> 1); j > 0; j >>= 1) {
|
||||
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
|
||||
size_t tid_comp = tid ^ j;
|
||||
if (tid_comp > tid) {
|
||||
if ((tid & i) == 0) {
|
||||
if (index_arr[tid] > index_arr[tid_comp]) {
|
||||
Swap(&data_arr[tid], &data_arr[tid_comp]);
|
||||
Swap(&index_arr[tid], &index_arr[tid_comp]);
|
||||
}
|
||||
} else {
|
||||
if (index_arr[tid] < index_arr[tid_comp]) {
|
||||
Swap(&data_arr[tid], &data_arr[tid_comp]);
|
||||
Swap(&index_arr[tid], &index_arr[tid_comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
template <typename T, typename S, int warp_queue, int thread_queue, int threads_per_block, bool is_descend>
|
||||
__global__ void TopKBlock(int outer_size, int inner_size, const T *input, T *output, S *output_index, S k_cut,
|
||||
const T init_K) {
|
||||
constexpr int kNumWarps = threads_per_block / kWarpSize;
|
||||
|
||||
for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) {
|
||||
input[blockIdx.x * inner + tid] = data_arr[tid];
|
||||
indices[blockIdx.x * inner + tid] = index_arr[tid];
|
||||
__shared__ T shared_K[kNumWarps * warp_queue];
|
||||
__shared__ S shared_V[kNumWarps * warp_queue];
|
||||
__shared__ int watermark[1];
|
||||
__shared__ T ceil_K;
|
||||
__shared__ S ceil_V;
|
||||
|
||||
T threadK[thread_queue]; // NOLINT
|
||||
S threadV[thread_queue]; // NOLINT
|
||||
|
||||
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < blockDim.x * outer_size;
|
||||
t_idx += blockDim.x * gridDim.x) {
|
||||
S k_prime = 0;
|
||||
int outer_id = t_idx / blockDim.x;
|
||||
ceil_K = -init_K;
|
||||
ceil_V = -1;
|
||||
watermark[0] = k_cut;
|
||||
do {
|
||||
TopKStep<T, S, warp_queue, thread_queue, threads_per_block, is_descend>(
|
||||
outer_size, inner_size, input, output, output_index, k_cut, init_K, outer_id, shared_K, shared_V, watermark,
|
||||
threadK, threadV, &ceil_K, &ceil_V, &k_prime);
|
||||
} while (k_prime < k_cut);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
cudaStream_t stream) {
|
||||
size_t ceil_power2 = RoundUpPower2(inner);
|
||||
size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S));
|
||||
if (share_mem > SHARED_MEM_PER_BLOCK) {
|
||||
share_mem = 0;
|
||||
void FastTopK(const int outer_size, const int inner_size, const T *input, const S *k, T *output, S *output_index,
|
||||
const T init_K, cudaStream_t stream) {
|
||||
int block_num_limit = outer_size < 128 ? outer_size : 128;
|
||||
S k_cut = 0;
|
||||
cudaMemcpy(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost);
|
||||
if (k_cut > inner_size) k_cut = inner_size;
|
||||
|
||||
if (k_cut <= 32) {
|
||||
// num-threads-of-block, warp-queue-size, thread-queue-size
|
||||
TOPK_HELPER(256, 32, 2, true);
|
||||
} else if (k_cut <= 64) {
|
||||
TOPK_HELPER(256, 64, 3, true);
|
||||
} else if (k_cut <= 128) {
|
||||
TOPK_HELPER(256, 128, 3, true);
|
||||
} else {
|
||||
data_buff = nullptr;
|
||||
index_buff = nullptr;
|
||||
TOPK_HELPER(1024, 128, 3, true);
|
||||
}
|
||||
size_t thread_num = std::min(ceil_power2, static_cast<size_t>(GET_THREADS));
|
||||
BitonicSortByKeyKernel<<<outer, thread_num, share_mem, stream>>>(outer, inner, ceil_power2, input, indices, data_buff,
|
||||
index_buff);
|
||||
}
|
||||
|
||||
template void TopK(const size_t &outer, const size_t &inner, const float *input_addr, const int *k, float *output,
|
||||
int *indices, float *data_buff, int *index_buff, cudaStream_t stream);
|
||||
template void BitonicSortByKey(const size_t &outer, const size_t &inner, float *input, int *indices, float *data_buff,
|
||||
int *index_buff, cudaStream_t stream);
|
||||
template void FastTopK(const int outer_size, const int inner_size, const float *input, const int *k, float *output,
|
||||
int *output_index, const float init_K, cudaStream_t stream);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,19 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void TopK(const size_t &outer, const size_t &inner, const T *input_addr, const S *k, T *output, S *indices,
|
||||
T *data_buff, S *index_buff, cudaStream_t stream);
|
||||
void FastTopK(const int outer, const int inner, const T *input_addr, const S *k, T *output, S *indices, const T initK,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff,
|
||||
cudaStream_t stream);
|
||||
size_t RoundUpPower2(size_t v);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_
|
||||
|
|
|
@ -0,0 +1,479 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
constexpr __host__ __device__ int Log2(int n, int p = 0) { return (n <= 1) ? p : Log2(n / 2, p + 1); }
|
||||
constexpr __host__ __device__ bool IsPow2(int v) { return (v && !(v & (v - 1))); }
|
||||
constexpr __host__ __device__ int NextPow2(int v) { return (IsPow2(v) ? 2 * v : (1 << static_cast<int>(Log2(v) + 1))); }
|
||||
|
||||
__device__ __forceinline__ int GetLaneId() {
|
||||
int laneId;
|
||||
asm("mov.u32 %0, %%laneid;" : "=r"(laneId));
|
||||
return laneId;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
struct CmpKV {
|
||||
__device__ static inline bool gt(T k1, S v1, T k2, S v2) { return k1 > k2 || (k1 == k2 && v1 < v2); }
|
||||
__device__ static inline bool lt(T k1, S v1, T k2, S v2) { return k1 < k2 || (k1 == k2 && v1 > v2); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Cmp {
|
||||
__device__ static inline bool lt(T a, T b) { return a < b; }
|
||||
__device__ static inline bool gt(T a, T b) { return a > b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T shfl_xor(const T val, int laneMask, int width = kWarpSize) {
|
||||
return __shfl_xor_sync(0xffffffff, val, laneMask, width);
|
||||
}
|
||||
|
||||
template <typename T, typename S, bool is_descend>
|
||||
inline __device__ void L2CompareAndSwap(T *a, S *b, int i_1, int i_2) {
|
||||
bool swap =
|
||||
is_descend ? CmpKV<T, S>::gt(a[i_1], b[i_1], a[i_2], b[i_2]) : CmpKV<T, S>::lt(a[i_1], b[i_1], a[i_2], b[i_2]);
|
||||
|
||||
if (!swap) return;
|
||||
|
||||
T a_tmp = a[i_1];
|
||||
a[i_1] = a[i_2];
|
||||
a[i_2] = a_tmp;
|
||||
|
||||
T b_tmp = b[i_1];
|
||||
b[i_1] = b[i_2];
|
||||
b[i_2] = b_tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void ConditionalAssign(bool is_assign, T *x, const T &y) {
|
||||
(*x) = is_assign ? y : (*x);
|
||||
}
|
||||
|
||||
// Merge pairs of lists smaller than threads-per-block
|
||||
// NumThreads is 128
|
||||
// N is 2, 1 etc
|
||||
// L is 32, 64 etc
|
||||
template <int NumThreads, typename T, typename S, int N, int L, bool AllThreads, bool is_descend, bool FullMerge>
|
||||
inline __device__ void BlockSortSmallK(T *list_k, S *list_v) {
|
||||
int mergeId = threadIdx.x / L;
|
||||
int tid = threadIdx.x % L;
|
||||
|
||||
list_k += 2 * L * mergeId;
|
||||
list_v += 2 * L * mergeId;
|
||||
|
||||
int pos = L - 1 - tid;
|
||||
int stride = 2 * tid + 1;
|
||||
|
||||
if (AllThreads || (static_cast<int>(threadIdx.x) < N * L)) {
|
||||
L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
_Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) {
|
||||
int pos = 2 * tid - (tid & (stride - 1));
|
||||
|
||||
if (AllThreads || (static_cast<int>(threadIdx.x) < N * L)) {
|
||||
L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// Merge pairs of lists larger than threads-per-block
|
||||
template <int NumThreads, typename T, typename S, int L, bool is_descend, bool FullMerge>
|
||||
inline __device__ void BlockSortBigK(T *list_k, S *list_v) {
|
||||
constexpr int kLoopPerThread = L / NumThreads;
|
||||
|
||||
_Pragma("unroll") for (int loop = 0; loop < kLoopPerThread; ++loop) {
|
||||
int tid = loop * NumThreads + threadIdx.x;
|
||||
int pos = L - 1 - tid;
|
||||
int stride = 2 * tid + 1;
|
||||
|
||||
L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2;
|
||||
|
||||
_Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) {
|
||||
_Pragma("unroll") for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
|
||||
int tid = loop * NumThreads + threadIdx.x;
|
||||
int pos = 2 * tid - (tid & (stride - 1));
|
||||
L2CompareAndSwap<T, S, is_descend>(list_k, list_v, pos, pos + stride);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// Merging lists smaller than threads-per-block
|
||||
template <int NumThreads, typename T, typename S, int N, int L, bool is_descend, bool FullMerge = true>
|
||||
inline __device__ void SortBlockStep(T *list_k, S *list_v) {
|
||||
if (L <= NumThreads) {
|
||||
int kNumParallelMerges = NumThreads / L;
|
||||
int kNumIterations = N / kNumParallelMerges;
|
||||
|
||||
if (N < kNumParallelMerges) {
|
||||
BlockSortSmallK<NumThreads, T, S, N, L, false, is_descend, FullMerge>(list_k, list_v);
|
||||
} else {
|
||||
_Pragma("unroll") for (int i = 0; i < kNumIterations; ++i) {
|
||||
int start = i * kNumParallelMerges * 2 * L;
|
||||
BlockSortSmallK<NumThreads, T, S, N, L, true, is_descend, FullMerge>(list_k + start, list_v + start);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_Pragma("unroll") for (int i = 0; i < N; ++i) {
|
||||
int start = i * 2 * L;
|
||||
BlockSortBigK<NumThreads, T, S, L, is_descend, FullMerge>(list_k + start, list_v + start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Block-wide merge
|
||||
template <int NumWarps, int NumThreads, typename T, typename S, int warp_queue, bool is_descend>
|
||||
inline __device__ void SortBlockWide(T *shared_K, S *shared_V) {
|
||||
if (NumWarps == 2) {
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend, false>(shared_K, shared_V);
|
||||
} else if (NumWarps == 4) {
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend, false>(shared_K,
|
||||
shared_V);
|
||||
} else if (NumWarps == 8) {
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend, false>(shared_K,
|
||||
shared_V);
|
||||
} else if (NumWarps == 16) {
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 16), warp_queue * 8, !is_descend>(shared_K, shared_V);
|
||||
} else if (NumWarps == 32) {
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 2), warp_queue, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 4), warp_queue * 2, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 8), warp_queue * 4, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 16), warp_queue * 8, !is_descend>(shared_K, shared_V);
|
||||
SortBlockStep<NumThreads, T, S, NumThreads / (kWarpSize * 32), warp_queue * 16, !is_descend>(shared_K, shared_V);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, int L, bool is_descend, bool IsBitonic>
|
||||
inline __device__ void BitonicSortWarpLE16(T *k, S *v) {
|
||||
int laneId = GetLaneId();
|
||||
|
||||
if (!IsBitonic) {
|
||||
// Reverse the first comparison stage. head-tail swap.
|
||||
T other_K = shfl_xor((*k), 2 * L - 1);
|
||||
S other_V = shfl_xor((*v), 2 * L - 1);
|
||||
|
||||
bool small = !(laneId & L);
|
||||
bool small_compare = small ? CmpKV<T, S>::gt((*k), (*v), other_K, other_V) :
|
||||
CmpKV<T, S>::lt((*k), (*v), other_K, other_V);
|
||||
bool small_compare_descend = is_descend ? small_compare : !small_compare;
|
||||
ConditionalAssign(small_compare_descend, k, other_K);
|
||||
ConditionalAssign(small_compare_descend, v, other_V);
|
||||
}
|
||||
|
||||
_Pragma("unroll") for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
|
||||
T other_K = shfl_xor((*k), stride);
|
||||
S other_V = shfl_xor((*v), stride);
|
||||
|
||||
bool small = !(laneId & stride);
|
||||
bool small_compare = small ? CmpKV<T, S>::gt((*k), (*v), other_K, other_V) :
|
||||
CmpKV<T, S>::lt((*k), (*v), other_K, other_V);
|
||||
bool small_compare_descend = is_descend ? small_compare : !small_compare;
|
||||
ConditionalAssign(small_compare_descend, k, other_K);
|
||||
ConditionalAssign(small_compare_descend, v, other_V);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, int N, bool is_descend, bool Low, bool Pow2>
|
||||
struct MergeWarpStepBitonic {};
|
||||
|
||||
// All merges call this
|
||||
template <typename T, typename S, bool is_descend, bool Low>
|
||||
struct MergeWarpStepBitonic<T, S, 1, is_descend, Low, true> {
|
||||
static inline __device__ void merge(T k[1], S v[1]) { BitonicSortWarpLE16<T, S, 16, is_descend, true>(&k[0], &v[0]); }
|
||||
};
|
||||
|
||||
template <typename T, typename S, int N, bool is_descend, bool Low>
|
||||
struct MergeWarpStepBitonic<T, S, N, is_descend, Low, true> {
|
||||
static inline __device__ void merge(T k[N], S v[N]) {
|
||||
_Pragma("unroll") for (int i = 0; i < N / 2; ++i) { L2CompareAndSwap<T, S, is_descend>(k, v, i, i + N / 2); }
|
||||
|
||||
{
|
||||
T newK[N / 2];
|
||||
S newV[N / 2];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N / 2; ++i) {
|
||||
newK[i] = k[i];
|
||||
newV[i] = v[i];
|
||||
}
|
||||
|
||||
MergeWarpStepBitonic<T, S, N / 2, is_descend, true, true>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N / 2; ++i) {
|
||||
k[i] = newK[i];
|
||||
v[i] = newV[i];
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
T newK[N / 2];
|
||||
S newV[N / 2];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N / 2; ++i) {
|
||||
newK[i] = k[i + N / 2];
|
||||
newV[i] = v[i + N / 2];
|
||||
}
|
||||
|
||||
MergeWarpStepBitonic<T, S, N / 2, is_descend, false, true>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N / 2; ++i) {
|
||||
k[i + N / 2] = newK[i];
|
||||
v[i + N / 2] = newV[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Low recursion
|
||||
template <typename T, typename S, int N, bool is_descend>
|
||||
struct MergeWarpStepBitonic<T, S, N, is_descend, true, false> {
|
||||
static inline __device__ void merge(T k[N], S v[N]) {
|
||||
constexpr int kNextHighestPowerOf2 = NextPow2(N);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
|
||||
L2CompareAndSwap<T, S, is_descend>(k, v, i, i + kNextHighestPowerOf2 / 2);
|
||||
}
|
||||
|
||||
constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
|
||||
constexpr int kHighSize = kNextHighestPowerOf2 / 2;
|
||||
{
|
||||
T newK[kLowSize];
|
||||
S newV[kLowSize];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kLowSize; ++i) {
|
||||
newK[i] = k[i];
|
||||
newV[i] = v[i];
|
||||
}
|
||||
|
||||
constexpr bool kLowIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2);
|
||||
MergeWarpStepBitonic<T, S, kLowSize, is_descend, true, kLowIsPowerOf2>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kLowSize; ++i) {
|
||||
k[i] = newK[i];
|
||||
v[i] = newV[i];
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
T newK[kHighSize];
|
||||
S newV[kHighSize];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kHighSize; ++i) {
|
||||
newK[i] = k[i + kLowSize];
|
||||
newV[i] = v[i + kLowSize];
|
||||
}
|
||||
|
||||
constexpr bool kHighIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2);
|
||||
MergeWarpStepBitonic<T, S, kHighSize, is_descend, false, kHighIsPowerOf2>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kHighSize; ++i) {
|
||||
k[i + kLowSize] = newK[i];
|
||||
v[i + kLowSize] = newV[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// High recursion
|
||||
template <typename T, typename S, int N, bool is_descend>
|
||||
struct MergeWarpStepBitonic<T, S, N, is_descend, false, false> {
|
||||
static inline __device__ void merge(T k[N], S v[N]) {
|
||||
constexpr int kNextHighestPowerOf2 = NextPow2(N);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
|
||||
L2CompareAndSwap<T, S, is_descend>(k, v, i, i + kNextHighestPowerOf2 / 2);
|
||||
}
|
||||
|
||||
constexpr int kLowSize = kNextHighestPowerOf2 / 2;
|
||||
constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
|
||||
{
|
||||
T newK[kLowSize];
|
||||
S newV[kLowSize];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kLowSize; ++i) {
|
||||
newK[i] = k[i];
|
||||
newV[i] = v[i];
|
||||
}
|
||||
|
||||
constexpr bool kLowIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2);
|
||||
MergeWarpStepBitonic<T, S, kLowSize, is_descend, true, kLowIsPowerOf2>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kLowSize; ++i) {
|
||||
k[i] = newK[i];
|
||||
v[i] = newV[i];
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
T newK[kHighSize];
|
||||
S newV[kHighSize];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kHighSize; ++i) {
|
||||
newK[i] = k[i + kLowSize];
|
||||
newV[i] = v[i + kLowSize];
|
||||
}
|
||||
|
||||
constexpr bool kHighIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2);
|
||||
MergeWarpStepBitonic<T, S, kHighSize, is_descend, false, kHighIsPowerOf2>::merge(newK, newV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kHighSize; ++i) {
|
||||
k[i + kLowSize] = newK[i];
|
||||
v[i + kLowSize] = newV[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Merges two sets of registers across the warp of any size;
|
||||
template <typename T, typename S, int N1, int N2, bool is_descend, bool FullMerge = true>
|
||||
inline __device__ void MergeWarpByRegister(T k1[N1], S v1[N1], T k2[N2], S v2[N2]) {
|
||||
constexpr int kSmallestN = N1 < N2 ? N1 : N2;
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kSmallestN; ++i) {
|
||||
T &ka = k1[N1 - 1 - i];
|
||||
S &va = v1[N1 - 1 - i];
|
||||
|
||||
T &kb = k2[i];
|
||||
S &vb = v2[i];
|
||||
|
||||
T other_Ka;
|
||||
S other_Va;
|
||||
|
||||
if (FullMerge) {
|
||||
other_Ka = shfl_xor(ka, kWarpSize - 1);
|
||||
other_Va = shfl_xor(va, kWarpSize - 1);
|
||||
}
|
||||
|
||||
T other_Kb = shfl_xor(kb, kWarpSize - 1);
|
||||
S other_Vb = shfl_xor(vb, kWarpSize - 1);
|
||||
|
||||
bool swapa = is_descend ? CmpKV<T, S>::gt(ka, va, other_Kb, other_Vb) : CmpKV<T, S>::lt(ka, va, other_Kb, other_Vb);
|
||||
ConditionalAssign(swapa, &ka, other_Kb);
|
||||
ConditionalAssign(swapa, &va, other_Vb);
|
||||
|
||||
if (FullMerge) {
|
||||
bool swapb = is_descend ? CmpKV<T, S>::lt(kb, vb, other_Ka, other_Va) :
|
||||
CmpKV<T, S>::gt(kb, vb, other_Ka, other_Va);
|
||||
ConditionalAssign(swapb, &kb, other_Ka);
|
||||
ConditionalAssign(swapb, &vb, other_Va);
|
||||
}
|
||||
}
|
||||
|
||||
MergeWarpStepBitonic<T, S, N1, is_descend, true, IsPow2(N1)>::merge(k1, v1);
|
||||
if (FullMerge) {
|
||||
MergeWarpStepBitonic<T, S, N2, is_descend, false, IsPow2(N2)>::merge(k2, v2);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursive template that uses the above bitonic merge
|
||||
template <typename T, typename S, int N, bool is_descend>
|
||||
struct SortWarpStepBitonic {
|
||||
static inline __device__ void Sort(T k[N], S v[N]) {
|
||||
constexpr int kSizeA = N / 2;
|
||||
constexpr int kSizeB = N - kSizeA;
|
||||
|
||||
T aK[kSizeA];
|
||||
S aV[kSizeA];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kSizeA; ++i) {
|
||||
aK[i] = k[i];
|
||||
aV[i] = v[i];
|
||||
}
|
||||
|
||||
// Recursive sort
|
||||
SortWarpStepBitonic<T, S, kSizeA, is_descend>::Sort(aK, aV);
|
||||
|
||||
T bK[kSizeB];
|
||||
S bV[kSizeB];
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kSizeB; ++i) {
|
||||
bK[i] = k[i + kSizeA];
|
||||
bV[i] = v[i + kSizeA];
|
||||
}
|
||||
|
||||
SortWarpStepBitonic<T, S, kSizeB, is_descend>::Sort(bK, bV);
|
||||
|
||||
// Merge halves
|
||||
MergeWarpByRegister<T, S, kSizeA, kSizeB, is_descend>(aK, aV, bK, bV);
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kSizeA; ++i) {
|
||||
k[i] = aK[i];
|
||||
v[i] = aV[i];
|
||||
}
|
||||
|
||||
_Pragma("unroll") for (int i = 0; i < kSizeB; ++i) {
|
||||
k[i + kSizeA] = bK[i];
|
||||
v[i + kSizeA] = bV[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename S, bool is_descend>
|
||||
struct SortWarpStepBitonic<T, S, 1, is_descend> {
|
||||
static inline __device__ void Sort(T k[1], S v[1]) {
|
||||
// up to warp-size/2
|
||||
BitonicSortWarpLE16<T, S, 1, is_descend, false>(&k[0], &v[0]);
|
||||
BitonicSortWarpLE16<T, S, 2, is_descend, false>(&k[0], &v[0]);
|
||||
BitonicSortWarpLE16<T, S, 4, is_descend, false>(&k[0], &v[0]);
|
||||
BitonicSortWarpLE16<T, S, 8, is_descend, false>(&k[0], &v[0]);
|
||||
BitonicSortWarpLE16<T, S, 16, is_descend, false>(&k[0], &v[0]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename S, int N, bool is_descend>
|
||||
inline __device__ void SortWarpByRegister(T k[N], S v[N]) {
|
||||
SortWarpStepBitonic<T, S, N, is_descend>::Sort(k, v);
|
||||
}
|
||||
|
||||
template <typename T, typename S, int warp_queue, int thread_queue, bool is_descend>
|
||||
inline __device__ void MergeWarpQueue(T *threadK, S *threadV, T *warp_K, S *warp_V) {
|
||||
int laneId = GetLaneId();
|
||||
SortWarpByRegister<T, S, thread_queue, !is_descend>(threadK, threadV);
|
||||
|
||||
constexpr int kWarpQueueRegisters = warp_queue / kWarpSize;
|
||||
T warp_KRegisters[kWarpQueueRegisters];
|
||||
S warp_VRegisters[kWarpQueueRegisters];
|
||||
_Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) {
|
||||
warp_KRegisters[i] = warp_K[i * kWarpSize + laneId];
|
||||
warp_VRegisters[i] = warp_V[i * kWarpSize + laneId];
|
||||
}
|
||||
__syncwarp();
|
||||
MergeWarpByRegister<T, S, kWarpQueueRegisters, thread_queue, !is_descend, false>(warp_KRegisters, warp_VRegisters,
|
||||
threadK, threadV);
|
||||
_Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) {
|
||||
warp_K[i * kWarpSize + laneId] = warp_KRegisters[i];
|
||||
warp_V[i * kWarpSize + laneId] = warp_VRegisters[i];
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,17 +39,22 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
S *output_index = GetDeviceAddress<S>(outputs, 0);
|
||||
T *output_mask = GetDeviceAddress<T>(outputs, 1);
|
||||
S *index_buff = GetDeviceAddress<S>(workspaces, 0);
|
||||
S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
|
||||
S *rank_buff = GetDeviceAddress<S>(workspaces, 2);
|
||||
S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3);
|
||||
S *tmp_buff = GetDeviceAddress<S>(workspaces, 4);
|
||||
void *States = GetDeviceAddress<void *>(workspaces, 5);
|
||||
curandState *devStates = reinterpret_cast<curandState *>(States);
|
||||
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2],
|
||||
input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask,
|
||||
index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (count_ > kSmallK || input_shape_size_ > 1) {
|
||||
S *index_buff = GetDeviceAddress<S>(workspaces, 0);
|
||||
S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
|
||||
S *rank_buff = GetDeviceAddress<S>(workspaces, 2);
|
||||
S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3);
|
||||
S *tmp_buff = GetDeviceAddress<S>(workspaces, 4);
|
||||
void *States = GetDeviceAddress<void *>(workspaces, 5);
|
||||
curandState *devStates = reinterpret_cast<curandState *>(States);
|
||||
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1],
|
||||
input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input,
|
||||
output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff,
|
||||
devStates, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -94,7 +99,9 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|||
}
|
||||
count_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "count"));
|
||||
// upper ceiling for input for ceil_power2
|
||||
ceil_power2_ = RcwmRoundUpPower2(input_size_);
|
||||
if (count_ > kSmallK || input_shape_size_ > 1) {
|
||||
ceil_power2_ = RcwmRoundUpPower2(input_size_);
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -104,16 +111,19 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S));
|
||||
output_size_list_.push_back(count_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE);
|
||||
workspace_size_list_.push_back(blocknum * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState));
|
||||
if (count_ > kSmallK || input_shape_size_ > 1) {
|
||||
workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE);
|
||||
workspace_size_list_.push_back(blocknum * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(S));
|
||||
workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const int kSmallK = 2048;
|
||||
int input_shape_size_;
|
||||
int seedc_;
|
||||
int input_size_;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class RCWM_count_in(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RCWM_count_in, self).__init__()
|
||||
|
@ -29,6 +30,7 @@ class RCWM_count_in(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.RCWM_count_in(x)
|
||||
|
||||
|
||||
class RCWM_count_out(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RCWM_count_out, self).__init__()
|
||||
|
@ -37,6 +39,7 @@ class RCWM_count_out(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.RCWM_count_out(x)
|
||||
|
||||
|
||||
class RCWM_3D(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RCWM_3D, self).__init__()
|
||||
|
@ -45,6 +48,16 @@ class RCWM_3D(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.RCWM_3D(x)
|
||||
|
||||
|
||||
class RCWM_1D(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RCWM_1D, self).__init__()
|
||||
self.RCWM_1D = P.RandomChoiceWithMask(count=10, seed=9)
|
||||
|
||||
def construct(self, x):
|
||||
return self.RCWM_1D(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -58,12 +71,14 @@ def test_RCWM_3D():
|
|||
assert output1.shape == expect1
|
||||
assert output2.shape == expect2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_RCWM_count_out():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
|
||||
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1],
|
||||
[0, 0, 0, 1]]).astype(np.bool))
|
||||
expect1 = (10, 2)
|
||||
expect2 = (10,)
|
||||
rcwm = RCWM_count_out()
|
||||
|
@ -71,15 +86,36 @@ def test_RCWM_count_out():
|
|||
assert output1.shape == expect1
|
||||
assert output2.shape == expect2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_RCWM_count_in():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
|
||||
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1],
|
||||
[0, 0, 0, 1]]).astype(np.bool))
|
||||
expect1 = (4, 2)
|
||||
expect2 = (4,)
|
||||
rcwm = RCWM_count_in()
|
||||
output1, output2 = rcwm(input_tensor)
|
||||
assert output1.shape == expect1
|
||||
assert output2.shape == expect2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_RCWM_1D():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(
|
||||
np.array([1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]).astype(np.bool))
|
||||
expect_index = np.array([[11], [9], [2], [15], [10], [7],
|
||||
[8], [0], [0], [0]]).astype(np.int32)
|
||||
expect_mask = np.array(
|
||||
[True, True, True, True, True, True, True, True, False, False])
|
||||
rcwm = RCWM_1D()
|
||||
output1, output2 = rcwm(input_tensor)
|
||||
print(output1.asnumpy())
|
||||
print(output2)
|
||||
assert np.array_equal(output1.asnumpy(), expect_index)
|
||||
assert np.array_equal(output2.asnumpy(), expect_mask)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-21 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_topk():
|
||||
def test_topk_small_2d():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
|
@ -36,7 +36,20 @@ def test_topk():
|
|||
x_np = np.random.rand(3, 4).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = P.TopK(False)(Tensor(x_np), k)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_topk_3d():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_np = np.random.rand(2, 256, 128).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
k = 2
|
||||
|
@ -44,6 +57,12 @@ def test_topk():
|
|||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_topk_big_2d():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_np = np.random.rand(512, 1024).astype(np.float32)
|
||||
k = 512
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
|
@ -51,32 +70,69 @@ def test_topk():
|
|||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
# sorted elements num greater than max thread per block
|
||||
x_np = np.random.rand(512, 2048).astype(np.float32)
|
||||
x_np = np.random.rand(128, 2048).astype(np.float32)
|
||||
k = 1
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 2048).astype(np.float32)
|
||||
x_np = np.random.rand(32, 2048).astype(np.float32)
|
||||
k = 2048
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
# sorted elements num greater than max share memory per block
|
||||
x_np = np.random.rand(512, 40960).astype(np.float32)
|
||||
x_np = np.random.rand(16, 40960).astype(np.float32)
|
||||
k = 1
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 40960).astype(np.float32)
|
||||
k = 40960
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_topk_big_k():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_np = np.random.rand(8, 40960).astype(np.float32)
|
||||
k = 4096
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(512, 40960).astype(np.float32)
|
||||
k = 40960
|
||||
ms_output = P.TopK(False)(Tensor(x_np), k)
|
||||
assert np.allclose(ms_output[0].asnumpy(), x_np)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_topk_1d():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_np = np.random.rand(12).astype(np.float32)
|
||||
k = 4
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np)[::-1][0:k]
|
||||
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
x_np = np.random.rand(1200).astype(np.float32)
|
||||
k = 256
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np)[::-1][0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(250000).astype(np.float32)
|
||||
k = 2000
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np)[::-1][0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(10240).astype(np.float32)
|
||||
k = 4096
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np)[::-1][0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy(), np_output)
|
||||
|
||||
x_np = np.random.rand(720).astype(np.float32)
|
||||
k = 720
|
||||
ms_output = P.TopK(True)(Tensor(x_np), k)
|
||||
np_output = np.sort(x_np)[::-1][0:k]
|
||||
assert np.allclose(ms_output[0].asnumpy()[:k], np_output)
|
||||
|
|
Loading…
Reference in New Issue