mirror of https://github.com/vllm-project/vllm
Fix w8a8 benchmark and add Llama-3-8B (#5562)
This commit is contained in:
parent
845a3f26f9
commit
e2b85cf86a
|
@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
|||
# impl
|
||||
|
||||
|
||||
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
return torch.mm(a, b)
|
||||
|
@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_i8_impl,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
|
||||
# cutlass impl
|
||||
|
@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
|
@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
|||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||
torch.float16, label, sub_label, cutlass_impl,
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||
return timers
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,12 @@ WEIGHT_SHAPES = {
|
|||
([4096, 22016], 1),
|
||||
([11008, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3-8b": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf": [
|
||||
([5120, 15360], 1),
|
||||
([5120, 5120], 0),
|
||||
|
|
Loading…
Reference in New Issue