Copy multi metal [do not merge]

This commit is contained in:
Laurent 2024-04-06 10:11:16 +02:00
parent ab892274d1
commit 09fafcfa99
2 changed files with 5 additions and 1 deletions

View File

@ -406,7 +406,7 @@ pub fn call_copy2d(
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);

View File

@ -112,6 +112,7 @@ kernel void FN_NAME( \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
tid *= 4; \
if (tid >= d1 * d2) { \
return; \
} \
@ -120,6 +121,9 @@ kernel void FN_NAME( \
size_t src_idx = idx1 * src_s + idx2; \
size_t dst_idx = idx1 * dst_s + idx2; \
output[dst_idx] = input[src_idx]; \
output[dst_idx+1] = input[src_idx+1]; \
output[dst_idx+2] = input[src_idx+2]; \
output[dst_idx+3] = input[src_idx+3]; \
}
COPY2D(copy2d_f32, float)