mirror of https://github.com/tracel-ai/burn.git
Merge branch 'main' into index-cpa-to-cubecl
This commit is contained in:
commit
ab5d437adf
|
@ -1379,7 +1379,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl"
|
name = "cubecl"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
"cubecl-cuda",
|
"cubecl-cuda",
|
||||||
|
@ -1390,11 +1390,12 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-common"
|
name = "cubecl-common"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
"pollster",
|
"pollster",
|
||||||
|
"portable-atomic-util",
|
||||||
"rand",
|
"rand",
|
||||||
"serde",
|
"serde",
|
||||||
"spin",
|
"spin",
|
||||||
|
@ -1404,7 +1405,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-core"
|
name = "cubecl-core"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-macros",
|
"cubecl-macros",
|
||||||
|
@ -1419,7 +1420,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-cuda"
|
name = "cubecl-cuda"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-common",
|
"cubecl-common",
|
||||||
|
@ -1434,7 +1435,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-linalg"
|
name = "cubecl-linalg"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
|
@ -1445,7 +1446,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-macros"
|
name = "cubecl-macros"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
@ -1456,7 +1457,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-runtime"
|
name = "cubecl-runtime"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"cubecl-common",
|
"cubecl-common",
|
||||||
|
@ -1475,7 +1476,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cubecl-wgpu"
|
name = "cubecl-wgpu"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=4e17724fbc98de02d3cb4275e249ba660a4b2cb9#4e17724fbc98de02d3cb4275e249ba660a4b2cb9"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
|
@ -4747,6 +4748,15 @@ version = "1.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
|
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-atomic-util"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d"
|
||||||
|
dependencies = [
|
||||||
|
"portable-atomic",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "powerfmt"
|
name = "powerfmt"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
|
@ -5678,9 +5688,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.206"
|
version = "1.0.207"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284"
|
checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
@ -5707,9 +5717,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.206"
|
version = "1.0.207"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97"
|
checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -5924,6 +5934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"lock_api",
|
"lock_api",
|
||||||
|
"portable-atomic",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -143,8 +143,8 @@ sysinfo = "0.30.13"
|
||||||
systemstat = "0.2.3"
|
systemstat = "0.2.3"
|
||||||
|
|
||||||
### For the main burn branch. ###
|
### For the main burn branch. ###
|
||||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
|
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
|
||||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4e17724fbc98de02d3cb4275e249ba660a4b2cb9" }
|
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
|
||||||
### For local development. ###
|
### For local development. ###
|
||||||
# cubecl = { path = "../cubecl/crates/cubecl" }
|
# cubecl = { path = "../cubecl/crates/cubecl" }
|
||||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
||||||
|
|
|
@ -8,8 +8,8 @@ use super::{
|
||||||
};
|
};
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
|
pub(crate) trait ReduceDimAlgorithm<EI: JitElement>:
|
||||||
ReduceDimNaive<E> + ReduceDimShared<E>
|
ReduceDimNaive<EI::Primitive> + ReduceDimShared<EI>
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ impl Default for ReduceStrategy {
|
||||||
macro_rules! reduce_operation {
|
macro_rules! reduce_operation {
|
||||||
($name:ident, $ops:ident) => {
|
($name:ident, $ops:ident) => {
|
||||||
pub(crate) struct $ops;
|
pub(crate) struct $ops;
|
||||||
impl<E: JitElement> ReduceDimAlgorithm<E> for $ops {}
|
impl<EI: JitElement> ReduceDimAlgorithm<EI> for $ops {}
|
||||||
|
|
||||||
/// Executes the reduce operation with the given strategy.
|
/// Executes the reduce operation with the given strategy.
|
||||||
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
||||||
|
|
|
@ -1,50 +1,36 @@
|
||||||
use crate::{kernel::reduce::Argmax, JitElement};
|
|
||||||
use cubecl::{
|
|
||||||
cpa,
|
|
||||||
ir::{Elem, Item, Scope, Variable},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
use crate::kernel::reduce::Argmax;
|
||||||
|
use cubecl::cube;
|
||||||
|
use cubecl::frontend::{Float, Tensor, UInt, ABSOLUTE_POS, F32};
|
||||||
|
use cubecl::prelude::{Cast, Numeric};
|
||||||
|
|
||||||
impl<E: JitElement> ReduceDimNaive<E> for Argmax {
|
#[allow(clippy::extra_unused_type_parameters)]
|
||||||
type Accumulator = (Variable, Variable);
|
#[cube]
|
||||||
|
impl<EI: Numeric> ReduceDimNaive<EI> for Argmax {
|
||||||
|
type Accumulator = (F32, UInt);
|
||||||
|
|
||||||
fn initialize_naive(
|
fn initialize_naive() -> (F32, UInt) {
|
||||||
scope: &mut Scope,
|
// TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
|
||||||
input_item: Item,
|
let a = F32::new(0.0);
|
||||||
_output_item: Item,
|
let b = F32::new(100000000.0);
|
||||||
) -> Self::Accumulator {
|
(a - b, UInt::new(0))
|
||||||
let index = scope.create_local(Elem::UInt);
|
|
||||||
let max = scope.create_local(input_item);
|
|
||||||
let max_initial = input_item
|
|
||||||
.elem()
|
|
||||||
.constant_from_f64(E::minimum_value().to_f64());
|
|
||||||
cpa!(scope, max = max_initial);
|
|
||||||
|
|
||||||
(max, index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inner_loop_naive(
|
fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
|
||||||
scope: &mut Scope,
|
let (max, index) = accumulator;
|
||||||
(max, index): Self::Accumulator,
|
let val = F32::cast_from(current_value);
|
||||||
value: Variable,
|
if val > *max {
|
||||||
i: Variable,
|
*max = val;
|
||||||
) {
|
*index = i;
|
||||||
let condition = scope.create_local(Elem::Bool);
|
}
|
||||||
cpa!(scope, condition = value > max);
|
|
||||||
cpa!(scope, if(condition).then(|scope| {
|
|
||||||
cpa!(scope, max = value);
|
|
||||||
cpa!(scope, index = i);
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(
|
||||||
scope: &mut Scope,
|
output: &mut Tensor<EO>,
|
||||||
output: Variable,
|
accumulator: (F32, UInt),
|
||||||
(_max, index): Self::Accumulator,
|
_shape_reduce_dim: UInt,
|
||||||
_shape_reduce_dim: Variable,
|
|
||||||
) {
|
) {
|
||||||
let id = Variable::AbsolutePos;
|
let (_, index) = accumulator;
|
||||||
cpa!(scope, output[id] = index);
|
output[ABSOLUTE_POS] = EO::cast_from(index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,52 +1,34 @@
|
||||||
use cubecl::{
|
use crate::kernel::reduce::Argmin;
|
||||||
cpa,
|
use cubecl::cube;
|
||||||
ir::{Elem, Item, Scope, Variable},
|
use cubecl::prelude::{Cast, Float, Numeric, Tensor, UInt, ABSOLUTE_POS, F32};
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{kernel::reduce::Argmin, JitElement};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
|
||||||
impl<E: JitElement> ReduceDimNaive<E> for Argmin {
|
#[allow(clippy::extra_unused_type_parameters)]
|
||||||
type Accumulator = (Variable, Variable);
|
#[cube]
|
||||||
|
impl<EI: Numeric> ReduceDimNaive<EI> for Argmin {
|
||||||
|
type Accumulator = (F32, UInt);
|
||||||
|
|
||||||
fn initialize_naive(
|
fn initialize_naive() -> (F32, UInt) {
|
||||||
scope: &mut Scope,
|
// TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
|
||||||
input_item: Item,
|
(F32::new(100000000.0), UInt::new(0))
|
||||||
_output_item: Item,
|
|
||||||
) -> Self::Accumulator {
|
|
||||||
let index = scope.create_local(Elem::UInt);
|
|
||||||
let min = scope.create_local(input_item);
|
|
||||||
let min_initial = input_item
|
|
||||||
.elem()
|
|
||||||
.constant_from_f64(E::maximum_value().to_f64());
|
|
||||||
|
|
||||||
cpa!(scope, min = min_initial);
|
|
||||||
|
|
||||||
(min, index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inner_loop_naive(
|
fn inner_loop_naive(accumulator: &mut (F32, UInt), current_value: EI, i: UInt) {
|
||||||
scope: &mut Scope,
|
let (min, index) = accumulator;
|
||||||
(min, index): Self::Accumulator,
|
let val = F32::cast_from(current_value);
|
||||||
value: Variable,
|
if val < *min {
|
||||||
i: Variable,
|
*min = val;
|
||||||
) {
|
*index = i;
|
||||||
let condition = scope.create_local(Elem::Bool);
|
}
|
||||||
cpa!(scope, condition = value < min);
|
|
||||||
cpa!(scope, if(condition).then(|scope| {
|
|
||||||
cpa!(scope, min = value);
|
|
||||||
cpa!(scope, index = i);
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(
|
||||||
scope: &mut Scope,
|
output: &mut Tensor<EO>,
|
||||||
output: Variable,
|
accumulator: (F32, UInt),
|
||||||
(_min, index): Self::Accumulator,
|
_shape_reduce_dim: UInt,
|
||||||
_shape_reduce_dim: Variable,
|
|
||||||
) {
|
) {
|
||||||
let id = Variable::AbsolutePos;
|
let (_, index) = accumulator;
|
||||||
cpa!(scope, output[id] = index);
|
output[ABSOLUTE_POS] = EO::cast_from(index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,23 @@
|
||||||
use cubecl::ir::{Item, Scope, Variable};
|
use cubecl::cube;
|
||||||
|
use cubecl::frontend::CubeType;
|
||||||
use crate::JitElement;
|
use cubecl::prelude::{Numeric, Tensor, UInt};
|
||||||
|
|
||||||
/// Specifies the reduce dim algorithm in use
|
/// Specifies the reduce dim algorithm in use
|
||||||
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static {
|
#[cube]
|
||||||
|
pub trait ReduceDimNaive<EI: Numeric>: Send + Sync + 'static {
|
||||||
/// The reduction accumulator
|
/// The reduction accumulator
|
||||||
type Accumulator: Copy;
|
type Accumulator: Copy + CubeType;
|
||||||
|
|
||||||
/// Initialization for naive algorithm
|
/// Initialization for naive algorithm
|
||||||
fn initialize_naive(
|
fn initialize_naive() -> Self::Accumulator;
|
||||||
scope: &mut Scope,
|
|
||||||
input_item: Item,
|
|
||||||
output_item: Item,
|
|
||||||
) -> Self::Accumulator;
|
|
||||||
|
|
||||||
/// Inner loop for naive algorithm
|
/// Inner loop for naive algorithm
|
||||||
fn inner_loop_naive(
|
fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: UInt);
|
||||||
scope: &mut Scope,
|
|
||||||
accumulator: Self::Accumulator,
|
|
||||||
current_value: Variable,
|
|
||||||
i: Variable,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Assignation for naive algorithm
|
/// Assignation for naive algorithm
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(
|
||||||
scope: &mut Scope,
|
output: &mut Tensor<EO>,
|
||||||
output: Variable,
|
|
||||||
accumulator: Self::Accumulator,
|
accumulator: Self::Accumulator,
|
||||||
shape_reduce_dim: Variable,
|
shape_reduce_dim: UInt,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,22 @@
|
||||||
use crate::{kernel::reduce::MeanDim, JitElement};
|
use crate::kernel::reduce::MeanDim;
|
||||||
use cubecl::{
|
use cubecl::cube;
|
||||||
cpa,
|
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
|
||||||
ir::{Item, Scope, Variable},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
|
||||||
impl<E: JitElement> ReduceDimNaive<E> for MeanDim {
|
#[cube]
|
||||||
type Accumulator = Variable;
|
impl<EI: Numeric> ReduceDimNaive<EI> for MeanDim {
|
||||||
|
type Accumulator = EI;
|
||||||
|
|
||||||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
fn initialize_naive() -> EI {
|
||||||
scope.zero(output_item)
|
EI::from_int(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
|
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
|
||||||
cpa!(scope, accumulator += value);
|
*accumulator += current_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(output: &mut Tensor<EO>, accumulator: EI, shape_reduce_dim: UInt) {
|
||||||
scope: &mut Scope,
|
output[ABSOLUTE_POS] = EO::cast_from(accumulator) / EO::cast_from(shape_reduce_dim);
|
||||||
output: Variable,
|
|
||||||
accumulator: Variable,
|
|
||||||
shape_reduce_dim: Variable,
|
|
||||||
) {
|
|
||||||
let id = Variable::AbsolutePos;
|
|
||||||
let denominator = scope.create_local(accumulator.item());
|
|
||||||
cpa!(scope, denominator = cast(shape_reduce_dim));
|
|
||||||
cpa!(scope, accumulator = accumulator / denominator);
|
|
||||||
cpa!(scope, output[id] = accumulator);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,29 +1,26 @@
|
||||||
use crate::{kernel::reduce::ProdDim, JitElement};
|
use crate::kernel::reduce::ProdDim;
|
||||||
use cubecl::{
|
use cubecl::cube;
|
||||||
cpa,
|
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
|
||||||
ir::{Item, Scope, Variable},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
|
||||||
impl<E: JitElement> ReduceDimNaive<E> for ProdDim {
|
#[cube]
|
||||||
type Accumulator = Variable;
|
impl<EI: Numeric> ReduceDimNaive<EI> for ProdDim {
|
||||||
|
type Accumulator = EI;
|
||||||
|
|
||||||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
fn initialize_naive() -> EI {
|
||||||
scope.create_with_value(1, output_item)
|
EI::from_int(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
|
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
|
||||||
cpa!(scope, accumulator *= value);
|
*accumulator *= current_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(
|
||||||
scope: &mut Scope,
|
output: &mut Tensor<EO>,
|
||||||
output: Variable,
|
accumulator: EI,
|
||||||
accumulator: Variable,
|
_shape_reduce_dim: UInt,
|
||||||
_shape_reduce_dim: Variable,
|
|
||||||
) {
|
) {
|
||||||
let id = Variable::AbsolutePos;
|
output[ABSOLUTE_POS] = EO::cast_from(accumulator);
|
||||||
cpa!(scope, output[id] = accumulator);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,148 +1,42 @@
|
||||||
use cubecl::{
|
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||||
cpa,
|
use cubecl::calculate_cube_count_elemwise;
|
||||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
use cubecl::prelude::*;
|
||||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
|
||||||
OutputInfo,
|
|
||||||
};
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
|
||||||
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimNaive<E>> {
|
#[cube(launch_unchecked)]
|
||||||
tensor: Variable,
|
pub(crate) fn naive_reduce_dim_compute_shader<RD: ReduceDimNaive<EI>, EI: Numeric, EO: Numeric>(
|
||||||
dim: usize,
|
input: &Tensor<EI>,
|
||||||
output: Variable,
|
output: &mut Tensor<EO>,
|
||||||
_reduce_dim: PhantomData<RD>,
|
dim: UInt,
|
||||||
_elem: PhantomData<E>,
|
) {
|
||||||
}
|
if ABSOLUTE_POS >= output.len() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(new)]
|
let mut offset_input = UInt::new(0);
|
||||||
pub(crate) struct NaiveReduceDimEagerKernel<
|
|
||||||
RD: ReduceDimNaive<EI>,
|
|
||||||
R: JitRuntime,
|
|
||||||
EI: JitElement,
|
|
||||||
EO: JitElement,
|
|
||||||
> {
|
|
||||||
dim: usize,
|
|
||||||
reduce_dim: PhantomData<RD>,
|
|
||||||
_runtime: PhantomData<R>,
|
|
||||||
_elem_in: PhantomData<EI>,
|
|
||||||
_elem_out: PhantomData<EO>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kernel
|
for i in range(0, input.rank(), Comptime::new(false)) {
|
||||||
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
|
let mut offset_local = ABSOLUTE_POS / output.stride(i);
|
||||||
{
|
offset_local = offset_local % output.shape(i);
|
||||||
fn define(&self) -> KernelDefinition {
|
if i != dim {
|
||||||
let mut scope = Scope::root();
|
offset_input += offset_local * input.stride(i);
|
||||||
let item_input = EI::cube_elem().into();
|
|
||||||
let item_output = EO::cube_elem().into();
|
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_input,
|
|
||||||
};
|
|
||||||
let output = Variable::GlobalOutputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_output,
|
|
||||||
};
|
|
||||||
|
|
||||||
NaiveReduceDimComputeShader {
|
|
||||||
tensor,
|
|
||||||
dim: self.dim,
|
|
||||||
output,
|
|
||||||
_reduce_dim: PhantomData::<RD>,
|
|
||||||
_elem: PhantomData::<EI>,
|
|
||||||
}
|
}
|
||||||
.expand(&mut scope);
|
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
|
||||||
|
|
||||||
let tensor = InputInfo::Array {
|
|
||||||
item: item_input,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = OutputInfo::Array { item: item_output };
|
|
||||||
|
|
||||||
let info = KernelExpansion {
|
|
||||||
inputs: vec![tensor],
|
|
||||||
outputs: vec![out],
|
|
||||||
scope,
|
|
||||||
};
|
|
||||||
|
|
||||||
let settings = KernelSettings::default();
|
|
||||||
KernelIntegrator::new(info).integrate(settings)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id(&self) -> cubecl::KernelId {
|
let mut accumulator = RD::initialize_naive();
|
||||||
cubecl::KernelId::new::<Self>().info(self.dim)
|
|
||||||
|
for i in range(0, input.shape(dim), Comptime::new(false)) {
|
||||||
|
let index = i * input.stride(dim) + offset_input;
|
||||||
|
RD::inner_loop_naive(&mut accumulator, input[index], i);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<E: JitElement, RD: ReduceDimNaive<E>> NaiveReduceDimComputeShader<E, RD> {
|
RD::assign_naive::<EO>(output, accumulator, input.shape(dim));
|
||||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
|
||||||
let tensor = self.tensor;
|
|
||||||
let dim: Variable = self.dim.into();
|
|
||||||
let id = Variable::AbsolutePos;
|
|
||||||
let output = self.output;
|
|
||||||
|
|
||||||
let offset_input = scope.zero(Elem::UInt);
|
|
||||||
let stride_input_dim = scope.create_local(Elem::UInt);
|
|
||||||
let shape_input_dim = scope.create_local(Elem::UInt);
|
|
||||||
|
|
||||||
cpa!(
|
|
||||||
scope,
|
|
||||||
range(0u32, Variable::Rank).for_each(|i, scope| {
|
|
||||||
let stride_input = scope.create_local(Elem::UInt);
|
|
||||||
let stride_output = scope.create_local(Elem::UInt);
|
|
||||||
let shape_output = scope.create_local(Elem::UInt);
|
|
||||||
|
|
||||||
cpa!(scope, stride_input = stride(tensor, i));
|
|
||||||
cpa!(scope, stride_output = stride(output, i));
|
|
||||||
cpa!(scope, shape_output = shape(output, i));
|
|
||||||
|
|
||||||
let offset_local = scope.create_local(Elem::UInt);
|
|
||||||
cpa!(scope, offset_local = id / stride_output);
|
|
||||||
cpa!(scope, offset_local = offset_local % shape_output);
|
|
||||||
|
|
||||||
let is_dim_reduce = scope.create_local(Elem::Bool);
|
|
||||||
cpa!(scope, is_dim_reduce = i == dim);
|
|
||||||
|
|
||||||
cpa!(scope, if(is_dim_reduce).then(|scope|{
|
|
||||||
cpa!(scope, shape_input_dim = shape(tensor, i));
|
|
||||||
cpa!(scope, stride_input_dim = stride_input);
|
|
||||||
cpa!(scope, offset_input += offset_local);
|
|
||||||
}).else(|scope|{
|
|
||||||
cpa!(scope, offset_local = offset_local * stride_input);
|
|
||||||
cpa!(scope, offset_input += offset_local);
|
|
||||||
}));
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
let accumulator = RD::initialize_naive(scope, tensor.item(), output.item());
|
|
||||||
|
|
||||||
cpa!(
|
|
||||||
scope,
|
|
||||||
range(0u32, shape_input_dim).for_each(|i, scope| {
|
|
||||||
let index = scope.create_local(Elem::UInt);
|
|
||||||
cpa!(scope, index = i * stride_input_dim);
|
|
||||||
cpa!(scope, index += offset_input);
|
|
||||||
let value = scope.create_local(tensor.item());
|
|
||||||
cpa!(scope, value = tensor[index]);
|
|
||||||
RD::inner_loop_naive(scope, accumulator, value, i);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
RD::assign_naive(scope, output, accumulator, shape_input_dim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Executes the naive kernel for reduce dim
|
/// Executes the naive kernel for reduce dim
|
||||||
pub fn reduce_dim_naive<
|
pub fn reduce_dim_naive<
|
||||||
RD: ReduceDimNaive<EI>,
|
RD: ReduceDimNaive<EI::Primitive>,
|
||||||
R: JitRuntime,
|
R: JitRuntime,
|
||||||
EI: JitElement,
|
EI: JitElement,
|
||||||
EO: JitElement,
|
EO: JitElement,
|
||||||
|
@ -152,12 +46,20 @@ pub fn reduce_dim_naive<
|
||||||
output: JitTensor<R, EO, D>,
|
output: JitTensor<R, EO, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> JitTensor<R, EO, D> {
|
) -> JitTensor<R, EO, D> {
|
||||||
let kernel = NaiveReduceDimEagerKernel::<RD, R, EI, EO>::new(dim);
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count =
|
||||||
|
calculate_cube_count_elemwise::<R::Server>(output.shape.num_elements(), cube_dim);
|
||||||
|
|
||||||
Execution::start(kernel, input.client.clone())
|
unsafe {
|
||||||
.inputs(&[input.as_handle_ref()])
|
naive_reduce_dim_compute_shader::launch_unchecked::<RD, EI::Primitive, EO::Primitive, R>(
|
||||||
.outputs(&[output.as_handle_ref()])
|
&input.client,
|
||||||
.execute(CubeCountSettings::Output { pos: 0 });
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
input.as_tensor_arg(1),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
ScalarArg::new(dim as u32),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,29 +1,25 @@
|
||||||
use crate::{kernel::reduce::SumDim, JitElement};
|
|
||||||
use cubecl::{
|
|
||||||
cpa,
|
|
||||||
ir::{Item, Scope, Variable},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::base::ReduceDimNaive;
|
use super::base::ReduceDimNaive;
|
||||||
|
use crate::kernel::reduce::SumDim;
|
||||||
|
use cubecl::cube;
|
||||||
|
use cubecl::prelude::{Cast, Numeric, Tensor, UInt, ABSOLUTE_POS};
|
||||||
|
|
||||||
impl<E: JitElement> ReduceDimNaive<E> for SumDim {
|
#[cube]
|
||||||
type Accumulator = Variable;
|
impl<EI: Numeric> ReduceDimNaive<EI> for SumDim {
|
||||||
|
type Accumulator = EI;
|
||||||
|
|
||||||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
fn initialize_naive() -> EI {
|
||||||
scope.zero(output_item)
|
EI::from_int(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
|
fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: UInt) {
|
||||||
cpa!(scope, accumulator += value);
|
*accumulator += current_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assign_naive(
|
fn assign_naive<EO: Numeric>(
|
||||||
scope: &mut Scope,
|
output: &mut Tensor<EO>,
|
||||||
output: Variable,
|
accumulator: EI,
|
||||||
accumulator: Variable,
|
_shape_reduce_dim: UInt,
|
||||||
_shape_reduce_dim: Variable,
|
|
||||||
) {
|
) {
|
||||||
let id = Variable::AbsolutePos;
|
output[ABSOLUTE_POS] = EO::cast_from(accumulator);
|
||||||
cpa!(scope, output[id] = accumulator);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue