!49775 Fix topk on cuda11.6

Merge pull request !49775 from RobinGrosman/topk_cuda116
This commit is contained in:
i-robot 2023-03-08 02:45:51 +00:00 committed by Gitee
commit 8b61ee1640
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 26 additions and 14 deletions

View File

@ -50,16 +50,23 @@ template <typename T, typename S, int warp_queue, int thread_queue, int threads_
inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T *ceil_K, S *ceil_V, int laneId) {
constexpr int kNumWarps = threads_per_block / kWarpSize; // kNumWarps is 1024/32=32
// find last_K, which is max of last element of warp queue
T last_K = shared_K[laneId * warp_queue + warp_queue - 1];
S last_V = shared_V[laneId * warp_queue + warp_queue - 1];
// If kNumWarps != kWarpSize, need to adjust this code as we are using lanes to aggregate kNumWArps
// if kWarpSize > kNumWarps, each warp now has kWarpSize/kNumWarps threads when we only need kNumWarp
constexpr int kWarpQueuePerLane = warp_queue * kNumWarps / kWarpSize;
constexpr int kLanesPerWarp = kWarpSize / kNumWarps;
T last_K = shared_K[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1];
S last_V = shared_V[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1];
__syncwarp();
// Find KCut:
// - The last element of each warp is the lowest in that warp
// --- If we have multiple lanes per warp look at last lane per warp
// - k_cut will the higheset of last elements of each warp
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
// kNumWarps is 32 if block size is 1024
T other_K = __shfl_down_sync(0xffffffff, last_K, offset);
S other_V = __shfl_down_sync(0xffffffff, last_V, offset);
T other_K = __shfl_down_sync(0xffffffff, last_K, offset * kLanesPerWarp);
S other_V = __shfl_down_sync(0xffffffff, last_V, offset * kLanesPerWarp);
bool is_greater = CmpKV<T, S>::gt(other_K, other_V, last_K, last_V);
ConditionalAssign(is_greater, &last_K, other_K);
@ -67,25 +74,27 @@ inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T
}
__syncwarp();
if (laneId == 0) {
// want to fetch last_K from last lane of first warp
if (laneId == kLanesPerWarp - 1) {
*ceil_K = last_K;
*ceil_V = last_V;
}
__syncwarp();
// calculate index cut by last_K
// calculate index cut by last_K. Do this per thread/lane instead of per warp
int L = 0;
int R = warp_queue;
int R = kWarpQueuePerLane;
while (L < R) {
int m = (L + R) / 2;
CmpKV<T, S>::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V))
CmpKV<T, S>::gt(shared_K[laneId * kWarpQueuePerLane + m],
shared_V[laneId * kWarpQueuePerLane + m], (*ceil_K), (*ceil_V))
? L = m + 1
: R = m;
}
__syncwarp();
// merge top number which value is greater than last_K
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
// R is calculated per thread --> sum over all threads and not just all warps
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
R += __shfl_down_sync(0xffffffff, R, offset);
}
@ -164,7 +173,9 @@ inline __device__ void TopKStep(const int &outer_size, const int &inner_size, co
}
__syncthreads();
SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
// Wide sort doesn't sort properly if kNumWarp != kWarpSize, so pass kWarpsize
SortBlockWide<kWarpSize, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
__syncthreads();
S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime);
for (int i = threadIdx.x; i < k_step; i += blockDim.x) {
@ -218,7 +229,8 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cu
} else if (k_cut <= 128) {
TOPK_HELPER(256, 128, 3, true);
} else {
TOPK_HELPER(1024, 128, 3, true);
// cuda 11.6 has lower # threads. Set lower number for all platforms for consistency
TOPK_HELPER(256, 128, 3, true);
}
}