Where cond get_strided_index conditionally based on function constants

This commit is contained in:
Ivar Flakstad 2024-01-22 20:59:02 +01:00
parent fd7c856564
commit 933716b374
4 changed files with 41 additions and 9 deletions

View File

@ -822,10 +822,13 @@ impl BackendStorage for MetalStorage {
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
),
!layout.is_contiguous(),
&t.buffer,
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
!t_l.is_contiguous(),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
!f_l.is_contiguous(),
&buffer,
)
.map_err(MetalError::from)?;

View File

@ -909,13 +909,22 @@ pub fn call_where_cond_strided(
shape: &[usize],
cond: &Buffer,
(cond_stride, cond_offset): (&[usize], usize),
cond_is_strided: bool,
left: &Buffer,
(left_stride, left_offset): (&[usize], usize),
left_is_strided: bool,
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
right_is_strided: bool,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let constants = Some(ConstantValues::new(vec![
(0, Value::Bool(cond_is_strided)),
(1, Value::Bool(left_is_strided)),
(2, Value::Bool(right_is_strided)),
]));
let pipeline =
kernels.load_pipeline_with_constants(device, Source::Ternary, name, constants)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

View File

@ -1,14 +1,20 @@
#include <metal_stdlib>
#
using namespace metal;
constant bool IDS_STRIDED [[function_constant(0)]];
constant bool T_STRIDED [[function_constant(1)]];
constant bool F_STRIDED [[function_constant(2)]];
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
constant const size_t &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
@ -17,6 +23,7 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
template<typename T, typename ID>
METAL_FUNC void where_cond(
constant size_t &numel,
@ -34,10 +41,20 @@ METAL_FUNC void where_cond(
if (i >= numel){
return;
}
uint strided_i = get_strided_index(i, num_dims, dims, strides);
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
uint strided_i = i;
uint strided_i_t = i;
uint strided_i_f = i;
if (IDS_STRIDED) {
strided_i = get_strided_index(i, num_dims, dims, strides);
}
if (T_STRIDED) {
strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
}
if (F_STRIDED) {
strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
}
out[i] = select(f[strided_i_t], t[strided_i_f], ids[strided_i]);
}
#define WHERE_OP(T, ID, FN_NAME) \

View File

@ -803,10 +803,13 @@ fn run_where_cond<I: Clone, T: Clone>(
shape,
&cond,
(&cond_stride, cond_offset),
true,
&left,
(&left_stride, left_offset),
true,
&right,
(&cond_stride, cond_offset),
true,
&output,
)
.unwrap();