Merge branch 'main' into index-cpa-to-cubecl

This commit is contained in:
mepatrick73 2024-08-14 18:26:38 -04:00
commit ab5d437adf
10 changed files with 168 additions and 313 deletions

35
Cargo.lock generated
View File

@ -1379,7 +1379,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl" name = "cubecl"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"cubecl-core", "cubecl-core",
"cubecl-cuda", "cubecl-cuda",
@ -1390,11 +1390,12 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-common" name = "cubecl-common"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"getrandom", "getrandom",
"pollster", "pollster",
"portable-atomic-util",
"rand", "rand",
"serde", "serde",
"spin", "spin",
@ -1404,7 +1405,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-core" name = "cubecl-core"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-macros", "cubecl-macros",
@ -1419,7 +1420,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cuda" name = "cubecl-cuda"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common",
@ -1434,7 +1435,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-linalg" name = "cubecl-linalg"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-core", "cubecl-core",
@ -1445,7 +1446,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-macros" name = "cubecl-macros"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"proc-macro2", "proc-macro2",
@ -1456,7 +1457,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-runtime" name = "cubecl-runtime"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"cubecl-common", "cubecl-common",
@ -1475,7 +1476,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-wgpu" name = "cubecl-wgpu"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9" source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"bytemuck", "bytemuck",
@ -4747,6 +4748,15 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "portable-atomic-util"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d"
dependencies = [
"portable-atomic",
]
[[package]] [[package]]
name = "powerfmt" name = "powerfmt"
version = "0.2.0" version = "0.2.0"
@ -5678,9 +5688,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.206" version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284" checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
@ -5707,9 +5717,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.206" version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97" checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -5924,6 +5934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [ dependencies = [
"lock_api", "lock_api",
"portable-atomic",
] ]
[[package]] [[package]]

View File

@ -143,8 +143,8 @@ sysinfo = "0.30.13"
systemstat = "0.2.3" systemstat = "0.2.3"
### For the main burn branch. ### ### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" } cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" } cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
### For local development. ### ### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" } # cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" }

View File

@ -8,8 +8,8 @@ use super::{
}; };
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) trait ReduceDimAlgorithm<E: JitElement>: pub(crate) trait ReduceDimAlgorithm<EI: JitElement>:
ReduceDimNaive<E> + ReduceDimShared<E> ReduceDimNaive<EI::Primitive> + ReduceDimShared<EI>
{ {
} }
@ -65,7 +65,7 @@ impl Default for ReduceStrategy {
macro_rules! reduce_operation { macro_rules! reduce_operation {
($name:ident, $ops:ident) => { ($name:ident, $ops:ident) => {
pub(crate) struct $ops; pub(crate) struct $ops;
impl<E: JitElement> ReduceDimAlgorithm<E> for $ops {} impl<EI: JitElement> ReduceDimAlgorithm<EI> for $ops {}
/// Executes the reduce operation with the given strategy. /// Executes the reduce operation with the given strategy.
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>( pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(

View File

@ -1,50 +1,36 @@
use crate::{kernel::reduce::Argmax, JitElement};
use cubecl::{
cpa,
ir::{Elem, Item, Scope, Variable},
};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
use crate::kernel::reduce::Argmax;
use cubecl::cube;
use cubecl::frontend::{Float, Tensor, UInt, ABSOLUTE_POS, F32};
use cubecl::prelude::{Cast, Numeric};
impl<E: JitElement> ReduceDimNaive<E> for Argmax { #[allow(clippy::extra_unused_type_parameters)]
type Accumulator = (Variable, Variable); #[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for Argmax {
type Accumulator = (F32, UInt);
fn initialize_naive( fn initialize_naive() -> (F32, UInt) {
scope: &mut Scope, // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
input_item: Item, let a = F32::new(0.0);
_output_item: Item, let b = F32::new(100000000.0);
) -> Self::Accumulator { (a - b, UInt::new(0))
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial = input_item
.elem()
.constant_from_f64(E::minimum_value().to_f64());
cpa!(scope, max = max_initial);
(max, index)
} }
fn inner_loop_naive( fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
scope: &mut Scope, let (max, index) = accumulator;
(max, index): Self::Accumulator, let val = F32::cast_from(current_value);
value: Variable, if val > *max {
i: Variable, *max = val;
) { *index = i;
let condition = scope.create_local(Elem::Bool); }
cpa!(scope, condition = value > max);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, max = value);
cpa!(scope, index = i);
}));
} }
fn assign_naive( fn assign_naive<EO: Numeric>(
scope: &mut Scope, output: &mut Tensor<EO>,
output: Variable, accumulator: (F32, UInt),
(_max, index): Self::Accumulator, _shape_reduce_dim: UInt,
_shape_reduce_dim: Variable,
) { ) {
let id = Variable::AbsolutePos; let (_, index) = accumulator;
cpa!(scope, output[id] = index); output[ABSOLUTE_POS] = EO::cast_from(index);
} }
} }

View File

@ -1,52 +1,34 @@
use cubecl::{ use crate::kernel::reduce::Argmin;
cpa, use cubecl::cube;
ir::{Elem, Item, Scope, Variable}, use cubecl::prelude::{Cast, Float, Numeric, Tensor, UInt, ABSOLUTE_POS, F32};
};
use crate::{kernel::reduce::Argmin, JitElement};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for Argmin { #[allow(clippy::extra_unused_type_parameters)]
type Accumulator = (Variable, Variable); #[cube]
impl<EI: Numeric> ReduceDimNaive<EI> for Argmin {
type Accumulator = (F32, UInt);
fn initialize_naive( fn initialize_naive() -> (F32, UInt) {
scope: &mut Scope, // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
input_item: Item, (F32::new(100000000.0), UInt::new(0))
_output_item: Item,
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial = input_item
.elem()
.constant_from_f64(E::maximum_value().to_f64());
cpa!(scope, min = min_initial);
(min, index)
} }
fn inner_loop_naive( fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
scope: &mut Scope, let (min, index) = accumulator;
(min, index): Self::Accumulator, let val = F32::cast_from(current_value);
value: Variable, if val < *min {
i: Variable, *min = val;
) { *index = i;
let condition = scope.create_local(Elem::Bool); }
cpa!(scope, condition = value < min);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, min = value);
cpa!(scope, index = i);
}));
} }
fn assign_naive( fn assign_naive<EO: Numeric>(
scope: &mut Scope, output: &mut Tensor<EO>,
output: Variable, accumulator: (F32, UInt),
(_min, index): Self::Accumulator, _shape_reduce_dim: UInt,
_shape_reduce_dim: Variable,
) { ) {
let id = Variable::AbsolutePos; let (_, index) = accumulator;
cpa!(scope, output[id] = index); output[ABSOLUTE_POS] = EO::cast_from(index);
} }
} }

View File

@ -1,32 +1,23 @@
use cubecl::ir::{Item, Scope, Variable}; use cubecl::cube;
use cubecl::frontend::CubeType;
use crate::JitElement; use cubecl::prelude::{Numeric, Tensor, UInt};
/// Specifies the reduce dim algorithm in use /// Specifies the reduce dim algorithm in use
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static { #[cube]
pub trait ReduceDimNaive<EI: Numeric>: Send + Sync + 'static {
/// The reduction accumulator /// The reduction accumulator
type Accumulator: Copy; type Accumulator: Copy + CubeType;
/// Initialization for naive algorithm /// Initialization for naive algorithm
fn initialize_naive( fn initialize_naive() -> Self::Accumulator;
scope: &mut Scope,
input_item: Item,
output_item: Item,
) -> Self::Accumulator;
/// Inner loop for naive algorithm /// Inner loop for naive algorithm
fn inner_loop_naive( fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: UInt);
scope: &mut Scope,
accumulator: Self::Accumulator,
current_value: Variable,
i: Variable,
);
/// Assignation for naive algorithm /// Assignation for naive algorithm
fn assign_naive( fn assign_naive<EO: Numeric>(
scope: &mut Scope, output: &mut Tensor<EO>,
output: Variable,
accumulator: Self::Accumulator, accumulator: Self::Accumulator,
shape_reduce_dim: Variable, shape_reduce_dim: UInt,
); );
} }

View File

@ -1,32 +1,22 @@
use crate::{kernel::reduce::MeanDim, JitElement}; use crate::kernel::reduce::MeanDim;
use cubecl::{ use cubecl::cube;
cpa, use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
ir::{Item, Scope, Variable},
};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for MeanDim { #[cube]
type Accumulator = Variable; impl<EI: Numeric> ReduceDimNaive<EI> for MeanDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { fn initialize_naive() -> EI {
scope.zero(output_item) EI::from_int(0)
} }
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
cpa!(scope, accumulator += value); *accumulator += current_value;
} }
fn assign_naive( fn assign_naive<EO: Numeric>(output: &mut Tensor<EO>, accumulator: EI, shape_reduce_dim: UInt) {
scope: &mut Scope, output[ABSOLUTE_POS] = EO::cast_from(accumulator) / EO::cast_from(shape_reduce_dim);
output: Variable,
accumulator: Variable,
shape_reduce_dim: Variable,
) {
let id = Variable::AbsolutePos;
let denominator = scope.create_local(accumulator.item());
cpa!(scope, denominator = cast(shape_reduce_dim));
cpa!(scope, accumulator = accumulator / denominator);
cpa!(scope, output[id] = accumulator);
} }
} }

View File

@ -1,29 +1,26 @@
use crate::{kernel::reduce::ProdDim, JitElement}; use crate::kernel::reduce::ProdDim;
use cubecl::{ use cubecl::cube;
cpa, use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
ir::{Item, Scope, Variable},
};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
impl<E: JitElement> ReduceDimNaive<E> for ProdDim { #[cube]
type Accumulator = Variable; impl<EI: Numeric> ReduceDimNaive<EI> for ProdDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { fn initialize_naive() -> EI {
scope.create_with_value(1, output_item) EI::from_int(1)
} }
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
cpa!(scope, accumulator *= value); *accumulator *= current_value;
} }
fn assign_naive( fn assign_naive<EO: Numeric>(
scope: &mut Scope, output: &mut Tensor<EO>,
output: Variable, accumulator: EI,
accumulator: Variable, _shape_reduce_dim: UInt,
_shape_reduce_dim: Variable,
) { ) {
let id = Variable::AbsolutePos; output[ABSOLUTE_POS] = EO::cast_from(accumulator);
cpa!(scope, output[id] = accumulator);
} }
} }

View File

@ -1,148 +1,42 @@
use cubecl::{ use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
cpa, use cubecl::calculate_cube_count_elemwise;
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, use cubecl::prelude::*;
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimNaive<E>> { #[cube(launch_unchecked)]
tensor: Variable, pub(crate) fn naive_reduce_dim_compute_shader<RD: ReduceDimNaive<EI>, EI: Numeric, EO: Numeric>(
dim: usize, input: &Tensor<EI>,
output: Variable, output: &mut Tensor<EO>,
_reduce_dim: PhantomData<RD>, dim: UInt,
_elem: PhantomData<E>, ) {
} if ABSOLUTE_POS >= output.len() {
return;
}
#[derive(new)] let mut offset_input = UInt::new(0);
pub(crate) struct NaiveReduceDimEagerKernel<
RD: ReduceDimNaive<EI>,
R: JitRuntime,
EI: JitElement,
EO: JitElement,
> {
dim: usize,
reduce_dim: PhantomData<RD>,
_runtime: PhantomData<R>,
_elem_in: PhantomData<EI>,
_elem_out: PhantomData<EO>,
}
impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for i in range(0, input.rank(), Comptime::new(false)) {
for NaiveReduceDimEagerKernel<RD, R, EI, EO> let mut offset_local = ABSOLUTE_POS / output.stride(i);
{ offset_local = offset_local % output.shape(i);
fn define(&self) -> KernelDefinition { if i != dim {
let mut scope = Scope::root(); offset_input += offset_local * input.stride(i);
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
NaiveReduceDimComputeShader {
tensor,
dim: self.dim,
output,
_reduce_dim: PhantomData::<RD>,
_elem: PhantomData::<EI>,
} }
.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
} }
fn id(&self) -> cubecl::KernelId { let mut accumulator = RD::initialize_naive();
cubecl::KernelId::new::<Self>().info(self.dim)
for i in range(0, input.shape(dim), Comptime::new(false)) {
let index = i * input.stride(dim) + offset_input;
RD::inner_loop_naive(&mut accumulator, input[index], i);
} }
}
impl<E: JitElement, RD: ReduceDimNaive<E>> NaiveReduceDimComputeShader<E, RD> { RD::assign_naive::<EO>(output, accumulator, input.shape(dim));
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let dim: Variable = self.dim.into();
let id = Variable::AbsolutePos;
let output = self.output;
let offset_input = scope.zero(Elem::UInt);
let stride_input_dim = scope.create_local(Elem::UInt);
let shape_input_dim = scope.create_local(Elem::UInt);
cpa!(
scope,
range(0u32, Variable::Rank).for_each(|i, scope| {
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
cpa!(scope, stride_input = stride(tensor, i));
cpa!(scope, stride_output = stride(output, i));
cpa!(scope, shape_output = shape(output, i));
let offset_local = scope.create_local(Elem::UInt);
cpa!(scope, offset_local = id / stride_output);
cpa!(scope, offset_local = offset_local % shape_output);
let is_dim_reduce = scope.create_local(Elem::Bool);
cpa!(scope, is_dim_reduce = i == dim);
cpa!(scope, if(is_dim_reduce).then(|scope|{
cpa!(scope, shape_input_dim = shape(tensor, i));
cpa!(scope, stride_input_dim = stride_input);
cpa!(scope, offset_input += offset_local);
}).else(|scope|{
cpa!(scope, offset_local = offset_local * stride_input);
cpa!(scope, offset_input += offset_local);
}));
})
);
let accumulator = RD::initialize_naive(scope, tensor.item(), output.item());
cpa!(
scope,
range(0u32, shape_input_dim).for_each(|i, scope| {
let index = scope.create_local(Elem::UInt);
cpa!(scope, index = i * stride_input_dim);
cpa!(scope, index += offset_input);
let value = scope.create_local(tensor.item());
cpa!(scope, value = tensor[index]);
RD::inner_loop_naive(scope, accumulator, value, i);
})
);
RD::assign_naive(scope, output, accumulator, shape_input_dim);
}
} }
/// Executes the naive kernel for reduce dim /// Executes the naive kernel for reduce dim
pub fn reduce_dim_naive< pub fn reduce_dim_naive<
RD: ReduceDimNaive<EI>, RD: ReduceDimNaive<EI::Primitive>,
R: JitRuntime, R: JitRuntime,
EI: JitElement, EI: JitElement,
EO: JitElement, EO: JitElement,
@ -152,12 +46,20 @@ pub fn reduce_dim_naive<
output: JitTensor<R, EO, D>, output: JitTensor<R, EO, D>,
dim: usize, dim: usize,
) -> JitTensor<R, EO, D> { ) -> JitTensor<R, EO, D> {
let kernel = NaiveReduceDimEagerKernel::<RD, R, EI, EO>::new(dim); let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise::<R::Server>(output.shape.num_elements(), cube_dim);
Execution::start(kernel, input.client.clone()) unsafe {
.inputs(&[input.as_handle_ref()]) naive_reduce_dim_compute_shader::launch_unchecked::<RD, EI::Primitive, EO::Primitive, R>(
.outputs(&[output.as_handle_ref()]) &input.client,
.execute(CubeCountSettings::Output { pos: 0 }); cube_count,
cube_dim,
input.as_tensor_arg(1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
);
}
output output
} }

View File

@ -1,29 +1,25 @@
use crate::{kernel::reduce::SumDim, JitElement};
use cubecl::{
cpa,
ir::{Item, Scope, Variable},
};
use super::base::ReduceDimNaive; use super::base::ReduceDimNaive;
use crate::kernel::reduce::SumDim;
use cubecl::cube;
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
impl<E: JitElement> ReduceDimNaive<E> for SumDim { #[cube]
type Accumulator = Variable; impl<EI: Numeric> ReduceDimNaive<EI> for SumDim {
type Accumulator = EI;
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { fn initialize_naive() -> EI {
scope.zero(output_item) EI::from_int(0)
} }
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
cpa!(scope, accumulator += value); *accumulator += current_value;
} }
fn assign_naive( fn assign_naive<EO: Numeric>(
scope: &mut Scope, output: &mut Tensor<EO>,
output: Variable, accumulator: EI,
accumulator: Variable, _shape_reduce_dim: UInt,
_shape_reduce_dim: Variable,
) { ) {
let id = Variable::AbsolutePos; output[ABSOLUTE_POS] = EO::cast_from(accumulator);
cpa!(scope, output[id] = accumulator);
} }
} }