Merge pull request #1318 from huggingface/metal4

Starting to fix some tests.
This commit is contained in:
Nicolas Patry 2023-12-20 15:37:31 +01:00 committed by GitHub
commit 9fc210fae8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 2784 additions and 785 deletions

View File

@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
@ -61,7 +62,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]
inherits = "release"

View File

@ -34,6 +34,8 @@ zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "matmul"
harness = false

View File

@ -0,0 +1,43 @@
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -201,10 +201,9 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Metal(storage))
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1877,10 +1877,7 @@ impl Tensor {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.

View File

@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"

View File

@ -10,7 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal = { version = "0.27.0", features = ["mps"]}
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -29,15 +29,96 @@ kernel void FN_NAME( \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
output[id] = TYPENAME(float(input[id]) * mul + add); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
}
#define POWF(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
}
#define ELU(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[id]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
AFFINE(affine_f32, float)
AFFINE(affine_f16, half)
POWF(powf_f32, float)
POWF(powf_f16, half)
ELU(elu_f32, float)
ELU(elu_f16, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);
#endif

View File

@ -1,5 +1,8 @@
#include <metal_stdlib>
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -22,15 +25,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
TYPENAME x = left[tid]; \
TYPENAME y = right[tid]; \
output[tid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -40,33 +43,48 @@ kernel void FN_NAME_STRIDED( \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
output[tid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
BINARY_OP_OUT(eq, x == y)
BINARY_OP_OUT(ne, x != y)
BINARY_OP_OUT(le, x <= y)
BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
#endif

View File

@ -23,12 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -37,15 +37,20 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint i [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (i >= dim) { \
if (tid >= dim) { \
return; \
} \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -1,6 +1,34 @@
#include <metal_stdlib>
using namespace metal;
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -11,92 +39,160 @@ kernel void NAME( \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = gid / right_size / left_size; \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid % left_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
output[gid] = input[src_i]; \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
) {
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
constant size_t &ids_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
constant size_t &ids_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
}
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,9 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -16,39 +19,160 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
constant int THREADGROUP_SIZE = 2048;
# define REDUCE(FN, NAME, TYPENAME) \
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const TYPENAME *src, \
device TYPENAME *dst, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint blockDim [[ threads_per_threadgroup ]] \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += blockDim; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
@ -56,10 +180,10 @@ kernel void NAME( \
/* \
// reduction in shared memory \
*/ \
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
@ -68,72 +192,101 @@ kernel void NAME( \
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {
threadgroup float shared_memory[THREADGROUP_SIZE];
shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
float max = shared_memory[0];
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
shared_memory[tid] = 0;
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
#endif

View File

@ -32,6 +32,9 @@ kernel void FN_NAME( \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= numel){ \
return; \
} \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \

View File

@ -1,7 +1,14 @@
use super::*;
use half::f16;
use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
@ -23,13 +30,19 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
@ -37,23 +50,24 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
name,
v.len(),
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
read_to_vec(&output, v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
@ -62,12 +76,12 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
x.len(),
&left,
&right,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(x.len())
read_to_vec(&output, x.len())
}
fn run_strided<T: Clone>(
@ -81,8 +95,9 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let kernels = Kernels::new();
let output = new_buffer(&device, v);
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_unary_strided(
&device,
command_buffer,
@ -92,13 +107,13 @@ fn run_strided<T: Clone>(
&input,
strides,
offset,
&mut output,
&output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
read_to_vec(&output, v.len())
}
#[test]
@ -200,6 +215,25 @@ fn cos_strided_random() {
);
}
#[test]
fn gelu_f16() {
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
let results = run(&v, unary::contiguous::gelu::HALF);
assert_eq!(approx_f16(results, 2), expected);
}
#[test]
fn gelu_f32() {
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
let results = run(&v, unary::contiguous::gelu::FLOAT);
assert_eq!(approx(results, 3), expected);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
@ -216,11 +250,14 @@ fn binary_add_f32() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let size = (v.len() * std::mem::size_of::<U>()) as u64;
let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
@ -229,12 +266,13 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
name,
v.len(),
&input,
&mut output,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<U>(v.len())
read_to_vec(&output, v.len())
}
#[test]
@ -245,21 +283,28 @@ fn cast_u32_f32() {
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32, 2.0, 3.0];
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
let results: Vec<f32> = cast(&input, "cast_f16_f32");
assert_eq!(results.len(), 10_000);
assert_eq!(&results[..10], vec![1.0f32; 10]);
assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
let size = v.len();
@ -267,9 +312,10 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
"affine_f32",
size,
&input,
&mut output,
&output,
mul as f32,
add as f32,
)
@ -277,7 +323,44 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
read_to_vec(&output, v.len())
}
fn run_affine_strided<T: Clone>(
v: &[T],
shape: &[usize],
strides: &[usize],
mul: f64,
add: f64,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_affine_strided(
&device,
command_buffer,
&kernels,
"affine_f32_strided",
shape,
&input,
strides,
0,
&output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
let len: usize = shape.iter().product();
read_to_vec(&output, len)
}
#[test]
@ -295,6 +378,18 @@ fn affine() {
assert_eq!(result, vec![2.6; 40_000]);
}
#[test]
fn affine_strided() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let shape = [4];
let strides = [2];
let result = run_affine_strided(&input, &shape, &strides, mul, add);
// 1 on 2
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
}
#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
@ -313,7 +408,26 @@ fn index_select() {
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
.map(|x| f16::from_f32(x))
.collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
@ -321,7 +435,7 @@ fn index_select() {
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
@ -341,27 +455,34 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new();
let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(
&device,
&command_buffer,
&kernels,
"is_u32_f32",
name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer,
&dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
dst_buffer.read_to_vec::<T>(dst_el)
read_to_vec(&dst_buffer, dst_el)
}
#[test]
@ -427,7 +548,7 @@ fn index_add() {
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
assert_eq!(result, expected);
}
@ -439,43 +560,49 @@ fn cos_f16() {
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
&device,
command_buffer,
&kernels,
name,
v.len(),
&dims,
&strides,
out_length,
&input,
&mut output,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(out_length)
read_to_vec(&output, out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
@ -484,13 +611,14 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
&mut output,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
read_to_vec(&output, v.len())
}
#[test]
@ -498,7 +626,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_float");
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
}
@ -507,7 +635,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_float");
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
@ -515,15 +643,33 @@ fn reduce_sum2() {
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let last_dim = 4096;
let n = 200;
let mut v = vec![0.0; n * last_dim];
for i in 0..n {
v[i * last_dim] = 20.0;
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
);
assert_eq!(results[0], 1.0);
assert_eq!(results[1], 0.0);
assert_eq!(results[last_dim], 1.0);
assert_eq!(results[2 * last_dim], 1.0);
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@ -531,11 +677,33 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_float");
let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_f16");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_bf16");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
);
}
fn run_where_cond<I: Clone, T: Clone>(
@ -549,7 +717,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -571,7 +740,7 @@ fn run_where_cond<I: Clone, T: Clone>(
options,
);
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
@ -584,13 +753,13 @@ fn run_where_cond<I: Clone, T: Clone>(
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(length)
read_to_vec(&output, length)
}
#[test]
@ -614,3 +783,93 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
lhs_offset: usize,
rhs: &[T],
rhs_stride: Vec<usize>,
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_gemm(
&device,
command_buffer,
&kernels,
"sgemm",
(b, m, n, k),
&lhs_stride,
lhs_offset,
&lhs,
&rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}

View File

@ -1,4 +1,7 @@
#include <metal_stdlib>
#include <metal_math>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
@ -17,10 +20,44 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in) { return in; }
template <typename T> METAL_FUNC T gelu_erf(T x) {
return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
}
template <typename T> METAL_FUNC T gelu(T x) {
if (x > 5) {
return x;
}
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -32,7 +69,7 @@ kernel void FN_NAME( \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -46,15 +83,15 @@ kernel void FN_NAME_STRIDED( \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
UNARY_OP(cos)
@ -64,8 +101,17 @@ UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
@ -75,6 +121,13 @@ BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif

View File

@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,

View File

@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
kernel_name.0,
v.len(),
iterations,
total_time,
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
kernel_name.0,
v.len(),
iterations,
total_time,

View File

@ -19,6 +19,8 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,3 +31,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]

View File

@ -201,6 +201,47 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match storage.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
candle::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {

View File

@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
metal = ["candle/metal", "candle-nn/metal"]