Fixed all bugs. Improved code quality. Added tests.
This commit is contained in:
parent
077e781f53
commit
8babfe0411
|
@ -3,9 +3,9 @@ mod benchmarks;
|
|||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
//benchmarks::affine::benches,
|
||||
//benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::where_cond::benches
|
||||
//benchmarks::where_cond::benches
|
||||
);
|
||||
|
|
|
@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up));
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up));
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,6 +97,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
|||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
|
@ -105,19 +114,49 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
|||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "reduce_f32",
|
||||
DType::F16 => "reduce_f16",
|
||||
DType::BF16 => "reduce_bf16",
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
|
@ -140,20 +179,46 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
|
|||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "arg_reduce_f32",
|
||||
DType::F16 => "arg_reduce_f16",
|
||||
DType::BF16 => "arg_reduce_bf16",
|
||||
_ => "reduce",
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,346 @@
|
|||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_f32, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32, uint, 0)
|
||||
ARGMAX(fast_argmax_u8, uint8_t, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_i64, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
|
@ -622,7 +622,7 @@ fn cos_f16() {
|
|||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||
}
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
|
@ -630,10 +630,10 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
call_reduce_strided(
|
||||
match call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
|
@ -644,8 +644,13 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
|
@ -677,22 +682,114 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
const fn create_array<const N: usize>() -> [f32; N] {
|
||||
let mut array: [f32; N] = [0.0; N];
|
||||
let mut i = 1;
|
||||
while i <= N {
|
||||
array[i - 1] = i as f32;
|
||||
i += 1;
|
||||
}
|
||||
array
|
||||
}
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||
let mut sum = 0;
|
||||
let mut results: [f32; D] = [0.0; D];
|
||||
let mut i = 1;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
sum += i;
|
||||
i += 1;
|
||||
if i > j * N / D {
|
||||
results[j - 1] = sum as f32;
|
||||
j += 1;
|
||||
sum = 0;
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||
let mut max = 0.0;
|
||||
let mut max_index: u32 = 0;
|
||||
let mut results: [u32; D] = [0; D];
|
||||
let mut i = 0;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
if i >= (j * N / D) {
|
||||
results[j - 1] = max_index;
|
||||
max = 0.0;
|
||||
max_index = 0;
|
||||
j += 1;
|
||||
}
|
||||
if i == N {
|
||||
break;
|
||||
}
|
||||
if arr[i] > max {
|
||||
max = arr[i];
|
||||
max_index = i as u32;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results = run_reduce(&v, D, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||
}
|
||||
|
||||
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
|
||||
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
fn reduce_sum() {
|
||||
reduce_sum_case::<6, 1>();
|
||||
reduce_sum_case::<10, 1>();
|
||||
reduce_sum_case::<64, 1>();
|
||||
reduce_sum_case::<128, 1>();
|
||||
reduce_sum_case::<256, 1>();
|
||||
reduce_sum_case::<512, 1>();
|
||||
reduce_sum_case::<1024, 1>();
|
||||
reduce_sum_case::<2048, 1>();
|
||||
reduce_sum_case::<4096, 1>();
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
reduce_sum_case::<6, 2>();
|
||||
reduce_sum_case::<10, 2>();
|
||||
reduce_sum_case::<64, 2>();
|
||||
reduce_sum_case::<128, 2>();
|
||||
reduce_sum_case::<256, 2>();
|
||||
reduce_sum_case::<512, 2>();
|
||||
reduce_sum_case::<1024, 2>();
|
||||
reduce_sum_case::<2048, 2>();
|
||||
reduce_sum_case::<4096, 2>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_argmax() {
|
||||
reduce_argmax_case::<6, 1>();
|
||||
reduce_argmax_case::<10, 1>();
|
||||
reduce_argmax_case::<64, 1>();
|
||||
reduce_argmax_case::<128, 1>();
|
||||
reduce_argmax_case::<256, 1>();
|
||||
reduce_argmax_case::<512, 1>();
|
||||
reduce_argmax_case::<1024, 1>();
|
||||
reduce_argmax_case::<2048, 1>();
|
||||
reduce_argmax_case::<4096, 1>();
|
||||
|
||||
reduce_argmax_case::<6, 2>();
|
||||
reduce_argmax_case::<10, 2>();
|
||||
reduce_argmax_case::<64, 2>();
|
||||
reduce_argmax_case::<128, 2>();
|
||||
reduce_argmax_case::<256, 2>();
|
||||
reduce_argmax_case::<512, 2>();
|
||||
reduce_argmax_case::<1024, 2>();
|
||||
reduce_argmax_case::<2048, 2>();
|
||||
reduce_argmax_case::<4096, 2>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in New Issue