refactor to reduce the amount of code wrapped in template syntax (#2002)

This commit is contained in:
Thomas Santerre 2024-04-04 02:13:12 -04:00 committed by GitHub
parent 318d143224
commit bd8db2a771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 368 additions and 261 deletions

View File

@ -21,6 +21,59 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048;
template<typename T>
METAL_FUNC void argmin(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device uint *dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T *shared_memory,
threadgroup uint *shared_indices
) {
bool notset = true;
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || src[strided_i] < shared_memory[tid]) {
shared_memory[tid] = src[strided_i];
/* Assume that the reduction takes place over the last dimension which is contiguous. */
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
shared_memory[tid] = shared_memory[tid + s];
} \
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0){
dst[dst_id] = shared_indices[0];
}
}
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
@ -35,53 +88,71 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
} \
template<typename T>
METAL_FUNC void argmax(
constant size_t & num_dims,
constant size_t * dims,
constant size_t * strides,
constant size_t & el_to_sum_per_block,
device const T * src,
device uint * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T * shared_memory,
threadgroup uint * shared_indices
) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
bool notset = true;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || shared_memory[tid] < src[strided_i]) {
shared_memory[tid] = src[strided_i];
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
shared_memory[tid] = shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
/*
// Thread 0 writes the result of the reduction
*/
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
@ -95,223 +166,279 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
} \
template<typename T>
METAL_FUNC void reduce(
constant size_t & num_dims,
constant size_t * dims,
constant size_t * strides,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T * shared_memory,
T (*fn)(T, T)
) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
T x = shared_memory[tid];
T y = src[strided_i];
shared_memory[tid] = fn(x, y);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
T x = shared_memory[tid];
T y = shared_memory[tid + s];
shared_memory[tid] = fn(x, y);
}
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared_memory[0];
}
}
#define REDUCE(FN, NAME, T, START) \
METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = START; \
reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \
} \
template<typename T>
METAL_FUNC void softmax(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
float tmp = -INFINITY;
while (idx < stop_idx) {
tmp = MAX(tmp, float(src[idx]));
idx += block_dim;
}
shared_memory[tid] = tmp;
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = 0; \
while (idx < stop_idx) { \
tmp = tmp + float(src[idx]) * float(src[idx]); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
float inv_norm = 1.0f / norm; \
idx = start_idx + tid; \
while (idx < stop_idx) { \
float val = float(src[idx]) * inv_norm; \
if (alpha != nullptr) { \
val *= float(alpha[idx - start_idx]); \
} \
dst[idx] = T(val); \
idx += block_dim; \
} \
} \
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
\
threadgroup_barrier(mem_flags::mem_threadgroup);
float _max = shared_memory[0];
/* prevent tid=0 from overwriting _max before other threads have written it */
threadgroup_barrier(mem_flags::mem_threadgroup);
shared_memory[tid] = 0;
idx = start_idx + tid;
while (idx < stop_idx) {
const float val = exp(float(src[idx]) - _max);
dst[idx] = T(val);
shared_memory[tid] += val;
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
const T inv_acc = T(1.0 / shared_memory[0]);
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += block_dim;
}
}
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void rmsnorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp = 0;
while (idx < stop_idx) {
tmp = tmp + float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
float inv_norm = 1.0f / norm;
idx = start_idx + tid;
while (idx < stop_idx) {
float val = float(src[idx]) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
constant size_t &td,
device const T *src,
device const T *cos,
device const T *sin,
device T *dst,
uint tid
) {
if (2 * tid >= bh * td) {
return;
}
size_t rope_idx = tid % (td / 2);
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
}
template<typename T>
METAL_FUNC void rope(
constant size_t &bh,
constant size_t &td,
constant size_t &d,
device const T *src,
device const T *cos,
device const T *sin,
device T *dst,
uint idx
) {
if (2 * idx >= bh * td) {
return;
}
size_t i_bh = idx / (td / 2);
size_t i_td = idx - (td / 2) * i_bh;
size_t i_t = i_td / (d / 2);
size_t i_d = i_td - (d / 2) * i_t;
size_t i1 = i_bh * td + i_t * d + i_d;
size_t i2 = i1 + d / 2;
size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
kernel void FN_NAME_I( \
@ -323,14 +450,7 @@ kernel void FN_NAME_I( \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (2 * tid >= bh * td) { \
return; \
} \
size_t rope_idx = tid % (td / 2); \
TYPENAME c = cos[rope_idx]; \
TYPENAME s = sin[rope_idx]; \
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
@ -342,20 +462,7 @@ kernel void FN_NAME( \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
if (2 * idx >= bh * td) { \
return; \
} \
size_t i_bh = idx / (td / 2); \
size_t i_td = idx - (td / 2) * i_bh; \
size_t i_t = i_td / (d / 2); \
size_t i_d = i_td - (d / 2) * i_t; \
size_t i1 = i_bh * td + i_t * d + i_d; \
size_t i2 = i1 + d / 2; \
size_t i_cs = i_t * (d / 2) + i_d; \
TYPENAME c = cos[i_cs]; \
TYPENAME s = sin[i_cs]; \
dst[i1] = src[i1] * c - src[i2] * s; \
dst[i2] = src[i1] * s + src[i2] * c; \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
}\
REDUCE(x + y, fast_sum_f32_strided, float, 0)