[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262)

This commit is contained in:
Douglas Lehr 2024-03-10 17:27:45 -05:00 committed by GitHub
parent 0bba88df03
commit e4a28e5316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 11 deletions

View File

@ -15,9 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
@ -31,11 +28,6 @@
#include <algorithm>
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

View File

@ -1,5 +1,15 @@
#pragma once
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else

View File

@ -24,7 +24,7 @@ namespace vllm {
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}
@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}