diff --git a/Cargo.toml b/Cargo.toml index c37bd75b..ba09b1d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 592f5bdf..e7d3ab6a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,7 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 597c2f01..27475efe 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -54,10 +54,6 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - // pub fn metal_device(&self) -> &metal::DeviceRef { - // self.device.as_ref() - // } - pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -76,7 +72,6 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -105,28 +100,22 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } match self.dtype { - DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize / 1), - )), - DType::U32 => Ok(CpuStorage::U32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::I64 => Ok(CpuStorage::I64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), - DType::F16 => Ok(CpuStorage::F16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::BF16 => Ok(CpuStorage::BF16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::F32 => Ok(CpuStorage::F32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::F64 => Ok(CpuStorage::F64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), + DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), } } @@ -137,9 +126,9 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); - assert_eq!(dtype, DType::F32); + if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{ + crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + } let mut buffer = device.new_buffer(el, self.dtype); let command_buffer = self.device.command_queue.new_command_buffer(); @@ -153,7 +142,7 @@ impl BackendStorage for MetalStorage { mul as f32, add as f32, ) - .unwrap(); + .map_err(MetalError::from)?; command_buffer.commit(); command_buffer.wait_until_completed(); return Ok(Self { @@ -164,18 +153,18 @@ impl BackendStorage for MetalStorage { } fn powf(&self, _: &Layout, _: f64) -> Result { - todo!() + crate::bail!("powf metal") } fn elu(&self, _: &Layout, _: f64) -> Result { - todo!() + crate::bail!("elu metal") } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - assert!(sum_dims.len() == 1); - assert!(sum_dims[0] == layout.shape().rank() - 1); - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); + + if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){ + crate::bail!("Non contiguous reduce op not supported yet"); + } let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -204,7 +193,7 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), - _ => todo!("Reduce op for non float"), + _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? @@ -234,7 +223,7 @@ impl BackendStorage for MetalStorage { } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { - todo!() + crate::bail!("cmp metal") } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { @@ -246,7 +235,7 @@ impl BackendStorage for MetalStorage { if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", - (left, right) => todo!("to dtype {left:?} - {right:?}"), + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( &device.device, @@ -259,7 +248,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!( + crate::bail!( "TODO Implement the kernel calling cast {:?}-{:?}", self.dtype, dtype @@ -293,7 +282,7 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( &device.device, @@ -306,7 +295,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!("TODO Implement the kernel calling {}", B::KERNEL); + crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); } command_buffer.commit(); command_buffer.wait_until_completed(); @@ -344,7 +333,7 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( &device.device, @@ -365,7 +354,7 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( &device.device, @@ -442,7 +431,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv1D, ) -> Result { - todo!() + crate::bail!("conv1d metal") } fn conv_transpose1d( @@ -452,7 +441,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose1D, ) -> Result { - todo!() + crate::bail!("conv_transpose1d metal") } fn conv2d( @@ -462,7 +451,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv2D, ) -> Result { - todo!() + crate::bail!("conv2d metal") } fn conv_transpose2d( @@ -472,27 +461,27 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose2D, ) -> Result { - todo!() + crate::bail!("conv_tranpose2d metal") } fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + crate::bail!("avg_pool2d metal") } fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + crate::bail!("max_pool2d metal") } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { - todo!() + crate::bail!("upsample_nearest1d metal") } fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { - todo!() + crate::bail!("upsample_nearest2d metal") } fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { - todo!() + crate::bail!("gather metal") } fn scatter_add( @@ -504,14 +493,13 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result { - todo!() + crate::bail!("scatter_add metal") } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - assert!(src_l.is_contiguous()); - assert!(src_l.start_offset() == 0); - assert!(ids_l.is_contiguous()); - assert!(ids_l.start_offset() == 0); + if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){ + crate::bail!("Non contiguous index select not implemented"); + } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.shape().elem_count(); @@ -519,10 +507,10 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); - let out = self.to_cpu_storage().unwrap(); + let out = self.to_cpu_storage()?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", - (left, right) => todo!("index select metal {left:?} {right:?}"), + (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; let command_buffer = self.device.command_queue.new_command_buffer(); candle_metal_kernels::call_index_select( @@ -556,7 +544,7 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result { - todo!() + crate::bail!("index_add metal") } fn matmul( @@ -666,11 +654,6 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); command_buffer.wait_until_completed(); - // let left = self.buffer.read_to_vec::(10); - // let right = rhs.buffer.read_to_vec::(10); - // let out = out_buffer.read_to_vec::(40); - // todo!("Out {left:?} {right:?} {out:?}"); - Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -681,7 +664,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); - // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -690,7 +672,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => todo!("copy_strided not implemented for {dtype:?}"), + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), }; candle_metal_kernels::call_unary_strided( &self.device.device, @@ -741,7 +723,7 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, _seed: u64) -> Result<()> { - todo!("set_seed") + crate::bail!("set_seed") } fn location(&self) -> crate::DeviceLocation { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index f164dc2f..186f3209 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.3.0" +version = "0.3.1" edition = "2021" description = "Metal kernels for Candle" @@ -10,7 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7288216a..6c2e5f2b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,4 +1,3 @@ -#![allow(clippy::too_many_arguments)] use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, ComputePipelineState, Device, Function, Library, MTLSize, @@ -156,14 +155,6 @@ pub mod binary { ops!(add, sub, mul, div); } -// static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { -// let mut l = HashMap::new(); -// l.insert("affine", AFFINE); -// l.insert("indexing", INDEXING); -// l.insert("unary", UNARY); -// l -// }); -// #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] @@ -197,21 +188,7 @@ impl Kernels { Self { libraries, funcs } } - // pub fn init(device: &Device) -> Result { - // let kernels = Self::new(); - // kernels.load_libraries(device)?; - // Ok(kernels) - // } - - // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { - // for name in LIBRARY_SOURCES.keys() { - // self.load_library(device, name)?; - // } - // Ok(()) - // } - fn get_library_source(&self, source: Source) -> &'static str { - // LIBRARY_SOURCES.get(name).cloned() match source { Source::Affine => AFFINE, Source::Unary => UNARY, @@ -261,6 +238,7 @@ impl Kernels { } } +#[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -270,8 +248,6 @@ pub fn call_unary_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -292,6 +268,8 @@ pub fn call_unary_contiguous( encoder.end_encoding(); Ok(()) } + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -339,6 +317,7 @@ pub fn call_unary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -349,8 +328,6 @@ pub fn call_binary_contiguous( right: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -373,6 +350,7 @@ pub fn call_binary_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -425,6 +403,7 @@ pub fn call_binary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -434,8 +413,6 @@ pub fn call_cast_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Cast, kernel_name)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -458,6 +435,7 @@ pub fn call_cast_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -508,6 +486,7 @@ pub fn call_reduce_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, command_buffer: &CommandBufferRef, @@ -543,7 +522,6 @@ pub fn call_last_softmax( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - // (elements_to_sum as u64 + 2 - 1) / 2, elements_to_sum as u64, ) .next_power_of_two(); @@ -559,6 +537,7 @@ pub fn call_last_softmax( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, @@ -590,6 +569,7 @@ pub fn call_affine( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -643,6 +623,7 @@ pub fn call_where_cond_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, command_buffer: &CommandBufferRef, @@ -813,7 +794,6 @@ mod tests { #[test] fn cos_f32_strided() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - // Shape = [6], strides = [1]; let shape = vec![6]; let strides = vec![1]; let offset = 0;