mirror of https://github.com/vllm-project/vllm
[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262)
This commit is contained in:
parent
0bba88df03
commit
e4a28e5316
|
@ -15,9 +15,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
@ -31,11 +28,6 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace vllm {
|
|||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) {
|
|||
/* Calculate the sum of all elements in a block */
|
||||
template<typename T>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
static __shared__ T shared[WARP_SIZE];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
|
@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue