Where cond get_strided_index conditionally based on function constants
This commit is contained in:
parent
fd7c856564
commit
933716b374
|
@ -822,10 +822,13 @@ impl BackendStorage for MetalStorage {
|
|||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
!layout.is_contiguous(),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
!t_l.is_contiguous(),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
!f_l.is_contiguous(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
|
|
@ -909,13 +909,22 @@ pub fn call_where_cond_strided(
|
|||
shape: &[usize],
|
||||
cond: &Buffer,
|
||||
(cond_stride, cond_offset): (&[usize], usize),
|
||||
cond_is_strided: bool,
|
||||
left: &Buffer,
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
left_is_strided: bool,
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
right_is_strided: bool,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::Bool(cond_is_strided)),
|
||||
(1, Value::Bool(left_is_strided)),
|
||||
(2, Value::Bool(right_is_strided)),
|
||||
]));
|
||||
let pipeline =
|
||||
kernels.load_pipeline_with_constants(device, Source::Ternary, name, constants)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
#include <metal_stdlib>
|
||||
#
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool IDS_STRIDED [[function_constant(0)]];
|
||||
constant bool T_STRIDED [[function_constant(1)]];
|
||||
constant bool F_STRIDED [[function_constant(2)]];
|
||||
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
constant const size_t &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
|
@ -17,6 +23,7 @@ METAL_FUNC uint get_strided_index(
|
|||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
|
@ -34,10 +41,20 @@ METAL_FUNC void where_cond(
|
|||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
uint strided_i = i;
|
||||
uint strided_i_t = i;
|
||||
uint strided_i_f = i;
|
||||
if (IDS_STRIDED) {
|
||||
strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
}
|
||||
if (T_STRIDED) {
|
||||
strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
}
|
||||
if (F_STRIDED) {
|
||||
strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
}
|
||||
|
||||
out[i] = select(f[strided_i_t], t[strided_i_f], ids[strided_i]);
|
||||
}
|
||||
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
|
|
|
@ -803,10 +803,13 @@ fn run_where_cond<I: Clone, T: Clone>(
|
|||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
true,
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
true,
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
true,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
|
|
Loading…
Reference in New Issue