mirror of https://github.com/tracel-ai/burn.git
feat cube support Array (#1907)
This commit is contained in:
parent
14d1bbba64
commit
d50bac165e
|
@ -534,6 +534,7 @@ dependencies = [
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"half",
|
"half",
|
||||||
"log",
|
"log",
|
||||||
|
"num-traits",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1983,6 +1984,15 @@ dependencies = [
|
||||||
"slab",
|
"slab",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gelu"
|
||||||
|
version = "0.14.0"
|
||||||
|
dependencies = [
|
||||||
|
"burn-cube",
|
||||||
|
"burn-cuda",
|
||||||
|
"burn-wgpu",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gemm"
|
name = "gemm"
|
||||||
version = "0.17.1"
|
version = "0.17.1"
|
||||||
|
|
|
@ -29,5 +29,6 @@ half = { workspace = true, features = ["bytemuck"] }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" }
|
burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" }
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
log = { workspace = true }
|
log = { workspace = true }
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
use crate::compute::{CubeCount, KernelTask};
|
use crate::compute::{CubeCount, KernelTask};
|
||||||
use crate::ir::{Elem, FloatKind, IntKind};
|
use crate::ir::{Elem, FloatKind, IntKind};
|
||||||
|
use crate::prelude::ArrayHandle;
|
||||||
use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime};
|
use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime};
|
||||||
use burn_compute::client::ComputeClient;
|
use burn_compute::client::ComputeClient;
|
||||||
use burn_compute::server::Binding;
|
use burn_compute::server::Binding;
|
||||||
use bytemuck::NoUninit;
|
use bytemuck::NoUninit;
|
||||||
|
use num_traits::ToPrimitive;
|
||||||
|
|
||||||
/// Prepare a kernel for [launch](KernelLauncher::launch).
|
/// Prepare a kernel for [launch](KernelLauncher::launch).
|
||||||
pub struct KernelLauncher<R: Runtime> {
|
pub struct KernelLauncher<R: Runtime> {
|
||||||
|
@ -24,6 +26,11 @@ impl<R: Runtime> KernelLauncher<R> {
|
||||||
self.tensors.push(tensor);
|
self.tensors.push(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Register an array to be launched.
|
||||||
|
pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) {
|
||||||
|
self.tensors.push(&array.as_tensor());
|
||||||
|
}
|
||||||
|
|
||||||
/// Register a u32 scalar to be launched.
|
/// Register a u32 scalar to be launched.
|
||||||
pub fn register_u32(&mut self, scalar: u32) {
|
pub fn register_u32(&mut self, scalar: u32) {
|
||||||
self.register_scalar(Elem::UInt);
|
self.register_scalar(Elem::UInt);
|
||||||
|
@ -165,17 +172,21 @@ impl<R: Runtime> TensorState<R> {
|
||||||
|
|
||||||
bindings.push(tensor.handle.clone().binding());
|
bindings.push(tensor.handle.clone().binding());
|
||||||
|
|
||||||
if metadata.is_empty() {
|
let old_rank = if metadata.is_empty() {
|
||||||
metadata.push(tensor.strides.len() as u32);
|
let rank = tensor.strides.len() as u32;
|
||||||
}
|
metadata.push(rank);
|
||||||
|
None
|
||||||
|
} else if tensor.strides.len() > metadata[0] as usize {
|
||||||
|
let old_rank = metadata[0];
|
||||||
|
let rank = tensor.strides.len() as u32;
|
||||||
|
Self::adjust_rank(metadata, bindings.len(), rank);
|
||||||
|
Some(old_rank)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
for s in tensor.strides.iter() {
|
Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
|
||||||
metadata.push(*s as u32);
|
Self::register_shape(tensor.shape, old_rank, metadata);
|
||||||
}
|
|
||||||
|
|
||||||
for s in tensor.shape.iter() {
|
|
||||||
metadata.push(*s as u32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if R::require_array_lengths() {
|
if R::require_array_lengths() {
|
||||||
let len = calculate_num_elems_dyn_rank(tensor.shape);
|
let len = calculate_num_elems_dyn_rank(tensor.shape);
|
||||||
|
@ -183,6 +194,82 @@ impl<R: Runtime> TensorState<R> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn adjust_rank(metadata: &mut Vec<u32>, num_registered: usize, rank: u32) {
|
||||||
|
let old_rank = metadata[0] as usize;
|
||||||
|
let rank_diff = rank as usize - old_rank;
|
||||||
|
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
|
||||||
|
|
||||||
|
for pos in 0..num_registered {
|
||||||
|
let stride_index = (pos * old_rank * 2) + 1;
|
||||||
|
let shape_index = stride_index + old_rank;
|
||||||
|
|
||||||
|
let strides_old = &metadata[stride_index..stride_index + old_rank];
|
||||||
|
let shape_old = &metadata[shape_index..shape_index + old_rank];
|
||||||
|
|
||||||
|
Self::register_strides(
|
||||||
|
strides_old,
|
||||||
|
shape_old,
|
||||||
|
Some(old_rank as u32),
|
||||||
|
&mut updated_metadata,
|
||||||
|
);
|
||||||
|
Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
core::mem::swap(&mut updated_metadata, metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_strides<T: ToPrimitive>(
|
||||||
|
strides: &[T],
|
||||||
|
shape: &[T],
|
||||||
|
old_rank: Option<u32>,
|
||||||
|
output: &mut Vec<u32>,
|
||||||
|
) {
|
||||||
|
let old_rank = if let Some(old_rank) = old_rank {
|
||||||
|
let rank = output[0];
|
||||||
|
let rank_diff = old_rank - rank;
|
||||||
|
let padded_strides = if rank_diff > 0 {
|
||||||
|
shape
|
||||||
|
.iter()
|
||||||
|
.take(old_rank as usize)
|
||||||
|
.map(|a| a.to_u32().unwrap())
|
||||||
|
.sum::<u32>()
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
for _ in 0..rank_diff {
|
||||||
|
output.push(padded_strides.to_u32().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
old_rank as usize
|
||||||
|
} else {
|
||||||
|
output[0] as usize // same as current.
|
||||||
|
};
|
||||||
|
|
||||||
|
for stride in strides.iter().take(old_rank) {
|
||||||
|
output.push(stride.to_u32().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
|
||||||
|
let old_rank = if let Some(old_rank) = old_rank {
|
||||||
|
let rank = output[0];
|
||||||
|
let rank_diff = rank - old_rank;
|
||||||
|
|
||||||
|
for _ in 0..rank_diff {
|
||||||
|
output.push(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
old_rank as usize
|
||||||
|
} else {
|
||||||
|
output[0] as usize // same as current
|
||||||
|
};
|
||||||
|
|
||||||
|
for elem in shape.iter().take(old_rank) {
|
||||||
|
output.push(elem.to_u32().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn register(
|
fn register(
|
||||||
self,
|
self,
|
||||||
client: &ComputeClient<R::Server, R::Channel>,
|
client: &ComputeClient<R::Server, R::Channel>,
|
||||||
|
@ -205,6 +292,7 @@ impl<R: Runtime> TensorState<R> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: NoUninit> ScalarState<T> {
|
impl<T: NoUninit> ScalarState<T> {
|
||||||
/// Add a new scalar value to the state.
|
/// Add a new scalar value to the state.
|
||||||
pub fn push(&mut self, val: T) {
|
pub fn push(&mut self, val: T) {
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::frontend::{CubeType, ExpandElement};
|
use crate::{
|
||||||
|
compute::{KernelBuilder, KernelLauncher},
|
||||||
|
frontend::{CubeType, ExpandElement},
|
||||||
|
ir::{Item, Vectorization},
|
||||||
|
unexpanded, Runtime,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{ArgSettings, CubeElem, LaunchArg, TensorHandle, UInt};
|
||||||
|
|
||||||
#[derive(new, Clone, Copy)]
|
#[derive(new, Clone, Copy)]
|
||||||
pub struct Array<E> {
|
pub struct Array<E> {
|
||||||
|
@ -10,3 +17,52 @@ pub struct Array<E> {
|
||||||
impl<C: CubeType> CubeType for Array<C> {
|
impl<C: CubeType> CubeType for Array<C> {
|
||||||
type ExpandType = ExpandElement;
|
type ExpandType = ExpandElement;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<E: CubeType> Array<E> {
|
||||||
|
/// Obtain the array length of input
|
||||||
|
pub fn len(self) -> UInt {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: CubeElem> LaunchArg for Array<C> {
|
||||||
|
type RuntimeArg<'a, R: Runtime> = ArrayHandle<'a, R>;
|
||||||
|
|
||||||
|
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||||
|
builder.input_array(Item::vectorized(C::as_elem(), vectorization))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||||
|
builder.output_array(Item::vectorized(C::as_elem(), vectorization))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ArrayHandle<'a, R: Runtime> {
|
||||||
|
pub handle: &'a burn_compute::server::Handle<R::Server>,
|
||||||
|
pub length: [usize; 1],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, R: Runtime> ArgSettings<R> for ArrayHandle<'a, R> {
|
||||||
|
fn register(&self, launcher: &mut KernelLauncher<R>) {
|
||||||
|
launcher.register_array(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, R: Runtime> ArrayHandle<'a, R> {
|
||||||
|
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
handle,
|
||||||
|
length: [length],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_tensor(&self) -> TensorHandle<'_, R> {
|
||||||
|
let shape = &self.length;
|
||||||
|
|
||||||
|
TensorHandle {
|
||||||
|
handle: self.handle,
|
||||||
|
strides: &[1],
|
||||||
|
shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,9 @@ pub use crate::ir::{CubeDim, KernelDefinition};
|
||||||
pub use crate::runtime::Runtime;
|
pub use crate::runtime::Runtime;
|
||||||
|
|
||||||
/// Elements
|
/// Elements
|
||||||
pub use crate::frontend::{Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64};
|
pub use crate::frontend::{
|
||||||
|
Array, ArrayHandle, Float, LaunchArg, Tensor, TensorHandle, UInt, F16, F32, F64, I32, I64,
|
||||||
|
};
|
||||||
pub use crate::pod::CubeElement;
|
pub use crate::pod::CubeElement;
|
||||||
|
|
||||||
/// Topology
|
/// Topology
|
||||||
|
|
|
@ -10,7 +10,7 @@ pub mod compiler;
|
||||||
pub use device::*;
|
pub use device::*;
|
||||||
|
|
||||||
use burn_jit::JitBackend;
|
use burn_jit::JitBackend;
|
||||||
use runtime::CudaRuntime;
|
pub use runtime::CudaRuntime;
|
||||||
|
|
||||||
#[cfg(not(feature = "fusion"))]
|
#[cfg(not(feature = "fusion"))]
|
||||||
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>;
|
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>;
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
authors = []
|
||||||
|
name = "gelu"
|
||||||
|
publish = false
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["wgpu"]
|
||||||
|
cuda = ["burn-cuda"]
|
||||||
|
wgpu = ["burn-wgpu"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
burn-cube = { path = "../../crates/burn-cube", version = "0.14.0" }
|
||||||
|
burn-cuda = { path = "../../crates/burn-cuda", version = "0.14.0", optional = true }
|
||||||
|
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", optional = true }
|
|
@ -0,0 +1,6 @@
|
||||||
|
fn main() {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
gelu::launch::<burn_cuda::CudaRuntime>(&Default::default());
|
||||||
|
#[cfg(feature = "wgpu")]
|
||||||
|
gelu::launch::<burn_wgpu::WgpuRuntime>(&Default::default());
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube(launch)]
|
||||||
|
fn gelu<F: Float>(input: Array<F>, mut output: Array<F>) {
|
||||||
|
if ABSOLUTE_POS < input.len() {
|
||||||
|
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn gelu_scalar<F: Float>(x: F) -> F {
|
||||||
|
x * (F::new(1.0) + F::erf(x / F::sqrt(F::new(2.0)))) / F::new(2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn launch<R: Runtime>(device: &R::Device) {
|
||||||
|
let client = R::client(device);
|
||||||
|
println!("Executing gelu with runtime {:?}", R::name());
|
||||||
|
|
||||||
|
let input = &[-1., 0., 1., 5.];
|
||||||
|
let input_handle = client.create(f32::as_bytes(input));
|
||||||
|
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
|
||||||
|
|
||||||
|
gelu_launch::<F32, R>(
|
||||||
|
client.clone(),
|
||||||
|
CubeCount::new(1, 1, 1),
|
||||||
|
KernelSettings::default(),
|
||||||
|
ArrayHandle::new(&input_handle, input.len()),
|
||||||
|
ArrayHandle::new(&output_handle, input.len()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let output = client.read(output_handle.binding()).read_sync().unwrap();
|
||||||
|
let output = f32::from_bytes(&output);
|
||||||
|
|
||||||
|
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
|
||||||
|
println!("{output:?}");
|
||||||
|
}
|
Loading…
Reference in New Issue