feat cube support Array (#1907)

This commit is contained in:
Nathaniel Simard 2024-06-19 17:03:02 -04:00 committed by GitHub
parent 14d1bbba64
commit d50bac165e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 230 additions and 14 deletions

10
Cargo.lock generated
View File

@ -534,6 +534,7 @@ dependencies = [
"derive-new", "derive-new",
"half", "half",
"log", "log",
"num-traits",
"serde", "serde",
] ]
@ -1983,6 +1984,15 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "gelu"
version = "0.14.0"
dependencies = [
"burn-cube",
"burn-cuda",
"burn-wgpu",
]
[[package]] [[package]]
name = "gemm" name = "gemm"
version = "0.17.1" version = "0.17.1"

View File

@ -29,5 +29,6 @@ half = { workspace = true, features = ["bytemuck"] }
serde = { workspace = true } serde = { workspace = true }
burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" } burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" }
derive-new = { workspace = true } derive-new = { workspace = true }
num-traits = { workspace = true }
log = { workspace = true } log = { workspace = true }

View File

@ -1,9 +1,11 @@
use crate::compute::{CubeCount, KernelTask}; use crate::compute::{CubeCount, KernelTask};
use crate::ir::{Elem, FloatKind, IntKind}; use crate::ir::{Elem, FloatKind, IntKind};
use crate::prelude::ArrayHandle;
use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime}; use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime};
use burn_compute::client::ComputeClient; use burn_compute::client::ComputeClient;
use burn_compute::server::Binding; use burn_compute::server::Binding;
use bytemuck::NoUninit; use bytemuck::NoUninit;
use num_traits::ToPrimitive;
/// Prepare a kernel for [launch](KernelLauncher::launch). /// Prepare a kernel for [launch](KernelLauncher::launch).
pub struct KernelLauncher<R: Runtime> { pub struct KernelLauncher<R: Runtime> {
@ -24,6 +26,11 @@ impl<R: Runtime> KernelLauncher<R> {
self.tensors.push(tensor); self.tensors.push(tensor);
} }
/// Register an array to be launched.
pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) {
self.tensors.push(&array.as_tensor());
}
/// Register a u32 scalar to be launched. /// Register a u32 scalar to be launched.
pub fn register_u32(&mut self, scalar: u32) { pub fn register_u32(&mut self, scalar: u32) {
self.register_scalar(Elem::UInt); self.register_scalar(Elem::UInt);
@ -165,17 +172,21 @@ impl<R: Runtime> TensorState<R> {
bindings.push(tensor.handle.clone().binding()); bindings.push(tensor.handle.clone().binding());
if metadata.is_empty() { let old_rank = if metadata.is_empty() {
metadata.push(tensor.strides.len() as u32); let rank = tensor.strides.len() as u32;
} metadata.push(rank);
None
} else if tensor.strides.len() > metadata[0] as usize {
let old_rank = metadata[0];
let rank = tensor.strides.len() as u32;
Self::adjust_rank(metadata, bindings.len(), rank);
Some(old_rank)
} else {
None
};
for s in tensor.strides.iter() { Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
metadata.push(*s as u32); Self::register_shape(tensor.shape, old_rank, metadata);
}
for s in tensor.shape.iter() {
metadata.push(*s as u32);
}
if R::require_array_lengths() { if R::require_array_lengths() {
let len = calculate_num_elems_dyn_rank(tensor.shape); let len = calculate_num_elems_dyn_rank(tensor.shape);
@ -183,6 +194,82 @@ impl<R: Runtime> TensorState<R> {
} }
} }
fn adjust_rank(metadata: &mut Vec<u32>, num_registered: usize, rank: u32) {
let old_rank = metadata[0] as usize;
let rank_diff = rank as usize - old_rank;
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
for pos in 0..num_registered {
let stride_index = (pos * old_rank * 2) + 1;
let shape_index = stride_index + old_rank;
let strides_old = &metadata[stride_index..stride_index + old_rank];
let shape_old = &metadata[shape_index..shape_index + old_rank];
Self::register_strides(
strides_old,
shape_old,
Some(old_rank as u32),
&mut updated_metadata,
);
Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata);
}
core::mem::swap(&mut updated_metadata, metadata);
}
fn register_strides<T: ToPrimitive>(
strides: &[T],
shape: &[T],
old_rank: Option<u32>,
output: &mut Vec<u32>,
) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = old_rank - rank;
let padded_strides = if rank_diff > 0 {
shape
.iter()
.take(old_rank as usize)
.map(|a| a.to_u32().unwrap())
.sum::<u32>()
} else {
0
};
for _ in 0..rank_diff {
output.push(padded_strides.to_u32().unwrap());
}
old_rank as usize
} else {
output[0] as usize // same as current.
};
for stride in strides.iter().take(old_rank) {
output.push(stride.to_u32().unwrap());
}
}
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = rank - old_rank;
for _ in 0..rank_diff {
output.push(1);
}
old_rank as usize
} else {
output[0] as usize // same as current
};
for elem in shape.iter().take(old_rank) {
output.push(elem.to_u32().unwrap());
}
}
fn register( fn register(
self, self,
client: &ComputeClient<R::Server, R::Channel>, client: &ComputeClient<R::Server, R::Channel>,
@ -205,6 +292,7 @@ impl<R: Runtime> TensorState<R> {
} }
} }
} }
impl<T: NoUninit> ScalarState<T> { impl<T: NoUninit> ScalarState<T> {
/// Add a new scalar value to the state. /// Add a new scalar value to the state.
pub fn push(&mut self, val: T) { pub fn push(&mut self, val: T) {

View File

@ -1,6 +1,13 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::frontend::{CubeType, ExpandElement}; use crate::{
compute::{KernelBuilder, KernelLauncher},
frontend::{CubeType, ExpandElement},
ir::{Item, Vectorization},
unexpanded, Runtime,
};
use super::{ArgSettings, CubeElem, LaunchArg, TensorHandle, UInt};
#[derive(new, Clone, Copy)] #[derive(new, Clone, Copy)]
pub struct Array<E> { pub struct Array<E> {
@ -10,3 +17,52 @@ pub struct Array<E> {
impl<C: CubeType> CubeType for Array<C> { impl<C: CubeType> CubeType for Array<C> {
type ExpandType = ExpandElement; type ExpandType = ExpandElement;
} }
impl<E: CubeType> Array<E> {
/// Obtain the array length of input
pub fn len(self) -> UInt {
unexpanded!()
}
}
impl<C: CubeElem> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayHandle<'a, R>;
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
builder.input_array(Item::vectorized(C::as_elem(), vectorization))
}
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
builder.output_array(Item::vectorized(C::as_elem(), vectorization))
}
}
pub struct ArrayHandle<'a, R: Runtime> {
pub handle: &'a burn_compute::server::Handle<R::Server>,
pub length: [usize; 1],
}
impl<'a, R: Runtime> ArgSettings<R> for ArrayHandle<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_array(self)
}
}
impl<'a, R: Runtime> ArrayHandle<'a, R> {
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
Self {
handle,
length: [length],
}
}
pub fn as_tensor(&self) -> TensorHandle<'_, R> {
let shape = &self.length;
TensorHandle {
handle: self.handle,
strides: &[1],
shape,
}
}
}

View File

@ -9,7 +9,9 @@ pub use crate::ir::{CubeDim, KernelDefinition};
pub use crate::runtime::Runtime; pub use crate::runtime::Runtime;
/// Elements /// Elements
pub use crate::frontend::{Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64}; pub use crate::frontend::{
Array, ArrayHandle, Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64,
};
pub use crate::pod::CubeElement; pub use crate::pod::CubeElement;
/// Topology /// Topology

View File

@ -10,7 +10,7 @@ pub mod compiler;
pub use device::*; pub use device::*;
use burn_jit::JitBackend; use burn_jit::JitBackend;
use runtime::CudaRuntime; pub use runtime::CudaRuntime;
#[cfg(not(feature = "fusion"))] #[cfg(not(feature = "fusion"))]
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>; pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>;

View File

@ -18,4 +18,4 @@ reqwest = {workspace = true, features = ["blocking"]}
# CSV parsing # CSV parsing
csv = {workspace = true} csv = {workspace = true}
serde = {workspace = true, features = ["std", "derive"]} serde = {workspace = true, features = ["std", "derive"]}

17
examples/gelu/Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
authors = []
name = "gelu"
publish = false
edition.workspace = true
license.workspace = true
version.workspace = true
[features]
default = ["wgpu"]
cuda = ["burn-cuda"]
wgpu = ["burn-wgpu"]
[dependencies]
burn-cube = { path = "../../crates/burn-cube", version = "0.14.0" }
burn-cuda = { path = "../../crates/burn-cuda", version = "0.14.0", optional = true }
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", optional = true }

View File

@ -0,0 +1,6 @@
fn main() {
#[cfg(feature = "cuda")]
gelu::launch::<burn_cuda::CudaRuntime>(&Default::default());
#[cfg(feature = "wgpu")]
gelu::launch::<burn_wgpu::WgpuRuntime>(&Default::default());
}

36
examples/gelu/src/lib.rs Normal file
View File

@ -0,0 +1,36 @@
use burn_cube::prelude::*;
#[cube(launch)]
fn gelu<F: Float>(input: Array<F>, mut output: Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}
#[cube]
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::new(1.0) + F::erf(x / F::sqrt(F::new(2.0)))) / F::new(2.0)
}
pub fn launch<R: Runtime>(device: &R::Device) {
let client = R::client(device);
println!("Executing gelu with runtime {:?}", R::name());
let input = &[-1., 0., 1., 5.];
let input_handle = client.create(f32::as_bytes(input));
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
gelu_launch::<F32, R>(
client.clone(),
CubeCount::new(1, 1, 1),
KernelSettings::default(),
ArrayHandle::new(&input_handle, input.len()),
ArrayHandle::new(&output_handle, input.len()),
);
let output = client.read(output_handle.binding()).read_sync().unwrap();
let output = f32::from_bytes(&output);
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("{output:?}");
}