diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 23dbedba8..3cad192f6 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -21,23 +21,22 @@ fn select_assign_kernel( num_elems *= shape_tensor; - let ogwl = ABSOLUTE_POS / tensor.stride(i); + let ogwl = ABSOLUTE_POS / indices.stride(i); offset_tensor += ogwl % shape_tensor * tensor.stride(i); offset_value += ogwl % value.shape(i) * value.stride(i); } } - if num_elems >= ABSOLUTE_POS { + if ABSOLUTE_POS >= num_elems { return; } let strides_tensor_dim = tensor.stride(dim); let strides_value_dim = value.stride(dim); - let shape_value_dim = value.shape(dim); // Main operation - for i in range(0, shape_value_dim, Comptime::new(false)) { + for i in range(0, value.shape(dim), Comptime::new(false)) { let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value; @@ -56,6 +55,8 @@ pub(crate) fn select_assign tensor.copy(), }; + let mut strides = [0; D]; + let mut current = 1; let mut num_elems = 1; tensor @@ -63,14 +64,15 @@ pub(crate) fn select_assign( @@ -78,7 +80,8 @@ pub(crate) fn select_assign