Integrate the MLX gemm kernels (#2468)

* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Laurent Mazare 2024-09-11 15:56:48 +01:00 committed by GitHub
parent 13b2a8a4a0
commit 5635650d38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1874 additions and 55 deletions

View File

@ -23,3 +23,4 @@ half = { version = "2.3.1", features = [
"rand_distr",
] }
rand = "0.8.5"
rand_distr = "0.4.3"

View File

@ -11,33 +11,35 @@ pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Indexing,
Unary,
Binary,
Ternary,
Cast,
Reduce,
Mfa,
Conv,
Random,
Gemm,
Indexing,
Mfa,
Quantized,
Random,
Reduce,
Sort,
Ternary,
Unary,
}
pub mod copy2d {
@ -191,16 +193,17 @@ impl Kernels {
fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
Source::Random => RANDOM,
Source::Reduce => REDUCE,
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -2178,5 +2181,181 @@ pub fn call_arg_sort(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
BF16,
F16,
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_mlx_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct GemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, false)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, false)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool(/* has_batch */ b > 1)),
(100, Value::Bool(/* use_out_source */ false)),
(110, Value::Bool(/* do_axpby */ false)),
(200, Value::Bool(/* align_m */ m % bm == 0)),
(201, Value::Bool(/* align_n */ n % bn == 0)),
(202, Value::Bool(/* align_k */ k % bk == 0)),
(300, Value::Bool(/* do_gather */ false)),
]));
let swizzle_log = 0;
let tile = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile;
let tm = tm.div_ceil(tile);
let batch_stride_a = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3]
} else {
m * k
};
let batch_stride_b = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3]
} else {
n * k
};
let gemm_params = GemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda,
ldb,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a: batch_stride_a as isize,
batch_stride_b: batch_stride_b as isize,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
// TODO(laurent): generate the name
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
let name = match (dtype, a_trans, b_trans) {
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
};
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(3, Some(output), 0);
encoder.set_bytes(
4,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
6, // batch_shape
std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);
let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
depth: /* batch_size_out */ b as u64,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}
#[cfg(test)]
mod tests;

File diff suppressed because it is too large Load Diff

View File

@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
#[test]
fn cast_f32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -360,7 +360,7 @@ fn cast_f32() {
#[test]
fn cast_f16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -391,7 +391,7 @@ fn cast_f16() {
#[test]
fn cast_bf16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -422,7 +422,7 @@ fn cast_bf16() {
#[test]
fn cast_u32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -453,7 +453,7 @@ fn cast_u32() {
#[test]
fn cast_u8() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -484,7 +484,7 @@ fn cast_u8() {
#[test]
fn cast_i64() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -911,7 +911,7 @@ fn softmax() {
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -922,7 +922,7 @@ fn softmax() {
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() {
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[allow(clippy::too_many_arguments)]
fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: Vec<usize>,
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>(
&kernels,
name,
(b, m, n, k),
&lhs_stride,
lhs_stride,
lhs_offset,
&lhs,
&rhs_stride,
rhs_stride,
rhs_offset,
&rhs,
&output,
@ -1105,10 +1106,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1125,10 +1126,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1150,10 +1151,10 @@ fn gemm() {
"sgemm",
(1, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
12 * 4,
);
assert_eq!(
@ -1172,10 +1173,10 @@ fn gemm() {
"bgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1194,10 +1195,10 @@ fn gemm() {
"hgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1206,6 +1207,204 @@ fn gemm() {
);
}
#[allow(clippy::too_many_arguments)]
fn run_mlx_gemm<T: Clone>(
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
lhs_stride,
lhs_offset,
&lhs,
rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
use rand::SeedableRng;
use rand_distr::Distribution;
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
let v1: Vec<f32> = run_mlx_gemm(
dtype,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
let v2: Vec<f32> = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
for (a, b) in v1.iter().zip(v2.iter()) {
let diff = (a - b).abs();
assert_eq!((diff * 1e4).round(), 0.)
}
}
#[test]
fn mlx_vs_mfa() {
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
}
#[test]
fn mlx_gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_mlx_gemm(
GemmDType::F32,
(1, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
// bgemm sanity test
{
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::BF16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
{
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::F16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
@ -1280,7 +1479,7 @@ fn random() {
variance.sqrt()
}
let shape = vec![1024, 10];
let shape = [1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
@ -1636,7 +1835,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1656,7 +1855,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = vec![5.0, 7.0, 13.0, 15.0]
let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = vec![5.0, 7.0, 13.0, 15.0]
let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() {
&strides,
"avg_pool2d_f16",
);
let expected = vec![
let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() {
&strides,
"avg_pool2d_bf16",
);
let expected = vec![
let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() {
#[test]
fn conv_transpose1d_f16() {
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() {
"conv_transpose1d_f16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() {
#[test]
fn conv_transpose1d_bf16() {
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() {
"conv_transpose1d_bf16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();

View File

@ -165,7 +165,7 @@ pub trait EncoderProvider {
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
fn encoder(&self) -> Self::Encoder<'_>;
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
@ -178,7 +178,7 @@ impl<'a> Drop for WrappedEncoder<'a> {
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
&self.0
self.0
}
}
@ -186,7 +186,7 @@ impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder())
}
}
@ -195,7 +195,7 @@ impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder())
}
}