trying part 2

This commit is contained in:
mepatrick73 2024-08-14 19:01:48 -04:00
parent c111d9dd61
commit 9e23cc4227
1 changed files with 3 additions and 8 deletions

View File

@ -7,7 +7,6 @@ use crate::{
use cubecl::prelude::*; use cubecl::prelude::*;
use cubecl::{calculate_cube_count_elemwise, CubeDim}; use cubecl::{calculate_cube_count_elemwise, CubeDim};
#[allow(unused_assignments)]
#[cube(launch_unchecked)] #[cube(launch_unchecked)]
fn scatter_kernel<T: Numeric>( fn scatter_kernel<T: Numeric>(
input: &mut Tensor<T>, input: &mut Tensor<T>,
@ -51,18 +50,14 @@ fn scatter_kernel<T: Numeric>(
return; return;
} }
let mut index_input = UInt::new(0);
let mut idx = UInt::new(0);
let mut result_indices = UInt::new(0);
for i in range(0, shape_value, Comptime::new(false)) { for i in range(0, shape_value, Comptime::new(false)) {
idx = stride_input * i; let mut idx = stride_input * i;
idx += offset_value; idx += offset_value;
let result_value = value[idx]; let result_value = value[idx];
result_indices = UInt::cast_from(indices[idx]); let result_indices = UInt::cast_from(indices[idx]);
index_input = stride_input * result_indices; let mut index_input = stride_input * result_indices;
index_input += offset_input; index_input += offset_input;
let mut result_input = input[index_input]; let mut result_input = input[index_input];