[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174)

This commit is contained in:
Lucas Wilkinson 2024-08-20 09:09:33 -04:00 committed by GitHub
parent b6f99a6ffe
commit 5288c06aa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 4828 additions and 2 deletions

3
.gitignore vendored
View File

@ -87,6 +87,9 @@ target/
profile_default/
ipython_config.py
# generated files
**/generated/**
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:

View File

@ -227,6 +227,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"-gencode arch=compute_90a,code=sm_90a")
endif()
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)
if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "Machete generation completed successfully.")
endif()
# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
# See comment above for scaled_mm_c3x (same if condition)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
${MACHETE_GEN_SOURCES}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
# Add pytorch binding
list(APPEND VLLM_EXT_SRC
csrc/quantization/machete/machete_pytorch.cu)
endif()
define_gpu_extension_target(

View File

@ -0,0 +1,372 @@
import argparse
import copy
import itertools
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # make col major
return ops.machete_prepack_B(w_q, wtype)
def make_bench_tensors(
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
k: int
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
torch.tensor]]]:
assert wtype.is_integer(), "TODO: support floating point weights"
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
weights = [
torch.randn((k, n), device="cuda", dtype=atype)
for _ in range(num_weights)
]
quanitized_weights = [
quantize_weights(w, wtype, group_size) for w in weights
]
return a, quanitized_weights
# impl
# bench
def bench_fn(label: str, sub_label: str, description: str,
fn: Callable) -> TMeasurement:
min_run_time = 1
return TBenchmark.Timer(
stmt="fn()",
globals={
"fn": fn
},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def loop_over_weights(
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
torch.tensor, torch.tensor]],
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
None]):
for w_ref, w_q, w_s, _ in weights:
fn(a, w_ref, w_q, w_s)
def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
for w_ref, w_q, w_s, w_zp in weights]
timers = []
# pytorch impl
timers.append(
bench_fn(
label, sub_label, "torch.matmul", lambda: loop_over_weights(
a,
weights,
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
)))
if benchmark_marlinv1:
w_ref = weights[0][0]
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
wtype.size_bits)
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
marlinv1_permute_scales(w_s), w_zp)
for w_ref, w_q, w_s, w_zp in weights]
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
# marlinv1
timers.append(
bench_fn(
label, sub_label, "marlin_orig", lambda: loop_over_weights(
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
gptq_marlin_gemm(a,
w_q,
w_s,
w_zp_empty,
g_idx,
sort_indices,
workspace.scratch,
wtype,
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
# machete
timers.append(
bench_fn(
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
if sweep_schedules:
print("Finding best schedule for machete")
best = None
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
w_q,
wtype,
w_s,
b_group_size=group_size,
schedule=schedule)
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
best_schedule = schedule
print("Best schedule:", best_schedule)
timers.append(best)
return timers
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype, sweep_schedules: bool,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
scalar_types.uint4b8,
128,
m,
k,
n,
f"{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
sweep_schedules=sweep_schedules)
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, args.sweep_schedules, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "bfloat16":
return torch.bfloat16
if dt == "float16":
return torch.float16
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['bfloat16', 'float16']",
)
parser.add_argument(
"--sweep-schedules",
action="store_true",
help="Run a sweep over all supported schedules",
)
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)

View File

@ -0,0 +1,64 @@
import math
import pickle
import re
from collections import defaultdict
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('filename', type=str)
args = parser.parse_args()
with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
results = defaultdict(lambda: list())
for v in data:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
else:
raise Exception("MKN not found")
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
if result is not None:
M = result.group(1)
else:
raise Exception("MKN not found")
kernel = v.task_spec.description
results[KN].append({
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten()
axs_idx = 0
for shape, data in results.items():
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
x="batch_size",
y="median",
hue="kernel",
style="kernel",
markers=True,
dashes=False,
palette="Dark2")
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout()
plt.savefig("graph_machete_bench.pdf")

View File

@ -0,0 +1,43 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}

View File

@ -1,5 +1,15 @@
#pragma once
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define HOST_INLINE inline
#endif
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);

View File

@ -0,0 +1,68 @@
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>)
return true;
else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template <class PointerType>
static constexpr auto get_logical_ptr(PointerType* ptr) {
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
return cute::subbyte_iterator<PointerType>(ptr);
} else {
return ptr;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}
}; // namespace cute

View File

@ -0,0 +1,154 @@
#pragma once
#include <torch/all.h>
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;
namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}
}; // namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}
}; // namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
return tensor.size(idx);
else
return int64_t(1);
});
return make_layout(shape, stride);
}
template <typename Stride>
static inline auto maybe_make_cute_layout(
c10::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}
//
// Torch Type to Cutlass Type (equivalent_cutlass_type)
//
template <typename T>
struct equivalent_cutlass_type {
using type = T;
};
template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
using type = cutlass::bfloat16_t;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
};
template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@ -0,0 +1,43 @@
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, class Enable = void>
struct VLLMCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType> {
using CollectiveOp = typename CollectiveBuilder<
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective

View File

@ -0,0 +1,50 @@
#pragma once
#include "cutlass/integer_subbyte.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed = false>
struct vllm_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
using Base = integer_subbyte<Bits, Signed>;
using Storage = typename Base::Storage;
using xint_t = typename Base::xint_t;
using Base::bits_mask_;
using Base::sign_mask_;
using Base::storage;
//
// Methods
//
/// No operation
vllm_biased_integer_subbyte() = default;
/// Conversion from integer type
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value)
: Base(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8
using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed>
struct sizeof_bits<vllm_biased_integer_subbyte<Bits, Bias, Signed>> {
static constexpr int value = Bits;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -0,0 +1,49 @@
import enum
from typing import Dict, Union
from cutlass_library import *
#
# Extend cutlass library with custom types, and missing values
#
class VLLMDataType(enum.Enum):
u4b8 = enum_auto()
u8b128 = enum_auto()
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedMixedInput = enum_auto()
TmaWarpSpecializedPingpongMixedInput = enum_auto()
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeNames, # type: ignore
**{
VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128",
}
}
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeTag, # type: ignore
**{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
}
}
VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
}
}

View File

@ -0,0 +1,795 @@
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace cutlass {
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
class Enable = void>
struct InterleavedNumericArrayConverter {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
CUTE_INVALID_CONTROL_PATH(
"InterleavedNumericArrayConverter not implemented\n");
return {};
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round>
struct InterleavedNumericArrayConverter<
IlvBlkLayout, T, S, N, Round,
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return Converter::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
using result_packed_8_t = Array<T, 8>;
using result_packed_4_t = Array<T, 4>;
using result_packed_2_t = Array<T, 2>;
using src_packed_8_t = Array<S, 8>;
using src_packed_4_t = Array<S, 4>;
using src_packed_2_t = Array<S, 2>;
static_assert(N % 2 == 0, "N must be a multiple of 2");
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
static constexpr auto src_elems_per_32bit_reg =
32 / cutlass::sizeof_bits_v<S>;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc>
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
if constexpr (sizeof(PackedSrc) == 1) {
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
} else if constexpr (sizeof(PackedSrc) == 2) {
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
} else {
static_assert(sizeof(PackedSrc) == 4);
return reinterpret_cast<const uint32_t&>(source);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
static_assert(PackedResultType::kElements == 2 ||
PackedResultType::kElements == 4 ||
PackedResultType::kElements == 8,
"Invalid PackedResultType must be 2, 4 or 8.");
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
ArrayConverterPacked32Bit<RegConvert32bit,
typename result_type::Element,
typename source_type::Element, N>;
if constexpr (src_elems_per_32bit_reg >= 8) {
detail::VectorizedConverter::convert<
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
} else if constexpr (src_elems_per_32bit_reg >= 4) {
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
} else {
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
src_packed_2_t>(result, source);
}
return result;
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4,
"Too many inputs for F16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static constexpr uint32_t xor_mask = 0x64006400;
static constexpr uint32_t and_mask = 0xFFF0FF0F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, vllm_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<uint4_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint8b128_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(r[ii])
: "r"(src), "n"(start_byte_for_fp16),
"r"(prmt_indices[ii]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static constexpr uint32_t bias_rep = 0x64806480;
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hsub2(fp16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
using result_type = Array<float, N>;
using source_type = Array<vllm_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
}
return r;
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for uint4b8_t -> BF16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias =
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, vllm_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<uint4_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
using src_packed_4_t = Array<vllm_uint8b128_t, 4>;
using src_packed_2_t = Array<vllm_uint8b128_t, 2>;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round>;
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
platform::is_same<PackedResultType, result_packed_4_t>::value),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch.");
NumericArrayConverter<float, vllm_uint8b128_t, PackedResultType::kElements,
Round>
convert_uint8_to_f32;
Array<float, PackedResultType::kElements> tmp =
convert_uint8_to_f32(source);
NumericArrayConverter<cutlass::bfloat16_t, float,
PackedResultType::kElements, Round>
convert_f32_to_bf16_;
return convert_f32_to_bf16_(tmp);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
NumericArrayConverter<typename result_type::Element,
typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -83,6 +83,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,

View File

@ -0,0 +1,45 @@
# Machete (Mixed Precision Cutlass-Based GEMM)
Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin.
## Overview
Machete effectively performs
```
scale_type = w_s.dtype
compute_type = a.dtype
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
```
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
`w_z` is the quantization zeropoints.
> **_NOTE:_** `w_z` is added after the scales so we can
use FMA operations, but this means they must have the scales pre-applied if the
supplied zeropoints assume that they will be subtracted before the scales are
applied.
## API
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
```
from vllm import _custom_ops as ops
...
W_q_packed = ops.machete_prepack_B(w_q, wtype)
output = ops.machete_gemm(
a,
b_q=W_q_packed,
b_type=wtype,
b_scales=w_s,
b_group_size=group_size
)
```
## Code Generation
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.

View File

@ -0,0 +1,446 @@
import itertools
import math
import os
import shutil
from collections.abc import Iterable
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import jinja2
# yapf conflicts with isort for this block
# yapf: disable
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType, VLLMDataType,
VLLMDataTypeNames, VLLMDataTypeTag,
VLLMKernelScheduleTag)
# yapf: enable
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
using GemmDispatcher_ = GemmDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
{% for s in schedules %}extern torch::Tensor
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
{% endfor %}
template <>
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
[[maybe_unused]] auto M = args.A.size(0);
[[maybe_unused]] auto N = args.B.size(1);
[[maybe_unused]] auto K = args.A.size(1);
if (!args.schedule) {
{%- for cond, s in heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
}
{% for s in schedules %}
if (*args.schedule == "{{ gen_sch_name(s) }}") {
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
}
{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
"schedule = ", *args.schedule);
}
template <>
std::vector<std::string> GemmDispatcher_::supported_schedules() {
return {
{% for s in schedules -%}
"{{ gen_sch_name(s) }}"{{ ",
" if not loop.last }}{%- endfor %}
};
}
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
using Kernel = MacheteKernelTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
Config, with_C, with_scales, with_zeropoints>;
{% for sch in schedules %}
{% set schedule_name = gen_sch_name(sch) -%}
struct sch_{{schedule_name}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
torch::Tensor
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
with_zeropoints = args.zeros.has_value();
{% for s in specializations %}
if (with_C == {{s.with_C|lower}}
&& with_zeropoints == {{s.with_zeropoints|lower}}
&& with_scales == {{s.with_scales|lower}}) {
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
}{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(
false, "for the sake of compile times and binary size machete_mm(..) is "
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
", with_zeropoints=", with_zeropoints,
" (for {{type_name}}_sch_{{schedule_name}})");
}
{% endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
using PrepackBDispatcher_ = PrepackBDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
using PrepackedLayoutB = PrepackedLayoutBTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
template <>
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
return prepack_impl<PrepackedLayoutB>(B);
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass
class ScheduleConfig:
tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass
class TypeConfig:
element_a: DataType
element_b: Union[DataType, VLLMDataType]
element_b_scale: DataType
element_b_zeropoint: DataType
element_d: DataType
accumulator: DataType
@dataclass
class Specialization:
with_C: bool
with_zeropoints: bool
with_scales: bool
@dataclass
class ImplConfig:
type_config: TypeConfig
schedule_configs: List[ScheduleConfig]
specializations: List[Specialization]
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
)
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
f"x{schedule_config.cluster_shape_mnk[1]}" +
f"x{schedule_config.cluster_shape_mnk[2]}")
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
.split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[
schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
.split("::")[-1]
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
f"_{epilogue_schedule}_{tile_scheduler}")
# mostly unique shorter schedule_name
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
schedule_name = generate_schedule_name(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
schedule_name = schedule_name.replace(orig, terse)
return schedule_name
# unique type_name
def generate_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
element_zeropoint = VLLMDataTypeNames[
kernel_type_config.element_b_zeropoint]
return (f"{element_a}{element_b}{element_d}"
f"{accumulator}{element_scale}{element_zeropoint}")
# non-unique shorter type_name
def generate_terse_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
return f"{element_a}{element_b}"
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: List[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
template_globals = {
"DataTypeTag": VLLMDataTypeTag,
"KernelScheduleTag": VLLMKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_name": generate_terse_schedule_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=2):
sources = []
type_name = generate_type_signature(impl_config.type_config)
terse_type_name = generate_terse_type_signature(impl_config.type_config)
sources.append((
f"machete_mm_{terse_type_name}",
mm_dispatch_template.render(type_name=type_name,
type_config=impl_config.type_config,
schedules=impl_config.schedule_configs,
heuristic=impl_config.heuristic),
))
sources.append((
f"machete_prepack_{terse_type_name}",
prepack_dispatch_template.render(
type_name=type_name,
type_config=impl_config.type_config,
),
))
num_schedules = len(impl_config.schedule_configs)
schedules_per_file = math.ceil(num_schedules / num_impl_files)
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
sources.append((
f"machete_mm_{terse_type_name}_impl_part{part}",
mm_impl_template.render(
type_name=type_name,
type_config=impl_config.type_config,
schedules=file_schedules,
specializations=impl_config.specializations,
),
))
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
schedules = [
ScheduleConfig(
tile_shape_mn=tile_shape_mn,
cluster_shape_mnk=cluster_shape_mnk,
kernel_schedule=kernel_schedule,
epilogue_schedule=epilogue_schedule,
tile_scheduler=tile_scheduler,
) for tile_shape_mn, cluster_shape_mnk in (
((128, 16), (1, 1, 1)),
((128, 32), (1, 1, 1)),
((128, 64), (1, 1, 1)),
((128, 128), (1, 1, 1)),
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
for tile_scheduler in (TileSchedulerType.StreamK, )
]
# For now we use the same heuristic for all types
default_heuristic = [
("M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
("M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
("M > 16",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
(None,
ScheduleConfig(tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK))
]
impl_configs = []
GPTQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for element_a in (DataType.f16, DataType.bf16)))
GPTQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(GPTQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
AWQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (DataType.u4, DataType.u8)
for element_a in (DataType.f16, DataType.bf16)))
AWQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(AWQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for impl_config in impl_configs:
for filename, code in create_sources(impl_config):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()

View File

@ -0,0 +1,33 @@
#pragma once
#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective

View File

@ -0,0 +1,35 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,237 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete

View File

@ -0,0 +1,95 @@
#pragma once
#include <torch/all.h>
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
struct PyTorchArguments {
torch::Tensor const& A;
torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales;
c10::optional<torch::Tensor> const& zeros;
c10::optional<int64_t> group_size;
c10::optional<torch::Tensor> const& C;
c10::optional<double> alpha;
c10::optional<double> beta;
c10::optional<std::string> schedule;
};
template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D =
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
return D;
};
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete

View File

@ -0,0 +1,62 @@
#pragma once
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
typename BTiledOutTensor>
static __global__ void prepack_B_kernel(BInTensor B_in,
BTiledOutTensor B_tiled_out) {
auto tB_in = local_tile(B_in, TileShapeNKL{},
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
auto tB_out = B_tiled_out(make_coord(_, _),
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
Layout<Shape<_1, _2>>{});
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B(cudaStream_t stream,
typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout,
typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
}
}; // namespace machete

View File

@ -0,0 +1,71 @@
#pragma once
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
auto device = B.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
// elements per storage item for B
auto eles_per_storage =
(B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto Bt_packed = B.t();
TORCH_CHECK(
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
template <typename ElementA, typename ElementB, typename ElementD,
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
typename ZeroT = cutlass::half_t>
struct PrepackBDispatcher {
static torch::Tensor dispatch(torch::Tensor B);
};
}; // namespace machete

View File

@ -0,0 +1,220 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementMma = MmaType;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// iterleaved layouts
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementA>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete

View File

@ -0,0 +1,79 @@
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
namespace machete {
using namespace vllm;
//
// Utils (type dispatching)
//
template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
if (type == vllm::kU4) {
return fn(cutlass::uint4b_t{});
} else if (type == vllm::kU8) {
return fn(cutlass::uint8_t{});
} else if (type == vllm::kU4B8) {
return fn(cutlass::vllm_uint4b8_t{});
} else if (type == vllm::kU8B128) {
return fn(cutlass::vllm_uint8b128_t{});
} else {
TORCH_CHECK(false, "Unsupported type ", type.str());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
.group_size = group_size,
.C = C,
.alpha = alpha,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
}
}; // namespace machete

View File

@ -133,6 +133,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules", &machete::supported_schedules);
ops.def(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);

View File

@ -0,0 +1,272 @@
"""Tests for the machete kernel.
Run `pytest tests/kernels/test_machete_gemm.py`.
"""
import math
from typing import Optional, Tuple
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(13, 8192, 4096),
(26, 4096, 8192),
(1, 4096, 4096),
(257, 128, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(64, 4096, 4096),
]
ACT_TYPES = [torch.float16, torch.bfloat16]
WTYPE_ZEROPOINTS = [
# GPTQ style
(scalar_types.uint4b8, False),
(scalar_types.uint8b128, False),
# AWQ style
(scalar_types.uint4, True),
(scalar_types.uint8, True),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
def rand_data(shape, dtype=torch.float16):
return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3)
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def machete_quantize_and_pack(w: torch.Tensor,
wtype: ScalarType,
group_size: int,
zero_points: bool = False):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
w,
wtype,
group_size,
zero_points=zero_points,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
return w_ref, w_q_machete, w_s, w_zp
def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
wtype: ScalarType, group_size: int,
zero_points: bool):
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)
output_ref = torch.matmul(a, w_ref)
output = ops.machete_gemm(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
@pytest.mark.parametrize("group_size", [128, None])
def test_machete_all_schedules(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
m, n, k = shape
wtype, zero_points = wtype_zeropoints
if group_size is not None and k % group_size != 0:
return
print(f"MNK = {m} {n} {k}")
# Normalize group_size
if group_size is None:
group_size = k
assert group_size <= k
a = rand_data((m, k), atype)
w = rand_data((k, n), atype)
w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
w, wtype, group_size, zero_points)
output_ref = torch.matmul(a, w_ref)
for schedule in ops.machete_supported_schedules(wtype):
output = ops.machete_gemm(
a,
b_q=w_q_machete,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
schedule=schedule,
)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
f"Schedule failed {schedule}"
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
@pytest.mark.parametrize("group_size", [128, None])
def test_machete_heuristic(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
m, n, k = shape
wtype, zero_points = wtype_zeropoints
if group_size is not None and k % group_size != 0:
return
# Normalize group_size
if group_size is None:
group_size = k
assert group_size <= k
a = rand_data((m, k), atype)
b = rand_data((k, n), atype)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
# Test working on other devices
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
m, n, k = 512, 4096, 4096
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False
print(f"MNK = {m} {n} {k}, device = {device}")
a = rand_data((m, k), torch.float16).to(device)
b = rand_data((k, n), torch.float16).to(device)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
# Test working with a subset of A and B
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
def test_machete_subset():
big_m, big_n, big_k = 1024, 1024, 1024
m, n, k = 512, 512, 512
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False
whole_a = rand_data((big_m, big_k), torch.float16)
whole_b = rand_data((big_k, big_n), torch.float16)
a = whole_a[0:m, 0:k]
b = whole_b[0:k, 0:n]
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
# Test to make sure cuda graphs work
class MacheteLayer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
def forward(self, a):
return ops.machete_gemm(**self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
def test_machete_cuda_graph():
m, n, k = 512, 4096, 4096
a = rand_data((m, k), torch.float16)
b = rand_data((k, n), torch.float16)
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)
# Construct a trivial model with a single layer that calls a machete kernel
model = MacheteLayer(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
)
output_ref = torch.matmul(a, w_ref)
# Run the model with a cuda graph
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
output = model(a)
output.zero_()
g.replay()
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)

View File

@ -329,6 +329,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
num_bits, size_m, size_n, size_k)
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type)
def machete_gemm(
a: torch.Tensor,
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
b_group_size, c, alpha, beta, schedule)
def machete_prepack_B(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
zero_points: bool = False):
zero_points: bool = False,
ref_zero_points_after_scales: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
@ -126,7 +127,13 @@ def quantize_weights(w: torch.Tensor,
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and zero_points:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias