mirror of https://github.com/tracel-ai/burn.git
trying part 2
This commit is contained in:
parent
c111d9dd61
commit
9e23cc4227
|
@ -7,7 +7,6 @@ use crate::{
|
||||||
use cubecl::prelude::*;
|
use cubecl::prelude::*;
|
||||||
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
||||||
|
|
||||||
#[allow(unused_assignments)]
|
|
||||||
#[cube(launch_unchecked)]
|
#[cube(launch_unchecked)]
|
||||||
fn scatter_kernel<T: Numeric>(
|
fn scatter_kernel<T: Numeric>(
|
||||||
input: &mut Tensor<T>,
|
input: &mut Tensor<T>,
|
||||||
|
@ -51,18 +50,14 @@ fn scatter_kernel<T: Numeric>(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut index_input = UInt::new(0);
|
|
||||||
let mut idx = UInt::new(0);
|
|
||||||
let mut result_indices = UInt::new(0);
|
|
||||||
|
|
||||||
for i in range(0, shape_value, Comptime::new(false)) {
|
for i in range(0, shape_value, Comptime::new(false)) {
|
||||||
idx = stride_input * i;
|
let mut idx = stride_input * i;
|
||||||
idx += offset_value;
|
idx += offset_value;
|
||||||
|
|
||||||
let result_value = value[idx];
|
let result_value = value[idx];
|
||||||
result_indices = UInt::cast_from(indices[idx]);
|
let result_indices = UInt::cast_from(indices[idx]);
|
||||||
|
|
||||||
index_input = stride_input * result_indices;
|
let mut index_input = stride_input * result_indices;
|
||||||
index_input += offset_input;
|
index_input += offset_input;
|
||||||
|
|
||||||
let mut result_input = input[index_input];
|
let mut result_input = input[index_input];
|
||||||
|
|
Loading…
Reference in New Issue