mirror of https://github.com/tracel-ai/burn.git
Fix select assign
This commit is contained in:
parent
bf4b7d98d6
commit
f1a131c6da
|
@ -21,23 +21,22 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
|
||||||
|
|
||||||
num_elems *= shape_tensor;
|
num_elems *= shape_tensor;
|
||||||
|
|
||||||
let ogwl = ABSOLUTE_POS / tensor.stride(i);
|
let ogwl = ABSOLUTE_POS / indices.stride(i);
|
||||||
|
|
||||||
offset_tensor += ogwl % shape_tensor * tensor.stride(i);
|
offset_tensor += ogwl % shape_tensor * tensor.stride(i);
|
||||||
offset_value += ogwl % value.shape(i) * value.stride(i);
|
offset_value += ogwl % value.shape(i) * value.stride(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if num_elems >= ABSOLUTE_POS {
|
if ABSOLUTE_POS >= num_elems {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let strides_tensor_dim = tensor.stride(dim);
|
let strides_tensor_dim = tensor.stride(dim);
|
||||||
let strides_value_dim = value.stride(dim);
|
let strides_value_dim = value.stride(dim);
|
||||||
let shape_value_dim = value.shape(dim);
|
|
||||||
|
|
||||||
// Main operation
|
// Main operation
|
||||||
for i in range(0, shape_value_dim, Comptime::new(false)) {
|
for i in range(0, value.shape(dim), Comptime::new(false)) {
|
||||||
let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
|
let index_tensor = UInt::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;
|
||||||
let index_value = i * strides_value_dim + offset_value;
|
let index_value = i * strides_value_dim + offset_value;
|
||||||
|
|
||||||
|
@ -56,6 +55,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
false => tensor.copy(),
|
false => tensor.copy(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut strides = [0; D];
|
||||||
|
let mut current = 1;
|
||||||
let mut num_elems = 1;
|
let mut num_elems = 1;
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
|
@ -63,14 +64,15 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
.dims
|
.dims
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
.rev()
|
||||||
.filter(|(index, _val)| *index != dim)
|
.filter(|(index, _val)| *index != dim)
|
||||||
.for_each(|(index, _val)| {
|
.for_each(|(index, val)| {
|
||||||
|
strides[index] = current;
|
||||||
|
current *= val;
|
||||||
num_elems *= tensor.shape.dims[index];
|
num_elems *= tensor.shape.dims[index];
|
||||||
});
|
});
|
||||||
|
|
||||||
let cube_dim = CubeDim::default();
|
let cube_dim = CubeDim::default();
|
||||||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||||
println!("COUNT {:?} elem {:?}", cube_count, num_elems);
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
select_assign_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
|
select_assign_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
|
||||||
|
@ -78,7 +80,8 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
cube_count,
|
cube_count,
|
||||||
cube_dim,
|
cube_dim,
|
||||||
tensor.as_tensor_arg(1),
|
tensor.as_tensor_arg(1),
|
||||||
TensorArg::from_raw_parts(&indices.handle, &tensor.strides, &tensor.shape.dims, 1),
|
// Ignored shape + custom strides.
|
||||||
|
TensorArg::from_raw_parts(&indices.handle, &strides, &strides, 1),
|
||||||
value.as_tensor_arg(1),
|
value.as_tensor_arg(1),
|
||||||
ScalarArg::new(dim as u32),
|
ScalarArg::new(dim as u32),
|
||||||
);
|
);
|
||||||
|
|
Loading…
Reference in New Issue