Copy multi metal [do not merge]
This commit is contained in:
parent
ab892274d1
commit
09fafcfa99
|
@ -406,7 +406,7 @@ pub fn call_copy2d(
|
|||
);
|
||||
|
||||
let width: usize = d1 * d2;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
|
|
@ -112,6 +112,7 @@ kernel void FN_NAME( \
|
|||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
tid *= 4; \
|
||||
if (tid >= d1 * d2) { \
|
||||
return; \
|
||||
} \
|
||||
|
@ -120,6 +121,9 @@ kernel void FN_NAME( \
|
|||
size_t src_idx = idx1 * src_s + idx2; \
|
||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||
output[dst_idx] = input[src_idx]; \
|
||||
output[dst_idx+1] = input[src_idx+1]; \
|
||||
output[dst_idx+2] = input[src_idx+2]; \
|
||||
output[dst_idx+3] = input[src_idx+3]; \
|
||||
}
|
||||
|
||||
COPY2D(copy2d_f32, float)
|
||||
|
|
Loading…
Reference in New Issue