diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4cff9bda..3b239f19 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -406,7 +406,7 @@ pub fn call_copy2d( ); let width: usize = d1 * d2; - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 809522d7..22d53177 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -112,6 +112,7 @@ kernel void FN_NAME( \ device TYPENAME *output, \ uint tid [[ thread_position_in_grid ]] \ ) { \ + tid *= 4; \ if (tid >= d1 * d2) { \ return; \ } \ @@ -120,6 +121,9 @@ kernel void FN_NAME( \ size_t src_idx = idx1 * src_s + idx2; \ size_t dst_idx = idx1 * dst_s + idx2; \ output[dst_idx] = input[src_idx]; \ + output[dst_idx+1] = input[src_idx+1]; \ + output[dst_idx+2] = input[src_idx+2]; \ + output[dst_idx+3] = input[src_idx+3]; \ } COPY2D(copy2d_f32, float)