diff --git a/Cargo.lock b/Cargo.lock index cf12c3fb3..4b0de5d0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,6 +534,7 @@ dependencies = [ "derive-new", "half", "log", + "num-traits", "serde", ] @@ -1983,6 +1984,15 @@ dependencies = [ "slab", ] +[[package]] +name = "gelu" +version = "0.14.0" +dependencies = [ + "burn-cube", + "burn-cuda", + "burn-wgpu", +] + [[package]] name = "gemm" version = "0.17.1" diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml index dd6a925cd..dd10d48eb 100644 --- a/crates/burn-cube/Cargo.toml +++ b/crates/burn-cube/Cargo.toml @@ -29,5 +29,6 @@ half = { workspace = true, features = ["bytemuck"] } serde = { workspace = true } burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" } derive-new = { workspace = true } +num-traits = { workspace = true } log = { workspace = true } diff --git a/crates/burn-cube/src/compute/launcher.rs b/crates/burn-cube/src/compute/launcher.rs index 5f437e4b2..55fdb4674 100644 --- a/crates/burn-cube/src/compute/launcher.rs +++ b/crates/burn-cube/src/compute/launcher.rs @@ -1,9 +1,11 @@ use crate::compute::{CubeCount, KernelTask}; use crate::ir::{Elem, FloatKind, IntKind}; +use crate::prelude::ArrayHandle; use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime}; use burn_compute::client::ComputeClient; use burn_compute::server::Binding; use bytemuck::NoUninit; +use num_traits::ToPrimitive; /// Prepare a kernel for [launch](KernelLauncher::launch). pub struct KernelLauncher { @@ -24,6 +26,11 @@ impl KernelLauncher { self.tensors.push(tensor); } + /// Register an array to be launched. + pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) { + self.tensors.push(&array.as_tensor()); + } + /// Register a u32 scalar to be launched. pub fn register_u32(&mut self, scalar: u32) { self.register_scalar(Elem::UInt); @@ -165,17 +172,21 @@ impl TensorState { bindings.push(tensor.handle.clone().binding()); - if metadata.is_empty() { - metadata.push(tensor.strides.len() as u32); - } + let old_rank = if metadata.is_empty() { + let rank = tensor.strides.len() as u32; + metadata.push(rank); + None + } else if tensor.strides.len() > metadata[0] as usize { + let old_rank = metadata[0]; + let rank = tensor.strides.len() as u32; + Self::adjust_rank(metadata, bindings.len(), rank); + Some(old_rank) + } else { + None + }; - for s in tensor.strides.iter() { - metadata.push(*s as u32); - } - - for s in tensor.shape.iter() { - metadata.push(*s as u32); - } + Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata); + Self::register_shape(tensor.shape, old_rank, metadata); if R::require_array_lengths() { let len = calculate_num_elems_dyn_rank(tensor.shape); @@ -183,6 +194,82 @@ impl TensorState { } } + fn adjust_rank(metadata: &mut Vec, num_registered: usize, rank: u32) { + let old_rank = metadata[0] as usize; + let rank_diff = rank as usize - old_rank; + let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered); + + for pos in 0..num_registered { + let stride_index = (pos * old_rank * 2) + 1; + let shape_index = stride_index + old_rank; + + let strides_old = &metadata[stride_index..stride_index + old_rank]; + let shape_old = &metadata[shape_index..shape_index + old_rank]; + + Self::register_strides( + strides_old, + shape_old, + Some(old_rank as u32), + &mut updated_metadata, + ); + Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata); + } + + core::mem::swap(&mut updated_metadata, metadata); + } + + fn register_strides( + strides: &[T], + shape: &[T], + old_rank: Option, + output: &mut Vec, + ) { + let old_rank = if let Some(old_rank) = old_rank { + let rank = output[0]; + let rank_diff = old_rank - rank; + let padded_strides = if rank_diff > 0 { + shape + .iter() + .take(old_rank as usize) + .map(|a| a.to_u32().unwrap()) + .sum::() + } else { + 0 + }; + + for _ in 0..rank_diff { + output.push(padded_strides.to_u32().unwrap()); + } + + old_rank as usize + } else { + output[0] as usize // same as current. + }; + + for stride in strides.iter().take(old_rank) { + output.push(stride.to_u32().unwrap()); + } + } + + fn register_shape(shape: &[T], old_rank: Option, output: &mut Vec) { + let old_rank = if let Some(old_rank) = old_rank { + let rank = output[0]; + let rank_diff = rank - old_rank; + + for _ in 0..rank_diff { + output.push(1); + } + + old_rank as usize + } else { + output[0] as usize // same as current + }; + + for elem in shape.iter().take(old_rank) { + output.push(elem.to_u32().unwrap()); + } + } + fn register( self, client: &ComputeClient, @@ -205,6 +292,7 @@ impl TensorState { } } } + impl ScalarState { /// Add a new scalar value to the state. pub fn push(&mut self, val: T) { diff --git a/crates/burn-cube/src/frontend/element/array.rs b/crates/burn-cube/src/frontend/element/array.rs index e1483a47f..9633f76f8 100644 --- a/crates/burn-cube/src/frontend/element/array.rs +++ b/crates/burn-cube/src/frontend/element/array.rs @@ -1,6 +1,13 @@ use std::marker::PhantomData; -use crate::frontend::{CubeType, ExpandElement}; +use crate::{ + compute::{KernelBuilder, KernelLauncher}, + frontend::{CubeType, ExpandElement}, + ir::{Item, Vectorization}, + unexpanded, Runtime, +}; + +use super::{ArgSettings, CubeElem, LaunchArg, TensorHandle, UInt}; #[derive(new, Clone, Copy)] pub struct Array { @@ -10,3 +17,52 @@ pub struct Array { impl CubeType for Array { type ExpandType = ExpandElement; } + +impl Array { + /// Obtain the array length of input + pub fn len(self) -> UInt { + unexpanded!() + } +} + +impl LaunchArg for Array { + type RuntimeArg<'a, R: Runtime> = ArrayHandle<'a, R>; + + fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { + builder.input_array(Item::vectorized(C::as_elem(), vectorization)) + } + + fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { + builder.output_array(Item::vectorized(C::as_elem(), vectorization)) + } +} + +pub struct ArrayHandle<'a, R: Runtime> { + pub handle: &'a burn_compute::server::Handle, + pub length: [usize; 1], +} + +impl<'a, R: Runtime> ArgSettings for ArrayHandle<'a, R> { + fn register(&self, launcher: &mut KernelLauncher) { + launcher.register_array(self) + } +} + +impl<'a, R: Runtime> ArrayHandle<'a, R> { + pub fn new(handle: &'a burn_compute::server::Handle, length: usize) -> Self { + Self { + handle, + length: [length], + } + } + + pub fn as_tensor(&self) -> TensorHandle<'_, R> { + let shape = &self.length; + + TensorHandle { + handle: self.handle, + strides: &[1], + shape, + } + } +} diff --git a/crates/burn-cube/src/prelude.rs b/crates/burn-cube/src/prelude.rs index 66d2f3fe3..78b3acd50 100644 --- a/crates/burn-cube/src/prelude.rs +++ b/crates/burn-cube/src/prelude.rs @@ -9,7 +9,9 @@ pub use crate::ir::{CubeDim, KernelDefinition}; pub use crate::runtime::Runtime; /// Elements -pub use crate::frontend::{Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64}; +pub use crate::frontend::{ + Array, ArrayHandle, Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64, +}; pub use crate::pod::CubeElement; /// Topology diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 00fe2cd9c..8f5d181cd 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -10,7 +10,7 @@ pub mod compiler; pub use device::*; use burn_jit::JitBackend; -use runtime::CudaRuntime; +pub use runtime::CudaRuntime; #[cfg(not(feature = "fusion"))] pub type Cuda = JitBackend; diff --git a/examples/custom-csv-dataset/Cargo.toml b/examples/custom-csv-dataset/Cargo.toml index 5d450ad47..fecedd76a 100644 --- a/examples/custom-csv-dataset/Cargo.toml +++ b/examples/custom-csv-dataset/Cargo.toml @@ -18,4 +18,4 @@ reqwest = {workspace = true, features = ["blocking"]} # CSV parsing csv = {workspace = true} -serde = {workspace = true, features = ["std", "derive"]} \ No newline at end of file +serde = {workspace = true, features = ["std", "derive"]} diff --git a/examples/gelu/Cargo.toml b/examples/gelu/Cargo.toml new file mode 100644 index 000000000..c906548dd --- /dev/null +++ b/examples/gelu/Cargo.toml @@ -0,0 +1,17 @@ +[package] +authors = [] +name = "gelu" +publish = false +edition.workspace = true +license.workspace = true +version.workspace = true + +[features] +default = ["wgpu"] +cuda = ["burn-cuda"] +wgpu = ["burn-wgpu"] + +[dependencies] +burn-cube = { path = "../../crates/burn-cube", version = "0.14.0" } +burn-cuda = { path = "../../crates/burn-cuda", version = "0.14.0", optional = true } +burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", optional = true } diff --git a/examples/gelu/examples/gelu.rs b/examples/gelu/examples/gelu.rs new file mode 100644 index 000000000..374bfa396 --- /dev/null +++ b/examples/gelu/examples/gelu.rs @@ -0,0 +1,6 @@ +fn main() { + #[cfg(feature = "cuda")] + gelu::launch::(&Default::default()); + #[cfg(feature = "wgpu")] + gelu::launch::(&Default::default()); +} diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs new file mode 100644 index 000000000..ff9e66ac5 --- /dev/null +++ b/examples/gelu/src/lib.rs @@ -0,0 +1,36 @@ +use burn_cube::prelude::*; + +#[cube(launch)] +fn gelu(input: Array, mut output: Array) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); + } +} + +#[cube] +fn gelu_scalar(x: F) -> F { + x * (F::new(1.0) + F::erf(x / F::sqrt(F::new(2.0)))) / F::new(2.0) +} + +pub fn launch(device: &R::Device) { + let client = R::client(device); + println!("Executing gelu with runtime {:?}", R::name()); + + let input = &[-1., 0., 1., 5.]; + let input_handle = client.create(f32::as_bytes(input)); + let output_handle = client.empty(input.len() * core::mem::size_of::()); + + gelu_launch::( + client.clone(), + CubeCount::new(1, 1, 1), + KernelSettings::default(), + ArrayHandle::new(&input_handle, input.len()), + ArrayHandle::new(&output_handle, input.len()), + ); + + let output = client.read(output_handle.binding()).read_sync().unwrap(); + let output = f32::from_bytes(&output); + + // Should be [-0.1587, 0.0000, 0.8413, 5.0000] + println!("{output:?}"); +}