mirror of https://github.com/vllm-project/vllm
[Kernel] Tuned int8 Cutlass Kernels for SM75 (T4) (#6996)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
93548eb37e
commit
35e9c12bfa
|
@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
# pytorch impl
|
# pytorch impl - bfloat16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||||
|
|
||||||
|
# pytorch impl - float16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a.to(dtype=torch.float16, device="cuda"),
|
||||||
|
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
|
||||||
|
torch.float16, label, sub_label, pytorch_mm_impl,
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
|
||||||
|
|
||||||
# cutlass impl
|
# cutlass impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
#include "scaled_mm_c2x.cuh"
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||||
|
@ -20,21 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
|
||||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
|
||||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
|
||||||
|
|
||||||
if (out.dtype() == torch::kBFloat16) {
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
return vllm::cutlass_gemm_caller<
|
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
|
||||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm75, vllm::enable_sm75_to_sm80,
|
Epilogue>(
|
||||||
int8_t, cutlass::bfloat16_t, Epilogue, TileShape,
|
|
||||||
WarpShape, InstructionShape, 2>>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
return vllm::cutlass_gemm_caller<vllm::cutlass_2x_gemm<
|
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||||
cutlass::arch::Sm75, vllm::enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
|
||||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
|
||||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "scaled_mm_c2x.cuh"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This file defines Gemm kernel configurations for SM75 based on the Gemm
|
||||||
|
* shape.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename> typename Epilogue>
|
||||||
|
struct sm75_config_default {
|
||||||
|
// This config is used in 2 cases,
|
||||||
|
// - M in (256, inf]
|
||||||
|
// - M in (64, 128]
|
||||||
|
// Shared memory required by this Gemm 32768
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
using Cutlass2xGemm =
|
||||||
|
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||||
|
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename> typename Epilogue>
|
||||||
|
struct sm75_config_M256 {
|
||||||
|
// M in (128, 256]
|
||||||
|
// Shared memory required by this Gemm 65536
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
using Cutlass2xGemm =
|
||||||
|
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||||
|
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename> typename Epilogue>
|
||||||
|
struct sm75_config_M64 {
|
||||||
|
// M in (32, 64]
|
||||||
|
// Shared memory required by this Gemm 49152
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
using Cutlass2xGemm =
|
||||||
|
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||||
|
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename> typename Epilogue>
|
||||||
|
struct sm75_config_M32 {
|
||||||
|
// M in [1, 32]
|
||||||
|
// Shared memory required by this Gemm 49152
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
using Cutlass2xGemm =
|
||||||
|
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||||
|
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename> typename Epilogue,
|
||||||
|
typename... EpilogueArgs>
|
||||||
|
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
EpilogueArgs&&... args) {
|
||||||
|
static_assert(std::is_same<InType, int8_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
using Cutlass2xGemmDefault =
|
||||||
|
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
|
using Cutlass2xGemmM256 =
|
||||||
|
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
|
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
|
||||||
|
using Cutlass2xGemmM64 =
|
||||||
|
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
|
using Cutlass2xGemmM32 =
|
||||||
|
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||||
|
|
||||||
|
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||||
|
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||||
|
// in such cases.
|
||||||
|
// sm75_config_default has the least shared-memory requirements.
|
||||||
|
using FallbackGemm = Cutlass2xGemmDefault;
|
||||||
|
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||||
|
if (mp2 <= 32) {
|
||||||
|
// M in [1, 32]
|
||||||
|
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 64) {
|
||||||
|
// M in (32, 64]
|
||||||
|
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// M in (64, 128]
|
||||||
|
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else if (mp2 <= 256) {
|
||||||
|
// M in (128, 256]
|
||||||
|
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
} else {
|
||||||
|
// M in (256, inf)
|
||||||
|
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||||
|
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
Loading…
Reference in New Issue