mirror of https://github.com/vllm-project/vllm
[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)
This commit is contained in:
parent
4c922709b6
commit
c9415c19d3
|
@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
|
|||
return val;
|
||||
}
|
||||
|
||||
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
|
||||
return warp_size - 1;
|
||||
}
|
||||
|
||||
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
|
||||
return 5 + (warp_size >> 6);
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<typename T>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[WARP_SIZE];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
|
||||
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
|
||||
int lane = threadIdx.x & LANE_MASK;
|
||||
int wid = threadIdx.x >> WID_SHIFT;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
|
|
Loading…
Reference in New Issue