[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)

This commit is contained in:
kliuae 2024-03-12 04:14:07 +08:00 committed by GitHub
parent 4c922709b6
commit c9415c19d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 2 deletions

View File

@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
return val;
}
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
return warp_size - 1;
}
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
return 5 + (warp_size >> 6);
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
int lane = threadIdx.x & LANE_MASK;
int wid = threadIdx.x >> WID_SHIFT;
val = warpReduceSum<T>(val);