mirror of https://github.com/tracel-ai/burn.git
trying part 2
This commit is contained in:
parent
c111d9dd61
commit
9e23cc4227
|
@ -7,7 +7,6 @@ use crate::{
|
|||
use cubecl::prelude::*;
|
||||
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
#[cube(launch_unchecked)]
|
||||
fn scatter_kernel<T: Numeric>(
|
||||
input: &mut Tensor<T>,
|
||||
|
@ -51,18 +50,14 @@ fn scatter_kernel<T: Numeric>(
|
|||
return;
|
||||
}
|
||||
|
||||
let mut index_input = UInt::new(0);
|
||||
let mut idx = UInt::new(0);
|
||||
let mut result_indices = UInt::new(0);
|
||||
|
||||
for i in range(0, shape_value, Comptime::new(false)) {
|
||||
idx = stride_input * i;
|
||||
let mut idx = stride_input * i;
|
||||
idx += offset_value;
|
||||
|
||||
let result_value = value[idx];
|
||||
result_indices = UInt::cast_from(indices[idx]);
|
||||
let result_indices = UInt::cast_from(indices[idx]);
|
||||
|
||||
index_input = stride_input * result_indices;
|
||||
let mut index_input = stride_input * result_indices;
|
||||
index_input += offset_input;
|
||||
|
||||
let mut result_input = input[index_input];
|
||||
|
|
Loading…
Reference in New Issue