Fix select assign

This commit is contained in:
nathaniel 2024-08-05 18:59:00 -04:00
parent bf4b7d98d6
commit f1a131c6da
1 changed files with 11 additions and 8 deletions

View File

@ -21,23 +21,22 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
num_elems *= shape_tensor; num_elems *= shape_tensor;
let ogwl = ABSOLUTE_POS / tensor.stride(i); let ogwl = ABSOLUTE_POS / indices.stride(i);
offset_tensor += ogwl % shape_tensor * tensor.stride(i); offset_tensor += ogwl % shape_tensor * tensor.stride(i);
offset_value += ogwl % value.shape(i) * value.stride(i); offset_value += ogwl % value.shape(i) * value.stride(i);
} }
} }
if num_elems >= ABSOLUTE_POS { if ABSOLUTE_POS >= num_elems {
return; return;
} }
let strides_tensor_dim = tensor.stride(dim); let strides_tensor_dim = tensor.stride(dim);
let strides_value_dim = value.stride(dim); let strides_value_dim = value.stride(dim);
let shape_value_dim = value.shape(dim);
// Main operation // Main operation
for i in range(0, shape_value_dim, Comptime::new(false)) { for i in range(0, value.shape(dim), Comptime::new(false)) {
let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
let index_value = i * strides_value_dim + offset_value; let index_value = i * strides_value_dim + offset_value;
@ -56,6 +55,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
false => tensor.copy(), false => tensor.copy(),
}; };
let mut strides = [0; D];
let mut current = 1;
let mut num_elems = 1; let mut num_elems = 1;
tensor tensor
@ -63,14 +64,15 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
.dims .dims
.iter() .iter()
.enumerate() .enumerate()
.rev()
.filter(|(index, _val)| *index != dim) .filter(|(index, _val)| *index != dim)
.for_each(|(index, _val)| { .for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems *= tensor.shape.dims[index]; num_elems *= tensor.shape.dims[index];
}); });
let cube_dim = CubeDim::default(); let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
println!("COUNT {:?} elem {:?}", cube_count, num_elems);
unsafe { unsafe {
select_assign_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>( select_assign_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
@ -78,7 +80,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
cube_count, cube_count,
cube_dim, cube_dim,
tensor.as_tensor_arg(1), tensor.as_tensor_arg(1),
TensorArg::from_raw_parts(&indices.handle, &tensor.strides, &tensor.shape.dims, 1), // Ignored shape + custom strides.
TensorArg::from_raw_parts(&indices.handle, &strides, &strides, 1),
value.as_tensor_arg(1), value.as_tensor_arg(1),
ScalarArg::new(dim as u32), ScalarArg::new(dim as u32),
); );