diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/topk_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/topk_impl.cu index e56af86ba13..bf3f2c33a9b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/topk_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/topk_impl.cu @@ -50,16 +50,23 @@ template kNumWarps, each warp now has kWarpSize/kNumWarps threads when we only need kNumWarp + constexpr int kWarpQueuePerLane = warp_queue * kNumWarps / kWarpSize; + constexpr int kLanesPerWarp = kWarpSize / kNumWarps; + + T last_K = shared_K[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1]; + S last_V = shared_V[laneId * kWarpQueuePerLane + kWarpQueuePerLane - 1]; __syncwarp(); + // Find KCut: + // - The last element of each warp is the lowest in that warp + // --- If we have multiple lanes per warp look at last lane per warp + // - k_cut will the higheset of last elements of each warp for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { - // kNumWarps is 32 if block size is 1024 - T other_K = __shfl_down_sync(0xffffffff, last_K, offset); - S other_V = __shfl_down_sync(0xffffffff, last_V, offset); + T other_K = __shfl_down_sync(0xffffffff, last_K, offset * kLanesPerWarp); + S other_V = __shfl_down_sync(0xffffffff, last_V, offset * kLanesPerWarp); bool is_greater = CmpKV::gt(other_K, other_V, last_K, last_V); ConditionalAssign(is_greater, &last_K, other_K); @@ -67,25 +74,27 @@ inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T } __syncwarp(); - if (laneId == 0) { + // want to fetch last_K from last lane of first warp + if (laneId == kLanesPerWarp - 1) { *ceil_K = last_K; *ceil_V = last_V; } __syncwarp(); - // calculate index cut by last_K + // calculate index cut by last_K. Do this per thread/lane instead of per warp int L = 0; - int R = warp_queue; + int R = kWarpQueuePerLane; while (L < R) { int m = (L + R) / 2; - CmpKV::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V)) + CmpKV::gt(shared_K[laneId * kWarpQueuePerLane + m], + shared_V[laneId * kWarpQueuePerLane + m], (*ceil_K), (*ceil_V)) ? L = m + 1 : R = m; } __syncwarp(); - // merge top number which value is greater than last_K - for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { + // R is calculated per thread --> sum over all threads and not just all warps + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { R += __shfl_down_sync(0xffffffff, R, offset); } @@ -164,7 +173,9 @@ inline __device__ void TopKStep(const int &outer_size, const int &inner_size, co } __syncthreads(); - SortBlockWide(shared_K, shared_V); + // Wide sort doesn't sort properly if kNumWarp != kWarpSize, so pass kWarpsize + SortBlockWide(shared_K, shared_V); + __syncthreads(); S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime); for (int i = threadIdx.x; i < k_step; i += blockDim.x) { @@ -218,7 +229,8 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cu } else if (k_cut <= 128) { TOPK_HELPER(256, 128, 3, true); } else { - TOPK_HELPER(1024, 128, 3, true); + // cuda 11.6 has lower # threads. Set lower number for all platforms for consistency + TOPK_HELPER(256, 128, 3, true); } }