Fix select assign

This commit is contained in:
nathaniel 2024-08-05 18:59:00 -04:00
parent bf4b7d98d6
commit f1a131c6da
1 changed files with 11 additions and 8 deletions

View File

@ -21,23 +21,22 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
num_elems *= shape_tensor;
let ogwl = ABSOLUTE_POS / tensor.stride(i);
let ogwl = ABSOLUTE_POS / indices.stride(i);
offset_tensor += ogwl % shape_tensor * tensor.stride(i);
offset_value += ogwl % value.shape(i) * value.stride(i);
}
}
if num_elems >= ABSOLUTE_POS {
if ABSOLUTE_POS >= num_elems {
return;
}
let strides_tensor_dim = tensor.stride(dim);
let strides_value_dim = value.stride(dim);
let shape_value_dim = value.shape(dim);
// Main operation
for i in range(0, shape_value_dim, Comptime::new(false)) {
for i in range(0, value.shape(dim), Comptime::new(false)) {
let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
let index_value = i * strides_value_dim + offset_value;
@ -56,6 +55,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
false => tensor.copy(),
};
let mut strides = [0; D];
let mut current = 1;
let mut num_elems = 1;
tensor
@ -63,14 +64,15 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
.dims
.iter()
.enumerate()
.rev()
.filter(|(index, _val)| *index != dim)
.for_each(|(index, _val)| {
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems *= tensor.shape.dims[index];
});
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
println!("COUNT {:?} elem {:?}", cube_count, num_elems);
unsafe {
select_assign_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
@ -78,7 +80,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
cube_count,
cube_dim,
tensor.as_tensor_arg(1),
TensorArg::from_raw_parts(&indices.handle, &tensor.strides, &tensor.shape.dims, 1),
// Ignored shape + custom strides.
TensorArg::from_raw_parts(&indices.handle, &strides, &strides, 1),
value.as_tensor_arg(1),
ScalarArg::new(dim as u32),
);