From 9e23cc42279a6557060dcf7ee64b4a1c3a78175f Mon Sep 17 00:00:00 2001 From: mepatrick73 Date: Wed, 14 Aug 2024 19:01:48 -0400 Subject: [PATCH] trying part 2 --- crates/burn-jit/src/kernel/index/scatter.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 77d92f449..a17ea3836 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -7,7 +7,6 @@ use crate::{ use cubecl::prelude::*; use cubecl::{calculate_cube_count_elemwise, CubeDim}; -#[allow(unused_assignments)] #[cube(launch_unchecked)] fn scatter_kernel( input: &mut Tensor, @@ -51,18 +50,14 @@ fn scatter_kernel( return; } - let mut index_input = UInt::new(0); - let mut idx = UInt::new(0); - let mut result_indices = UInt::new(0); - for i in range(0, shape_value, Comptime::new(false)) { - idx = stride_input * i; + let mut idx = stride_input * i; idx += offset_value; let result_value = value[idx]; - result_indices = UInt::cast_from(indices[idx]); + let result_indices = UInt::cast_from(indices[idx]); - index_input = stride_input * result_indices; + let mut index_input = stride_input * result_indices; index_input += offset_input; let mut result_input = input[index_input];