forked from mindspore-Ecosystem/mindspore
!49775 Fix topk on cuda11.6
Merge pull request !49775 from RobinGrosman/topk_cuda116
This commit is contained in:
commit
8b61ee1640
|
@ -50,16 +50,23 @@ template <typename T, typename S, int warp_queue, int thread_queue, int threads_
|
|||
inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T *ceil_K, S *ceil_V, int laneId) {
|
||||
constexpr int kNumWarps = threads_per_block / kWarpSize; // kNumWarps is 1024/32=32
|
||||
|
||||
// find last_K, which is max of last element of warp queue
|
||||
T last_K = shared_K[laneId * warp_queue + warp_queue - 1];
|
||||
S last_V = shared_V[laneId * warp_queue + warp_queue - 1];
|
||||
// If kNumWarps != kWarpSize, need to adjust this code as we are using lanes to aggregate kNumWArps
|
||||
// if kWarpSize > kNumWarps, each warp now has kWarpSize/kNumWarps threads when we only need kNumWarp
|
||||
constexpr int kWarpQueuePerLane = warp_queue * kNumWarps / kWarpSize;
|
||||
constexpr int kLanesPerWarp = kWarpSize / kNumWarps;
|
||||
|
||||
T last_K = shared_K[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1];
|
||||
S last_V = shared_V[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1];
|
||||
|
||||
__syncwarp();
|
||||
|
||||
// Find KCut:
|
||||
// - The last element of each warp is the lowest in that warp
|
||||
// --- If we have multiple lanes per warp look at last lane per warp
|
||||
// - k_cut will the higheset of last elements of each warp
|
||||
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
|
||||
// kNumWarps is 32 if block size is 1024
|
||||
T other_K = __shfl_down_sync(0xffffffff, last_K, offset);
|
||||
S other_V = __shfl_down_sync(0xffffffff, last_V, offset);
|
||||
T other_K = __shfl_down_sync(0xffffffff, last_K, offset * kLanesPerWarp);
|
||||
S other_V = __shfl_down_sync(0xffffffff, last_V, offset * kLanesPerWarp);
|
||||
|
||||
bool is_greater = CmpKV<T, S>::gt(other_K, other_V, last_K, last_V);
|
||||
ConditionalAssign(is_greater, &last_K, other_K);
|
||||
|
@ -67,25 +74,27 @@ inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T
|
|||
}
|
||||
__syncwarp();
|
||||
|
||||
if (laneId == 0) {
|
||||
// want to fetch last_K from last lane of first warp
|
||||
if (laneId == kLanesPerWarp - 1) {
|
||||
*ceil_K = last_K;
|
||||
*ceil_V = last_V;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// calculate index cut by last_K
|
||||
// calculate index cut by last_K. Do this per thread/lane instead of per warp
|
||||
int L = 0;
|
||||
int R = warp_queue;
|
||||
int R = kWarpQueuePerLane;
|
||||
while (L < R) {
|
||||
int m = (L + R) / 2;
|
||||
CmpKV<T, S>::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V))
|
||||
CmpKV<T, S>::gt(shared_K[laneId * kWarpQueuePerLane + m],
|
||||
shared_V[laneId * kWarpQueuePerLane + m], (*ceil_K), (*ceil_V))
|
||||
? L = m + 1
|
||||
: R = m;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// merge top number which value is greater than last_K
|
||||
for (int offset = kNumWarps / 2; offset > 0; offset /= 2) {
|
||||
// R is calculated per thread --> sum over all threads and not just all warps
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
R += __shfl_down_sync(0xffffffff, R, offset);
|
||||
}
|
||||
|
||||
|
@ -164,7 +173,9 @@ inline __device__ void TopKStep(const int &outer_size, const int &inner_size, co
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
SortBlockWide<kNumWarps, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
|
||||
// Wide sort doesn't sort properly if kNumWarp != kWarpSize, so pass kWarpsize
|
||||
SortBlockWide<kWarpSize, threads_per_block, T, S, warp_queue, is_descend>(shared_K, shared_V);
|
||||
__syncthreads();
|
||||
|
||||
S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime);
|
||||
for (int i = threadIdx.x; i < k_step; i += blockDim.x) {
|
||||
|
@ -218,7 +229,8 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cu
|
|||
} else if (k_cut <= 128) {
|
||||
TOPK_HELPER(256, 128, 3, true);
|
||||
} else {
|
||||
TOPK_HELPER(1024, 128, 3, true);
|
||||
// cuda 11.6 has lower # threads. Set lower number for all platforms for consistency
|
||||
TOPK_HELPER(256, 128, 3, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue