refactor to reduce the amount of code wrapped in template syntax (#2002)
This commit is contained in:
parent
318d143224
commit
bd8db2a771
|
@ -21,6 +21,59 @@ METAL_FUNC uint get_strided_index(
|
|||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void argmin(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T *shared_memory,
|
||||
threadgroup uint *shared_indices
|
||||
) {
|
||||
bool notset = true;
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || src[strided_i] < shared_memory[tid]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
shared_memory[tid] = shared_memory[tid + s];
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0){
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
|
@ -35,53 +88,71 @@ kernel void NAME( \
|
|||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
|
||||
} \
|
||||
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void argmax(
|
||||
constant size_t & num_dims,
|
||||
constant size_t * dims,
|
||||
constant size_t * strides,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device uint * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T * shared_memory,
|
||||
threadgroup uint * shared_indices
|
||||
) {
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
bool notset = true;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || shared_memory[tid] < src[strided_i]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
shared_memory[tid] = shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
/*
|
||||
// Thread 0 writes the result of the reduction
|
||||
*/
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
|
@ -95,223 +166,279 @@ kernel void NAME( \
|
|||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t & num_dims,
|
||||
constant size_t * dims,
|
||||
constant size_t * strides,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T * shared_memory,
|
||||
T (*fn)(T, T)
|
||||
) {
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
T x = shared_memory[tid];
|
||||
T y = src[strided_i];
|
||||
shared_memory[tid] = fn(x, y);
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
T x = shared_memory[tid];
|
||||
T y = shared_memory[tid + s];
|
||||
shared_memory[tid] = fn(x, y);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_memory[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = START; \
|
||||
reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void softmax(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
float tmp = -INFINITY;
|
||||
while (idx < stop_idx) {
|
||||
tmp = MAX(tmp, float(src[idx]));
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp;
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
constant float &eps, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = 0; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = tmp + float(src[idx]) * float(src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
|
||||
float inv_norm = 1.0f / norm; \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
float val = float(src[idx]) * inv_norm; \
|
||||
if (alpha != nullptr) { \
|
||||
val *= float(alpha[idx - start_idx]); \
|
||||
} \
|
||||
dst[idx] = T(val); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float _max = shared_memory[0];
|
||||
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
const float val = exp(float(src[idx]) - _max);
|
||||
dst[idx] = T(val);
|
||||
shared_memory[tid] += val;
|
||||
idx += block_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
const T inv_acc = T(1.0 / shared_memory[0]);
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void rmsnorm(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
device const T * alpha,
|
||||
constant float & eps,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
float tmp = 0;
|
||||
while (idx < stop_idx) {
|
||||
tmp = tmp + float(src[idx]) * float(src[idx]);
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
|
||||
float inv_norm = 1.0f / norm;
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
float val = float(src[idx]) * inv_norm;
|
||||
if (alpha != nullptr) {
|
||||
val *= float(alpha[idx - start_idx]);
|
||||
}
|
||||
dst[idx] = T(val);
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
constant float &eps, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
device T *dst,
|
||||
uint tid
|
||||
) {
|
||||
if (2 * tid >= bh * td) {
|
||||
return;
|
||||
}
|
||||
size_t rope_idx = tid % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void rope(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &d,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
device T *dst,
|
||||
uint idx
|
||||
) {
|
||||
if (2 * idx >= bh * td) {
|
||||
return;
|
||||
}
|
||||
size_t i_bh = idx / (td / 2);
|
||||
size_t i_td = idx - (td / 2) * i_bh;
|
||||
size_t i_t = i_td / (d / 2);
|
||||
size_t i_d = i_td - (d / 2) * i_t;
|
||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||
size_t i2 = i1 + d / 2;
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
|
||||
kernel void FN_NAME_I( \
|
||||
|
@ -323,14 +450,7 @@ kernel void FN_NAME_I( \
|
|||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (2 * tid >= bh * td) { \
|
||||
return; \
|
||||
} \
|
||||
size_t rope_idx = tid % (td / 2); \
|
||||
TYPENAME c = cos[rope_idx]; \
|
||||
TYPENAME s = sin[rope_idx]; \
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
|
||||
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
|
||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
|
@ -342,20 +462,7 @@ kernel void FN_NAME( \
|
|||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (2 * idx >= bh * td) { \
|
||||
return; \
|
||||
} \
|
||||
size_t i_bh = idx / (td / 2); \
|
||||
size_t i_td = idx - (td / 2) * i_bh; \
|
||||
size_t i_t = i_td / (d / 2); \
|
||||
size_t i_d = i_td - (d / 2) * i_t; \
|
||||
size_t i1 = i_bh * td + i_t * d + i_d; \
|
||||
size_t i2 = i1 + d / 2; \
|
||||
size_t i_cs = i_t * (d / 2) + i_d; \
|
||||
TYPENAME c = cos[i_cs]; \
|
||||
TYPENAME s = sin[i_cs]; \
|
||||
dst[i1] = src[i1] * c - src[i2] * s; \
|
||||
dst[i2] = src[i1] * s + src[i2] * c; \
|
||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
|
|
Loading…
Reference in New Issue