From f1a131c6da771d0af27a05b7527a60b723e5a980 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 5 Aug 2024 18:59:00 -0400 Subject: [PATCH] Fix select assign --- .../src/kernel/index/select_assign.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 23dbedba8..3cad192f6 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -21,23 +21,22 @@ fn select_assign_kernel( num_elems *= shape_tensor; - let ogwl = ABSOLUTE_POS / tensor.stride(i); + let ogwl = ABSOLUTE_POS / indices.stride(i); offset_tensor += ogwl % shape_tensor * tensor.stride(i); offset_value += ogwl % value.shape(i) * value.stride(i); } } - if num_elems >= ABSOLUTE_POS { + if ABSOLUTE_POS >= num_elems { return; } let strides_tensor_dim = tensor.stride(dim); let strides_value_dim = value.stride(dim); - let shape_value_dim = value.shape(dim); // Main operation - for i in range(0, shape_value_dim, Comptime::new(false)) { + for i in range(0, value.shape(dim), Comptime::new(false)) { let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value; @@ -56,6 +55,8 @@ pub(crate) fn select_assign tensor.copy(), }; + let mut strides = [0; D]; + let mut current = 1; let mut num_elems = 1; tensor @@ -63,14 +64,15 @@ pub(crate) fn select_assign( @@ -78,7 +80,8 @@ pub(crate) fn select_assign