Fixed all bugs. Improved code quality. Added tests.

This commit is contained in:
Ivar Flakstad 2024-01-30 14:12:57 +01:00
parent 077e781f53
commit 8babfe0411
5 changed files with 1062 additions and 561 deletions

View File

@ -3,9 +3,9 @@ mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
//benchmarks::affine::benches,
//benchmarks::matmul::benches,
//benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::where_cond::benches
//benchmarks::where_cond::benches
);

View File

@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) {
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up));
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_arg_reduce(c, &device, (lo, up));
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}
@ -89,6 +97,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
DType::BF16 => "softmax_bf16",
_ => "softmax",
};
softmax(&a).unwrap();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
@ -105,19 +114,49 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
group.finish();
}
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => "reduce_f32",
DType::F16 => "reduce_f16",
DType::BF16 => "reduce_bf16",
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "reduce",
};
@ -140,20 +179,46 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
let name = match T::DTYPE {
DType::F32 => "arg_reduce_f32",
DType::F16 => "arg_reduce_f16",
DType::BF16 => "arg_reduce_bf16",
_ => "reduce",
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,346 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 2048;
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
REDUCE(x + y, fast_sum_f32, float, 0)
REDUCE(x + y, fast_sum_u32, uint, 0)
REDUCE(x + y, fast_sum_f16, half, 0)
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
REDUCE(x * y, fast_mul_f32, float, 1)
REDUCE(x * y, fast_mul_u32, uint, 1)
REDUCE(x * y, fast_mul_f16, half, 1)
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32, uint, 0)
ARGMAX(fast_argmax_u8, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
REDUCE(x + y, fast_sum_i64, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
#endif
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
#endif

View File

@ -622,7 +622,7 @@ fn cos_f16() {
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@ -630,10 +630,10 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
match call_reduce_strided(
&device,
command_buffer,
&kernels,
@ -644,8 +644,13 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&input,
0,
&output,
)
.unwrap();
) {
Ok(_) => {}
Err(e) => {
println!("Error: {}", e);
panic!();
}
}
command_buffer.commit();
command_buffer.wait_until_completed();
@ -677,22 +682,114 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
read_to_vec(&output, v.len())
}
#[test]
fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
const fn create_array<const N: usize>() -> [f32; N] {
let mut array: [f32; N] = [0.0; N];
let mut i = 1;
while i <= N {
array[i - 1] = i as f32;
i += 1;
}
array
}
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
let mut sum = 0;
let mut results: [f32; D] = [0.0; D];
let mut i = 1;
let mut j = 1;
while i <= N {
sum += i;
i += 1;
if i > j * N / D {
results[j - 1] = sum as f32;
j += 1;
sum = 0;
}
}
results
}
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
let mut max = 0.0;
let mut max_index: u32 = 0;
let mut results: [u32; D] = [0; D];
let mut i = 0;
let mut j = 1;
while i <= N {
if i >= (j * N / D) {
results[j - 1] = max_index;
max = 0.0;
max_index = 0;
j += 1;
}
if i == N {
break;
}
if arr[i] > max {
max = arr[i];
max_index = i as u32;
}
i += 1;
}
results
}
fn reduce_sum_case<const N: usize, const D: usize>() {
let v = create_array::<N>();
let results = run_reduce(&v, D, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), correct_sum::<N, D>());
}
fn reduce_argmax_case<const N: usize, const D: usize>() {
let v = create_array::<N>();
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
assert_eq!(results, correct_argmax::<N, D>(v));
}
#[test]
fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
fn reduce_sum() {
reduce_sum_case::<6, 1>();
reduce_sum_case::<10, 1>();
reduce_sum_case::<64, 1>();
reduce_sum_case::<128, 1>();
reduce_sum_case::<256, 1>();
reduce_sum_case::<512, 1>();
reduce_sum_case::<1024, 1>();
reduce_sum_case::<2048, 1>();
reduce_sum_case::<4096, 1>();
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
reduce_sum_case::<6, 2>();
reduce_sum_case::<10, 2>();
reduce_sum_case::<64, 2>();
reduce_sum_case::<128, 2>();
reduce_sum_case::<256, 2>();
reduce_sum_case::<512, 2>();
reduce_sum_case::<1024, 2>();
reduce_sum_case::<2048, 2>();
reduce_sum_case::<4096, 2>();
}
#[test]
fn reduce_argmax() {
reduce_argmax_case::<6, 1>();
reduce_argmax_case::<10, 1>();
reduce_argmax_case::<64, 1>();
reduce_argmax_case::<128, 1>();
reduce_argmax_case::<256, 1>();
reduce_argmax_case::<512, 1>();
reduce_argmax_case::<1024, 1>();
reduce_argmax_case::<2048, 1>();
reduce_argmax_case::<4096, 1>();
reduce_argmax_case::<6, 2>();
reduce_argmax_case::<10, 2>();
reduce_argmax_case::<64, 2>();
reduce_argmax_case::<128, 2>();
reduce_argmax_case::<256, 2>();
reduce_argmax_case::<512, 2>();
reduce_argmax_case::<1024, 2>();
reduce_argmax_case::<2048, 2>();
reduce_argmax_case::<4096, 2>();
}
#[test]