Merge branch 'main' into index-cpa-to-cubecl

This commit is contained in:
mepatrick73 2024-08-14 18:26:38 -04:00
commit ab5d437adf
10 changed files with 168 additions and 313 deletions

35
Cargo.lock generated
View File

@ -1379,7 +1379,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1390,11 +1390,12 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"derive-new",
"getrandom",
"pollster",
"portable-atomic-util",
"rand",
"serde",
"spin",
@ -1404,7 +1405,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"bytemuck",
"cubecl-macros",
@ -1419,7 +1420,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1434,7 +1435,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1445,7 +1446,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"derive-new",
"proc-macro2",
@ -1456,7 +1457,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"async-channel",
"cubecl-common",
@ -1475,7 +1476,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [
"async-channel",
"bytemuck",
@ -4747,6 +4748,15 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "portable-atomic-util"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d"
dependencies = [
"portable-atomic",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
@ -5678,9 +5688,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
version = "1.0.206"
version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284"
checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2"
dependencies = [
"serde_derive",
]
@ -5707,9 +5717,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.206"
version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97"
checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e"
dependencies = [
"proc-macro2",
"quote",
@ -5924,6 +5934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
"portable-atomic",
]
[[package]]

View File

@ -143,8 +143,8 @@ sysinfo = "0.30.13"
systemstat = "0.2.3"
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }

View File

@ -8,8 +8,8 @@ use super::{
};
#[allow(dead_code)]
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
ReduceDimNaive<E> + ReduceDimShared<E>
pub(crate) trait ReduceDimAlgorithm<EI: JitElement>:
ReduceDimNaive<EI::Primitive> + ReduceDimShared<EI>
{
}
@ -65,7 +65,7 @@ impl Default for ReduceStrategy {
macro_rules! reduce_operation {
($name:ident, $ops:ident) => {
pub(crate) struct $ops;
impl<E: JitElement> ReduceDimAlgorithm<E> for $ops {}
impl<EI: JitElement> ReduceDimAlgorithm<EI> for $ops {}
/// Executes the reduce operation with the given strategy.
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(

View File

@ -1,50 +1,36 @@
use crate::{kernel::reduce::Argmax, JitElement};
use cubecl::{
cpa,
ir::{Elem, Item, Scope, Variable},
};
use super::base::ReduceDimNaive;
use crate::kernel::reduce::Argmax;
use cubecl::cube;
use cubecl::frontend::{Float, Tensor, UInt, ABSOLUTE_POS, F32};
use cubecl::prelude::{Cast, Numeric};
impl<E: JitElement> ReduceDimNaive<E> for Argmax {
type Accumulator = (Variable, Variable);
#[allow(clippy::extra_unused_type_parameters)]
#[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for Argmax {
type Accumulator = (F32, UInt);
fn initialize_naive(
scope: &mut Scope,
input_item: Item,
_output_item: Item,
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial = input_item
.elem()
.constant_from_f64(E::minimum_value().to_f64());
cpa!(scope, max = max_initial);
(max, index)
fn initialize_naive() -> (F32, UInt) {
// TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
let a = F32::new(0.0);
let b = F32::new(100000000.0);
(a - b, UInt::new(0))
}
fn inner_loop_naive(
scope: &mut Scope,
(max, index): Self::Accumulator,
value: Variable,
i: Variable,
) {
let condition = scope.create_local(Elem::Bool);
cpa!(scope, condition = value > max);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, max = value);
cpa!(scope, index = i);
}));
fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
let (max, index) = accumulator;
let val = F32::cast_from(current_value);
if val > *max {
*max = val;
*index = i;
}
}
fn assign_naive(
scope: &mut Scope,
output: Variable,
(_max, index): Self::Accumulator,
_shape_reduce_dim: Variable,
fn assign_naive<EO: Numeric>(
output: &mut Tensor<EO>,
accumulator: (F32, UInt),
_shape_reduce_dim: UInt,
) {
let id = Variable::AbsolutePos;
cpa!(scope, output[id] = index);
let (_, index) = accumulator;
output[ABSOLUTE_POS] = EO::cast_from(index);
}
}

View File

@ -1,52 +1,34 @@
use cubecl::{
cpa,
ir::{Elem, Item, Scope, Variable},
};
use crate::{kernel::reduce::Argmin, JitElement};
use crate::kernel::reduce::Argmin;
use cubecl::cube;
use cubecl::prelude::{Cast, Float, Numeric, Tensor, UInt, ABSOLUTE_POS, F32};
use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for Argmin {
type Accumulator = (Variable, Variable);
#[allow(clippy::extra_unused_type_parameters)]
#[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for Argmin {
type Accumulator = (F32, UInt);
fn initialize_naive(
scope: &mut Scope,
input_item: Item,
_output_item: Item,
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial = input_item
.elem()
.constant_from_f64(E::maximum_value().to_f64());
cpa!(scope, min = min_initial);
(min, index)
fn initialize_naive() -> (F32, UInt) {
// TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
(F32::new(100000000.0), UInt::new(0))
}
fn inner_loop_naive(
scope: &mut Scope,
(min, index): Self::Accumulator,
value: Variable,
i: Variable,
) {
let condition = scope.create_local(Elem::Bool);
cpa!(scope, condition = value < min);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, min = value);
cpa!(scope, index = i);
}));
fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
let (min, index) = accumulator;
let val = F32::cast_from(current_value);
if val < *min {
*min = val;
*index = i;
}
}
fn assign_naive(
scope: &mut Scope,
output: Variable,
(_min, index): Self::Accumulator,
_shape_reduce_dim: Variable,
fn assign_naive<EO: Numeric>(
output: &mut Tensor<EO>,
accumulator: (F32, UInt),
_shape_reduce_dim: UInt,
) {
let id = Variable::AbsolutePos;
cpa!(scope, output[id] = index);
let (_, index) = accumulator;
output[ABSOLUTE_POS] = EO::cast_from(index);
}
}

View File

@ -1,32 +1,23 @@
use cubecl::ir::{Item, Scope, Variable};
use crate::JitElement;
use cubecl::cube;
use cubecl::frontend::CubeType;
use cubecl::prelude::{Numeric, Tensor, UInt};
/// Specifies the reduce dim algorithm in use
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static {
#[cube]
pub trait ReduceDimNaive<EI: Numeric>: Send + Sync + 'static {
/// The reduction accumulator
type Accumulator: Copy;
type Accumulator: Copy + CubeType;
/// Initialization for naive algorithm
fn initialize_naive(
scope: &mut Scope,
input_item: Item,
output_item: Item,
) -> Self::Accumulator;
fn initialize_naive() -> Self::Accumulator;
/// Inner loop for naive algorithm
fn inner_loop_naive(
scope: &mut Scope,
accumulator: Self::Accumulator,
current_value: Variable,
i: Variable,
);
fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: UInt);
/// Assignation for naive algorithm
fn assign_naive(
scope: &mut Scope,
output: Variable,
fn assign_naive<EO: Numeric>(
output: &mut Tensor<EO>,
accumulator: Self::Accumulator,
shape_reduce_dim: Variable,
shape_reduce_dim: UInt,
);
}

View File

@ -1,32 +1,22 @@
use crate::{kernel::reduce::MeanDim, JitElement};
use cubecl::{
cpa,
ir::{Item, Scope, Variable},
};
use crate::kernel::reduce::MeanDim;
use cubecl::cube;
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for MeanDim {
type Accumulator = Variable;
#[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for MeanDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.zero(output_item)
fn initialize_naive() -> EI {
EI::from_int(0)
}
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
cpa!(scope, accumulator += value);
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
*accumulator += current_value;
}
fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
shape_reduce_dim: Variable,
) {
let id = Variable::AbsolutePos;
let denominator = scope.create_local(accumulator.item());
cpa!(scope, denominator = cast(shape_reduce_dim));
cpa!(scope, accumulator = accumulator / denominator);
cpa!(scope, output[id] = accumulator);
fn assign_naive<EO: Numeric>(output: &mut Tensor<EO>, accumulator: EI, shape_reduce_dim: UInt) {
output[ABSOLUTE_POS] = EO::cast_from(accumulator) / EO::cast_from(shape_reduce_dim);
}
}

View File

@ -1,29 +1,26 @@
use crate::{kernel::reduce::ProdDim, JitElement};
use cubecl::{
cpa,
ir::{Item, Scope, Variable},
};
use crate::kernel::reduce::ProdDim;
use cubecl::cube;
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for ProdDim {
type Accumulator = Variable;
#[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for ProdDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.create_with_value(1, output_item)
fn initialize_naive() -> EI {
EI::from_int(1)
}
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
cpa!(scope, accumulator *= value);
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
*accumulator *= current_value;
}
fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
_shape_reduce_dim: Variable,
fn assign_naive<EO: Numeric>(
output: &mut Tensor<EO>,
accumulator: EI,
_shape_reduce_dim: UInt,
) {
let id = Variable::AbsolutePos;
cpa!(scope, output[id] = accumulator);
output[ABSOLUTE_POS] = EO::cast_from(accumulator);
}
}

View File

@ -1,148 +1,42 @@
use cubecl::{
cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use cubecl::calculate_cube_count_elemwise;
use cubecl::prelude::*;
use super::base::ReduceDimNaive;
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimNaive<E>> {
tensor: Variable,
dim: usize,
output: Variable,
_reduce_dim: PhantomData<RD>,
_elem: PhantomData<E>,
}
#[cube(launch_unchecked)]
pub(crate) fn naive_reduce_dim_compute_shader<RD: ReduceDimNaive<EI>, EI: Numeric, EO: Numeric>(
input: &Tensor<EI>,
output: &mut Tensor<EO>,
dim: UInt,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
#[derive(new)]
pub(crate) struct NaiveReduceDimEagerKernel<
RD: ReduceDimNaive<EI>,
R: JitRuntime,
EI: JitElement,
EO: JitElement,
> {
dim: usize,
reduce_dim: PhantomData<RD>,
_runtime: PhantomData<R>,
_elem_in: PhantomData<EI>,
_elem_out: PhantomData<EO>,
}
let mut offset_input = UInt::new(0);
impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kernel
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
{
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
NaiveReduceDimComputeShader {
tensor,
dim: self.dim,
output,
_reduce_dim: PhantomData::<RD>,
_elem: PhantomData::<EI>,
for i in range(0, input.rank(), Comptime::new(false)) {
let mut offset_local = ABSOLUTE_POS / output.stride(i);
offset_local = offset_local % output.shape(i);
if i != dim {
offset_input += offset_local * input.stride(i);
}
.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info(self.dim)
let mut accumulator = RD::initialize_naive();
for i in range(0, input.shape(dim), Comptime::new(false)) {
let index = i * input.stride(dim) + offset_input;
RD::inner_loop_naive(&mut accumulator, input[index], i);
}
}
impl<E: JitElement, RD: ReduceDimNaive<E>> NaiveReduceDimComputeShader<E, RD> {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let dim: Variable = self.dim.into();
let id = Variable::AbsolutePos;
let output = self.output;
let offset_input = scope.zero(Elem::UInt);
let stride_input_dim = scope.create_local(Elem::UInt);
let shape_input_dim = scope.create_local(Elem::UInt);
cpa!(
scope,
range(0u32, Variable::Rank).for_each(|i, scope| {
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
cpa!(scope, stride_input = stride(tensor, i));
cpa!(scope, stride_output = stride(output, i));
cpa!(scope, shape_output = shape(output, i));
let offset_local = scope.create_local(Elem::UInt);
cpa!(scope, offset_local = id / stride_output);
cpa!(scope, offset_local = offset_local % shape_output);
let is_dim_reduce = scope.create_local(Elem::Bool);
cpa!(scope, is_dim_reduce = i == dim);
cpa!(scope, if(is_dim_reduce).then(|scope|{
cpa!(scope, shape_input_dim = shape(tensor, i));
cpa!(scope, stride_input_dim = stride_input);
cpa!(scope, offset_input += offset_local);
}).else(|scope|{
cpa!(scope, offset_local = offset_local * stride_input);
cpa!(scope, offset_input += offset_local);
}));
})
);
let accumulator = RD::initialize_naive(scope, tensor.item(), output.item());
cpa!(
scope,
range(0u32, shape_input_dim).for_each(|i, scope| {
let index = scope.create_local(Elem::UInt);
cpa!(scope, index = i * stride_input_dim);
cpa!(scope, index += offset_input);
let value = scope.create_local(tensor.item());
cpa!(scope, value = tensor[index]);
RD::inner_loop_naive(scope, accumulator, value, i);
})
);
RD::assign_naive(scope, output, accumulator, shape_input_dim);
}
RD::assign_naive::<EO>(output, accumulator, input.shape(dim));
}
/// Executes the naive kernel for reduce dim
pub fn reduce_dim_naive<
RD: ReduceDimNaive<EI>,
RD: ReduceDimNaive<EI::Primitive>,
R: JitRuntime,
EI: JitElement,
EO: JitElement,
@ -152,12 +46,20 @@ pub fn reduce_dim_naive<
output: JitTensor<R, EO, D>,
dim: usize,
) -> JitTensor<R, EO, D> {
let kernel = NaiveReduceDimEagerKernel::<RD, R, EI, EO>::new(dim);
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise::<R::Server>(output.shape.num_elements(), cube_dim);
Execution::start(kernel, input.client.clone())
.inputs(&[input.as_handle_ref()])
.outputs(&[output.as_handle_ref()])
.execute(CubeCountSettings::Output { pos: 0 });
unsafe {
naive_reduce_dim_compute_shader::launch_unchecked::<RD, EI::Primitive, EO::Primitive, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg(1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
);
}
output
}

View File

@ -1,29 +1,25 @@
use crate::{kernel::reduce::SumDim, JitElement};
use cubecl::{
cpa,
ir::{Item, Scope, Variable},
};
use super::base::ReduceDimNaive;
use crate::kernel::reduce::SumDim;
use cubecl::cube;
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
impl<E: JitElement> ReduceDimNaive<E> for SumDim {
type Accumulator = Variable;
#[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for SumDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.zero(output_item)
fn initialize_naive() -> EI {
EI::from_int(0)
}
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
cpa!(scope, accumulator += value);
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
*accumulator += current_value;
}
fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
_shape_reduce_dim: Variable,
fn assign_naive<EO: Numeric>(
output: &mut Tensor<EO>,
accumulator: EI,
_shape_reduce_dim: UInt,
) {
let id = Variable::AbsolutePos;
cpa!(scope, output[id] = accumulator);
output[ABSOLUTE_POS] = EO::cast_from(accumulator);
}
}