diff --git a/Cargo.toml b/Cargo.toml index 8eae18eb2..58c206bb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "burn-ndarray", "burn-no-std-tests", "burn-tch", + "burn-wgpu", "burn-tensor-testgen", "burn-tensor", "burn-train", @@ -53,6 +54,11 @@ syn = "2.0" tempfile = "3.5.0" thiserror = "1.0.40" topological-sort = "0.2.2" + +# WGPU stuff +wgpu = "0.16.0" +futures-intrusive = "0.5" +pollster = "0.3" # # The following packages disable the "std" feature for no_std compatibility # diff --git a/burn-core/src/nn/attention/mask.rs b/burn-core/src/nn/attention/mask.rs index 7d773b3d3..976d34b98 100644 --- a/burn-core/src/nn/attention/mask.rs +++ b/burn-core/src/nn/attention/mask.rs @@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask( mask = mask.to_device(device).repeat(0, batch_size); - mask.equal_elem(1_i64) + mask.equal_elem(1_i64.elem::()) } pub struct GeneratePaddingMask { diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 951759fd3..ad614daf7 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -256,12 +256,6 @@ where K::equal(self.primitive, other.primitive) } - /// Applies element wise equal comparison and returns a boolean tensor. - pub fn equal_elem>(self, other: E) -> Tensor { - let elem: K::Elem = other.into(); - K::equal_elem::(self.primitive, elem) - } - /// Concatenates all tensors into a new one along the given dimension. /// /// # Panics @@ -400,7 +394,6 @@ pub trait BasicOps: TensorKind { lhs: Self::Primitive, rhs: Self::Primitive, ) -> Tensor; - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; fn elem_type_name() -> &'static str { core::any::type_name::() } @@ -478,10 +471,6 @@ impl BasicOps for Float { ) -> Tensor { Tensor::new(B::equal(lhs, rhs)) } - - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::equal_elem(lhs, rhs)) - } } impl BasicOps for Int { @@ -553,10 +542,6 @@ impl BasicOps for Int { Tensor::new(B::int_equal(lhs, rhs)) } - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::int_equal_elem(lhs, rhs)) - } - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { B::int_cat(vectors, dim) } @@ -631,10 +616,6 @@ impl BasicOps for Bool { Tensor::new(B::bool_equal(lhs, rhs)) } - fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { - Tensor::new(B::bool_equal_elem(lhs, rhs)) - } - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive { B::bool_cat(vectors, dim) } diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index af1ff3e70..234c751a0 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -133,6 +133,11 @@ where Self::new(K::sum_dim(self.primitive, dim)) } + /// Applies element wise equal comparison and returns a boolean tensor. + pub fn equal_elem(self, other: E) -> Tensor { + K::equal_elem::(self.primitive, other.elem()) + } + /// Applies element wise greater comparison and returns a boolean tensor. /// /// # Panics @@ -413,6 +418,7 @@ where fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; fn mean(tensor: Self::Primitive) -> Self::Primitive<1>; fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor; fn greater( lhs: Self::Primitive, rhs: Self::Primitive, @@ -559,6 +565,9 @@ impl Numeric for Int { B::int_mean_dim(tensor, dim) } + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::int_equal_elem(lhs, rhs)) + } fn greater( lhs: Self::Primitive, rhs: Self::Primitive, @@ -777,6 +786,9 @@ impl Numeric for Float { B::mean_dim(tensor, dim) } + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> Tensor { + Tensor::new(B::equal_elem(lhs, rhs)) + } fn greater( lhs: Self::Primitive, rhs: Self::Primitive, diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 2114919c9..d3a300c7a 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -79,7 +79,7 @@ pub trait Backend: /// Tensor primitive to be used for all int operations. type IntTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; /// Int element type. - type IntElem: Element + From + Into; + type IntElem: Element; /// Tensor primitive to be used for all bool operations. type BoolTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index d58f034d0..addda91e5 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -22,7 +22,9 @@ pub trait TensorOps { } fn shape(tensor: &B::TensorPrimitive) -> Shape; fn to_data(tensor: &B::TensorPrimitive) -> Data; - fn into_data(tensor: B::TensorPrimitive) -> Data; + fn into_data(tensor: B::TensorPrimitive) -> Data { + Self::to_data(&tensor) + } fn device(tensor: &B::TensorPrimitive) -> B::Device; fn to_device( tensor: B::TensorPrimitive, @@ -102,7 +104,9 @@ pub trait TensorOps { lhs: B::TensorPrimitive, rhs: B::TensorPrimitive, ) -> B::TensorPrimitive; - fn neg(tensor: B::TensorPrimitive) -> B::TensorPrimitive; + fn neg(tensor: B::TensorPrimitive) -> B::TensorPrimitive { + Self::mul_scalar(tensor, (-1.0_f32).elem::()) + } fn transpose(tensor: B::TensorPrimitive) -> B::TensorPrimitive { Self::swap_dims(tensor, D - 2, D - 1) } diff --git a/burn-tensor/src/tests/ops/add.rs b/burn-tensor/src/tests/ops/add.rs index 7608a3af4..ae350d550 100644 --- a/burn-tensor/src/tests/ops/add.rs +++ b/burn-tensor/src/tests/ops/add.rs @@ -15,4 +15,17 @@ mod tests { let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn test_add_broadcast() { + let data_1 = Data::from([[0.0, 1.0, 2.0]]); + let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-tensor/src/tests/ops/log.rs b/burn-tensor/src/tests/ops/log.rs index 5e2f8f361..f71387317 100644 --- a/burn-tensor/src/tests/ops/log.rs +++ b/burn-tensor/src/tests/ops/log.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{Data, Tensor}; #[test] - fn should_support_exp_ops() { + fn should_support_log_ops() { let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = Tensor::::from_data(data); diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml new file mode 100644 index 000000000..2dcba4c9c --- /dev/null +++ b/burn-wgpu/Cargo.toml @@ -0,0 +1,32 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "WGPU backend for burn" +edition = "2021" +keywords = ["deep-learning", "machine-learning", "data"] +license = "MIT/Apache-2.0" +name = "burn-wgpu" +readme = "README.md" +repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu" +version = "0.8.0" + +[dependencies] +burn-tensor = {path = "../burn-tensor", version = "0.8.0"} +burn-common = {path = "../burn-common", version = "0.8.0"} +derive-new = {workspace = true} +bytemuck = {workspace = true} +rand = {workspace = true} +num-traits = {workspace = true} + +# WGPU stuff +wgpu = {workspace = true} +futures-intrusive = {workspace = true} +pollster = {workspace = true} + +[dev-dependencies] +burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [ + "export_tests", +]} +burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = false, features = [ + "export_tests", +]} diff --git a/burn-wgpu/LICENSE-APACHE b/burn-wgpu/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/burn-wgpu/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/burn-wgpu/LICENSE-MIT b/burn-wgpu/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/burn-wgpu/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/burn-wgpu/README.md b/burn-wgpu/README.md new file mode 100644 index 000000000..100034edf --- /dev/null +++ b/burn-wgpu/README.md @@ -0,0 +1,3 @@ +# Burn WGPU Backend + +[Burn](https://github.com/burn-rs/burn) WGPU backend diff --git a/burn-wgpu/src/backend.rs b/burn-wgpu/src/backend.rs new file mode 100644 index 000000000..aca470bfa --- /dev/null +++ b/burn-wgpu/src/backend.rs @@ -0,0 +1,45 @@ +use burn_tensor::backend::Backend; +use rand::{rngs::StdRng, SeedableRng}; + +use crate::{ + element::{FloatElement, IntElement}, + tensor::WGPUTensor, + GraphicsAPI, WGPUDevice, +}; +use std::{marker::PhantomData, sync::Mutex}; + +pub(crate) static SEED: Mutex> = Mutex::new(None); + +#[derive(Debug, Default, Clone)] +pub struct WGPUBackend { + _g: PhantomData, + _f: PhantomData, + _i: PhantomData, +} + +impl Backend for WGPUBackend { + type Device = WGPUDevice; + type FullPrecisionBackend = WGPUBackend; + + type FullPrecisionElem = f32; + type FloatElem = F; + type IntElem = I; + + type TensorPrimitive = WGPUTensor; + type IntTensorPrimitive = WGPUTensor; + type BoolTensorPrimitive = WGPUTensor; + + fn name() -> String { + String::from("wgpu") + } + + fn seed(seed: u64) { + let rng = StdRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } + + fn ad_enabled() -> bool { + false + } +} diff --git a/burn-wgpu/src/context.rs b/burn-wgpu/src/context.rs new file mode 100644 index 000000000..67f056247 --- /dev/null +++ b/burn-wgpu/src/context.rs @@ -0,0 +1,256 @@ +use burn_common::id::IdGenerator; +use std::{ + any::TypeId, + borrow::Cow, + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use wgpu::{ + util::{BufferInitDescriptor, DeviceExt}, + Buffer, DeviceDescriptor, DeviceType, ShaderModule, ShaderModuleDescriptor, +}; + +use crate::{kernel::KernelGenerator, GraphicsAPI, WGPUDevice}; + +/// The context is the basic struct that allows to execute GPU kernel on devices. +/// +/// You can access a context for a [wgpu device](WGPUDevice) using [get_context](crate::pool::get_context). +#[derive(Debug)] +pub struct Context { + id: String, + queue: wgpu::Queue, + device_wgpu: wgpu::Device, + cache: Mutex>>, + pub(crate) device: WGPUDevice, +} + +#[derive(new, Clone, Debug)] +pub struct WorkGroup { + pub x: u32, + pub y: u32, + pub z: u32, +} + +impl Context { + pub(crate) fn new(device: &WGPUDevice) -> Self { + // Instantiates instance of WebGPU + let instance = wgpu::Instance::default(); + + // `request_adapter` instantiates the general connection to the GPU + let adapters = instance.enumerate_adapters(G::backend().into()); + let mut adapters = adapters + .filter(|adapter| { + let device_type = adapter.get_info().device_type; + match device { + WGPUDevice::DiscreteGPU(_) => device_type == DeviceType::DiscreteGpu, + WGPUDevice::IntegratedGPU(_) => device_type == DeviceType::IntegratedGpu, + WGPUDevice::VirtualGPU(_) => device_type == DeviceType::VirtualGpu, + WGPUDevice::CPU => device_type == DeviceType::Cpu, + } + }) + .collect::>(); + + let adapter = match device { + WGPUDevice::DiscreteGPU(num) => { + assert!(adapters.len() > *num, "No Discrete GPU device found"); + adapters.remove(*num) + } + WGPUDevice::IntegratedGPU(num) => { + assert!(adapters.len() > *num, "No Integrated GPU device found"); + adapters.remove(*num) + } + WGPUDevice::VirtualGPU(num) => { + assert!(adapters.len() > *num, "No Virtual GPU device found"); + adapters.remove(*num) + } + WGPUDevice::CPU => { + assert!(!adapters.is_empty(), "No CPU device found"); + adapters.remove(0) + } + }; + + let device_wgpu = device.clone(); + let (device, queue) = pollster::block_on(adapter.request_device( + &DeviceDescriptor { + label: None, + features: wgpu::Features::empty(), + limits: wgpu::Limits::downlevel_defaults(), + }, + None, + )) + .expect("Unable to request the device with the adapter"); + + Self { + id: IdGenerator::generate(), + queue, + device_wgpu: device, + device: device_wgpu, + cache: Mutex::new(HashMap::new()), + } + } + + /// Create a new buffer with the provided size. + pub fn create_buffer(&self, size: usize) -> Buffer { + self.device_wgpu.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: size as u64, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }) + } + + /// Create a new buffer initialized with the provided bytes. + pub fn create_buffer_with_data(&self, data: &[u8]) -> Buffer { + let buffer_src = self.device_wgpu.create_buffer_init(&BufferInitDescriptor { + label: Some("Buffer Src"), + contents: data, + usage: wgpu::BufferUsages::COPY_SRC, + }); + + let buffer = self.create_buffer(buffer_src.size() as usize); + + // Create a command encoder + let mut encoder = + self.device_wgpu + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Command Encoder"), + }); + + // Copy data from the staging buffer to the target buffer + encoder.copy_buffer_to_buffer(&buffer_src, 0, &buffer, 0, buffer_src.size()); + + // Submit the command encoder to the queue + self.queue.submit(std::iter::once(encoder.finish())); + + buffer + } + + /// Read a buffer from the GPU and return its content as bytes. + pub fn buffer_to_data(&self, buffer: &Buffer) -> Vec { + let size = buffer.size(); + + let buffer_dest = self.device_wgpu.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Create a command encoder + let mut encoder = + self.device_wgpu + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Command Encoder"), + }); + + encoder.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size); + + self.queue.submit(std::iter::once(encoder.finish())); + + let buffer_slice = buffer_dest.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + self.device_wgpu.poll(wgpu::Maintain::Wait); + + let result = pollster::block_on(receiver.receive()); + + if let Some(Ok(())) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + buffer_dest.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } + } + + /// Compile a kernel template if not present in the cache. + pub fn compile(&self) -> Arc { + let mut cache = self.cache.lock().unwrap(); + let template_id = TypeId::of::(); + + if let Some(module) = cache.get(&template_id) { + return module.clone(); + } + + let source = K::generate(); + + let module = self + .device_wgpu + .create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source.as_ref())), + }); + let module = Arc::new(module); + + cache.insert(template_id, module.clone()); + + module + } + + /// Execute a kernel using the provided buffers. + /// + /// # Notes + /// + /// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a + /// buffer can be mutated when lauching a compute shaders with write access to a buffer. + /// + /// Buffer positions are used as bindings when lauching a compute kernel. + pub fn execute(&self, work_group: &WorkGroup, kernel: &ShaderModule, buffers: &[&Buffer]) { + let pipeline = self + .device_wgpu + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: kernel, + entry_point: "main", + }); + + let group_layout = pipeline.get_bind_group_layout(0); + + let entries = buffers + .iter() + .enumerate() + .map(|(i, buffer)| wgpu::BindGroupEntry { + binding: i as u32, + resource: buffer.as_entire_binding(), + }) + .collect::>(); + + let bind_group = self + .device_wgpu + .create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &group_layout, + entries: &entries, + }); + + let mut encoder = self + .device_wgpu + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + let mut compute = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); + compute.set_pipeline(&pipeline); + compute.set_bind_group(0, &bind_group, &[]); + + compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + std::mem::drop(compute); + + self.queue.submit(Some(encoder.finish())); + } +} + +impl PartialEq for Context { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} diff --git a/burn-wgpu/src/device.rs b/burn-wgpu/src/device.rs new file mode 100644 index 000000000..45889b174 --- /dev/null +++ b/burn-wgpu/src/device.rs @@ -0,0 +1,25 @@ +/// The device struct when using the `wgpu` backend. +/// +/// Note that you need to provide the device index when using a GPU backend. +/// +/// # Example +/// +/// ```no_run +/// use burn_wgpu::WGPUDevice; +/// +/// let device_gpu_1 = WGPUDevice::DiscreteGPU(0); // First discrete GPU found. +/// let device_gpu_2 = WGPUDevice::DiscreteGPU(1); // Second discrete GPU found. +/// ``` +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub enum WGPUDevice { + DiscreteGPU(usize), + IntegratedGPU(usize), + VirtualGPU(usize), + CPU, +} + +impl Default for WGPUDevice { + fn default() -> Self { + Self::CPU + } +} diff --git a/burn-wgpu/src/element.rs b/burn-wgpu/src/element.rs new file mode 100644 index 000000000..b176a8b84 --- /dev/null +++ b/burn-wgpu/src/element.rs @@ -0,0 +1,66 @@ +use burn_tensor::Element; + +pub trait WGPUElement: core::fmt::Debug + 'static + Clone +where + Self: Sized, +{ + fn type_name() -> &'static str; + fn as_bytes(slice: &[Self]) -> &[u8]; + fn from_bytes(bytes: &[u8]) -> &[Self]; +} + +pub trait FloatElement: WGPUElement + Element {} + +pub trait IntElement: WGPUElement + Element {} + +impl WGPUElement for u32 { + fn type_name() -> &'static str { + "u32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } +} + +impl WGPUElement for i32 { + fn type_name() -> &'static str { + "i32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } +} + +impl WGPUElement for i64 { + fn type_name() -> &'static str { + "i64" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } +} + +impl WGPUElement for f32 { + fn type_name() -> &'static str { + "f32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } +} + +impl FloatElement for f32 {} +impl IntElement for i32 {} +impl IntElement for i64 {} diff --git a/burn-wgpu/src/graphics.rs b/burn-wgpu/src/graphics.rs new file mode 100644 index 000000000..6f5cc6573 --- /dev/null +++ b/burn-wgpu/src/graphics.rs @@ -0,0 +1,61 @@ +/// The basic trait to specify which graphics API to use as Backend. +/// +/// Options are: +/// - [Vulkan](Vulkan) +/// - [Metal](Metal) +/// - [OpenGL](OpenGL) +/// - [DirectX 11](Dx11) +/// - [DirectX 12](Dx12) +/// - [WebGPU](WebGPU) +pub trait GraphicsAPI: Send + Sync + core::fmt::Debug + Default + Clone + 'static { + fn backend() -> wgpu::Backend; +} + +#[derive(Default, Debug, Clone)] +pub struct Vulkan; +#[derive(Default, Debug, Clone)] +pub struct Metal; +#[derive(Default, Debug, Clone)] +pub struct OpenGL; +#[derive(Default, Debug, Clone)] +pub struct Dx11; +#[derive(Default, Debug, Clone)] +pub struct Dx12; +#[derive(Default, Debug, Clone)] +pub struct WebGPU; + +impl GraphicsAPI for Vulkan { + fn backend() -> wgpu::Backend { + wgpu::Backend::Vulkan + } +} + +impl GraphicsAPI for Metal { + fn backend() -> wgpu::Backend { + wgpu::Backend::Metal + } +} + +impl GraphicsAPI for OpenGL { + fn backend() -> wgpu::Backend { + wgpu::Backend::Gl + } +} + +impl GraphicsAPI for Dx11 { + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx11 + } +} + +impl GraphicsAPI for Dx12 { + fn backend() -> wgpu::Backend { + wgpu::Backend::Dx12 + } +} + +impl GraphicsAPI for WebGPU { + fn backend() -> wgpu::Backend { + wgpu::Backend::BrowserWebGpu + } +} diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs new file mode 100644 index 000000000..310d03e9a --- /dev/null +++ b/burn-wgpu/src/kernel/base.rs @@ -0,0 +1,87 @@ +use crate::element::WGPUElement; +use std::marker::PhantomData; + +/// Generate wgpu kernel source code to create [compute shader modules](wgpu::ShaderModule). +pub trait KernelGenerator: 'static { + /// Source code concrete type. + type Source: AsRef; + + /// Generate the source code. + fn generate() -> Self::Source; +} + +#[macro_export] +macro_rules! kernel_wgsl { + ( + $struct:ident, + $file:expr + ) => { + #[derive(new)] + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = &'static str; + + fn generate() -> Self::Source { + include_str!($file) + } + } + }; +} + +/// Generate kernel source code by replacing some information using templating. +pub struct KernelSettings< + K: KernelGenerator, + E: WGPUElement, + I: WGPUElement, + const WORKGROUP_X_SIZE: usize, + const WORKGROUP_Y_SIZE: usize, + const WORKGROUP_Z_SIZE: usize, +> { + _k: PhantomData, + _e: PhantomData, + _i: PhantomData, +} + +impl< + K: KernelGenerator, + E: WGPUElement, + I: WGPUElement, + const WORKGROUP_X_SIZE: usize, + const WORKGROUP_Y_SIZE: usize, + const WORKGROUP_Z_SIZE: usize, + > KernelGenerator + for KernelSettings +{ + type Source = String; + + fn generate() -> String { + let mut source = K::generate().as_ref().to_string(); + + source = source.replace("WORKGROUP_SIZE_X", &WORKGROUP_X_SIZE.to_string()); + source = source.replace("WORKGROUP_SIZE_Y", &WORKGROUP_Y_SIZE.to_string()); + source = source.replace("WORKGROUP_SIZE_Z", &WORKGROUP_Y_SIZE.to_string()); + source = source.replace("elem", E::type_name()); + source = source.replace("int", I::type_name()); + + source + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::any::TypeId; + + #[test] + fn test_kernel_type_id() { + kernel_wgsl!(Add, "../template/binary_elemwise.wgsl"); + + let type_id_1 = TypeId::of::>(); + let type_id_2 = TypeId::of::>(); + let type_id_3 = TypeId::of::>(); + + assert_ne!(type_id_1, type_id_2); + assert_eq!(type_id_1, type_id_3); + } +} diff --git a/burn-wgpu/src/kernel/binary_elemwise.rs b/burn-wgpu/src/kernel/binary_elemwise.rs new file mode 100644 index 000000000..5e9afc2de --- /dev/null +++ b/burn-wgpu/src/kernel/binary_elemwise.rs @@ -0,0 +1,159 @@ +use super::{KernelGenerator, KernelSettings}; +use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor}; +use burn_tensor::Shape; +use num_traits::ToPrimitive; +use std::sync::Arc; + +kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl"); +kernel_wgsl!( + BinaryElemwiseInplaceRaw, + "../template/binary_elemwise_inplace.wgsl" +); + +#[macro_export] +macro_rules! binary_elemwise { + ( + $struct:ident, + $ops:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::BinaryElemwiseRaw::generate().to_string(); + let body = format!( + "output[global_id.x] = lhs[index_lhs] {} rhs[index_rhs]", + $ops + ); + source.replace("BODY", &body) + } + } + }; +} + +#[macro_export] +macro_rules! binary_elemwise_inplace { + ( + $struct:ident, + $ops:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::BinaryElemwiseInplaceRaw::generate().to_string(); + let body = format!( + "lhs[global_id.x] = lhs[global_id.x] {} rhs[index_rhs];", + $ops + ); + source.replace("BODY", &body) + } + } + }; +} + +pub fn binary_elemwise( + lhs: WGPUTensor, + rhs: WGPUTensor, +) -> WGPUTensor { + lhs.assert_is_on_save_device(&rhs); + + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::new(shape_out); + + let buffer = lhs + .context + .create_buffer(shape_out.num_elements() * core::mem::size_of::()); + let output = WGPUTensor::new(lhs.context.clone(), shape_out, Arc::new(buffer)); + + let kernel = lhs + .context + .compile::>(); + let mut info: Vec = vec![D.to_u32().unwrap()]; + + lhs.strides + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + rhs.strides + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + lhs.shape + .dims + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + rhs.shape + .dims + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + let info_buffers = lhs + .context + .create_buffer_with_data(bytemuck::cast_slice(&info)); + + lhs.context.execute( + &WorkGroup::new( + f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + ); + + output +} +pub fn binary_elemwise_inplace( + lhs: WGPUTensor, + rhs: WGPUTensor, +) -> WGPUTensor { + lhs.assert_is_on_save_device(&rhs); + + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let kernel = lhs + .context + .compile::>(); + let mut info: Vec = vec![D.to_u32().unwrap()]; + rhs.strides + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + rhs.shape + .dims + .into_iter() + .for_each(|v| info.push(v.to_u32().unwrap())); + let info_buffers = lhs + .context + .create_buffer_with_data(bytemuck::cast_slice(&info)); + + lhs.context.execute( + &WorkGroup::new( + f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&lhs.buffer, &rhs.buffer, &info_buffers], + ); + + lhs +} diff --git a/burn-wgpu/src/kernel/mod.rs b/burn-wgpu/src/kernel/mod.rs new file mode 100644 index 000000000..420c171e6 --- /dev/null +++ b/burn-wgpu/src/kernel/mod.rs @@ -0,0 +1,9 @@ +mod base; +mod binary_elemwise; +mod unary; +mod unary_scalar; + +pub use base::*; +pub use binary_elemwise::*; +pub use unary::*; +pub use unary_scalar::*; diff --git a/burn-wgpu/src/kernel/unary.rs b/burn-wgpu/src/kernel/unary.rs new file mode 100644 index 000000000..06cbdcae3 --- /dev/null +++ b/burn-wgpu/src/kernel/unary.rs @@ -0,0 +1,120 @@ +use super::{KernelGenerator, KernelSettings}; +use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor}; +use std::sync::Arc; + +kernel_wgsl!(UnaryRaw, "../template/unary.wgsl"); +kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl"); + +#[macro_export] +macro_rules! unary { + ( + $struct:ident, + func $func:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryRaw::generate().to_string(); + let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func); + source.replace("BODY", &body) + } + } + }; + ( + $struct:ident, + body $body:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryRaw::generate().to_string(); + source.replace("BODY", $body) + } + } + }; +} + +#[macro_export] +macro_rules! unary_inplace { + ( + $struct:ident, + func $func:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryInplaceRaw::generate().to_string(); + let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func); + source.replace("BODY", &body) + } + } + }; + ( + $struct:ident, + body $body:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryInplaceRaw::generate().to_string(); + source.replace("BODY", $body) + } + } + }; +} + +pub fn unary( + input: WGPUTensor, +) -> WGPUTensor { + let buffer = input + .context + .create_buffer(input.shape.num_elements() * core::mem::size_of::()); + let output = WGPUTensor::new(input.context.clone(), input.shape, Arc::new(buffer)); + let kernel = input + .context + .compile::>(); + + input.context.execute( + &WorkGroup::new( + f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&input.buffer, &output.buffer], + ); + + output +} + +pub fn unary_inplace( + input: WGPUTensor, +) -> WGPUTensor { + let kernel = input + .context + .compile::>(); + + input.context.execute( + &WorkGroup::new( + f32::ceil(input.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&input.buffer], + ); + + input +} diff --git a/burn-wgpu/src/kernel/unary_scalar.rs b/burn-wgpu/src/kernel/unary_scalar.rs new file mode 100644 index 000000000..872dc326b --- /dev/null +++ b/burn-wgpu/src/kernel/unary_scalar.rs @@ -0,0 +1,135 @@ +use super::{KernelGenerator, KernelSettings}; +use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor}; +use std::sync::Arc; + +kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl"); +kernel_wgsl!( + UnaryScalarInplaceRaw, + "../template/unary_scalar_inplace.wgsl" +); + +#[macro_export] +macro_rules! unary_scalar { + ( + $struct:ident, + ops $ops:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryScalarRaw::generate().to_string(); + let body = format!("output[global_id.x] = lhs[global_id.x] {} rhs;", $ops); + + source.replace("BODY", &body) + } + } + }; + + ( + $struct:ident, + func $func:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryScalarRaw::generate().to_string(); + let body = format!("output[global_id.x] = {}(lhs[global_id.x], rhs);", $func); + + source.replace("BODY", &body) + } + } + }; +} + +#[macro_export] +macro_rules! unary_scalar_inplace { + ( + $struct:ident, + ops $ops:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string(); + let body = format!("lhs[global_id.x] = lhs[global_id.x] {} rhs;", $ops); + + source.replace("BODY", &body) + } + } + }; + + ( + $struct:ident, + func $func:expr + ) => { + pub struct $struct; + + impl $crate::kernel::KernelGenerator for $struct { + type Source = String; + + fn generate() -> Self::Source { + let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string(); + let body = format!("lhs[global_id.x] = {}(lhs[global_id.x], rhs);", $func); + + source.replace("BODY", &body) + } + } + }; +} + +pub fn unary_scalar( + lhs: WGPUTensor, + scalar: E, +) -> WGPUTensor { + let buffer = lhs + .context + .create_buffer(lhs.shape.num_elements() * core::mem::size_of::()); + let output = WGPUTensor::new(lhs.context.clone(), lhs.shape, Arc::new(buffer)); + let kernel = lhs + .context + .compile::>(); + let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar])); + + lhs.context.execute( + &WorkGroup::new( + f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&lhs.buffer, &rhs_buffer, &output.buffer], + ); + + output +} + +pub fn unary_scalar_inplace( + lhs: WGPUTensor, + scalar: E, +) -> WGPUTensor { + let kernel = lhs + .context + .compile::>(); + let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar])); + + lhs.context.execute( + &WorkGroup::new( + f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32, + 1, + 1, + ), + &kernel, + &[&lhs.buffer, &rhs_buffer], + ); + + lhs +} diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs new file mode 100644 index 000000000..7b8324206 --- /dev/null +++ b/burn-wgpu/src/lib.rs @@ -0,0 +1,39 @@ +#[macro_use] +extern crate derive_new; + +mod ops; + +pub(crate) mod context; +pub(crate) mod element; +pub(crate) mod kernel; +pub(crate) mod pool; +pub(crate) mod tensor; + +mod device; +pub use device::*; + +mod backend; +pub use backend::*; + +mod graphics; +pub use graphics::*; + +#[cfg(test)] +mod tests { + type TestBackend = crate::WGPUBackend; + + burn_tensor::testgen_add!(); + burn_tensor::testgen_sub!(); + burn_tensor::testgen_div!(); + burn_tensor::testgen_mul!(); + burn_tensor::testgen_neg!(); + burn_tensor::testgen_powf!(); + burn_tensor::testgen_exp!(); + burn_tensor::testgen_log!(); + burn_tensor::testgen_relu!(); + + // Once all operations will be implemented. + // type TestTensor = burn_tensor::Tensor; + // type TestTensorInt = burn_tensor::Tensor; + // burn_tensor::testgen_all!(); +} diff --git a/burn-wgpu/src/ops/activation_ops.rs b/burn-wgpu/src/ops/activation_ops.rs new file mode 100644 index 000000000..e833f4681 --- /dev/null +++ b/burn-wgpu/src/ops/activation_ops.rs @@ -0,0 +1,27 @@ +use burn_tensor::ops::ActivationOps; + +use crate::{ + element::{FloatElement, IntElement}, + kernel::{unary, unary_inplace}, + unary, unary_inplace, GraphicsAPI, WGPUBackend, +}; + +use super::FloatTensor; + +impl ActivationOps> for WGPUBackend +where + G: GraphicsAPI + 'static, + F: FloatElement, + I: IntElement, +{ + fn relu(tensor: FloatTensor) -> FloatTensor { + unary!(Relu, body "output[global_id.x] = max(input[global_id.x], 0.0);"); + unary_inplace!(ReluInplace, body "input[global_id.x] = max(input[global_id.x], 0.0);"); + + if tensor.can_mut() { + return unary_inplace::(tensor); + } + + unary::(tensor) + } +} diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs new file mode 100644 index 000000000..1b0dbc480 --- /dev/null +++ b/burn-wgpu/src/ops/base.rs @@ -0,0 +1,59 @@ +use std::{marker::PhantomData, sync::Arc}; + +use burn_tensor::{backend::Backend, Data, Shape}; + +use crate::{element::WGPUElement, pool::get_context, tensor::WGPUTensor, GraphicsAPI, WGPUDevice}; + +pub type FloatElem = ::FloatElem; +pub type Device = ::Device; + +pub type FloatTensor = ::TensorPrimitive; + +pub type IntElem = ::IntElem; +pub type IntTensor = ::IntTensorPrimitive; +pub type BoolTensor = ::BoolTensorPrimitive; + +pub struct BaseOps { + _g: PhantomData, +} + +impl BaseOps { + pub fn from_data( + data: Data, + device: &WGPUDevice, + ) -> WGPUTensor { + let context = get_context::(device); + let buffer = context.create_buffer_with_data(E::as_bytes(&data.value)); + + WGPUTensor::new(context, data.shape, Arc::new(buffer)) + } + + pub fn to_data(tensor: &WGPUTensor) -> Data { + let bytes = tensor.context.buffer_to_data(&tensor.buffer); + let values = E::from_bytes(&bytes); + + Data::new(values.to_vec(), tensor.shape.clone()) + } + + pub fn to_device( + tensor: WGPUTensor, + device: &WGPUDevice, + ) -> WGPUTensor { + if &tensor.context.device == device { + return tensor; + } + + let context = get_context::(device); + tensor.to_context(context) + } + + pub fn empty( + shape: Shape, + device: &WGPUDevice, + ) -> WGPUTensor { + let context = get_context::(device); + let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); + + WGPUTensor::new(context, shape, Arc::new(buffer)) + } +} diff --git a/burn-wgpu/src/ops/bool_ops.rs b/burn-wgpu/src/ops/bool_ops.rs new file mode 100644 index 000000000..b2c7ea069 --- /dev/null +++ b/burn-wgpu/src/ops/bool_ops.rs @@ -0,0 +1,106 @@ +use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape}; + +use crate::{ + element::{FloatElement, IntElement}, + GraphicsAPI, WGPUBackend, +}; + +use super::{BaseOps, BoolTensor, Device, IntTensor}; + +impl BoolTensorOps> for WGPUBackend +where + G: GraphicsAPI + 'static, + F: FloatElement, + I: IntElement, +{ + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + BaseOps::::empty(shape, device) + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + tensor.shape.clone() + } + + fn bool_into_data(tensor: BoolTensor) -> Data { + let data = BaseOps::::to_data(&tensor); + + Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape) + } + + fn bool_from_data( + data: Data, + device: &Device, + ) -> BoolTensor { + let data: Data = Data::new( + data.value + .into_iter() + .map(|c| match c { + true => 1, + false => 0, + }) + .collect(), + data.shape, + ); + BaseOps::::from_data(data, device) + } + + fn bool_into_int(_tensor: BoolTensor) -> IntTensor { + todo!() + } + + fn bool_device( + _tensor: & as Backend>::BoolTensorPrimitive, + ) -> as Backend>::Device { + todo!() + } + + fn bool_to_device( + tensor: BoolTensor, + device: &Device, + ) -> BoolTensor { + BaseOps::::to_device(tensor, device) + } + + fn bool_reshape( + _tensor: as Backend>::BoolTensorPrimitive, + _shape: Shape, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn bool_index( + _tensor: as Backend>::BoolTensorPrimitive, + _indexes: [std::ops::Range; D2], + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn bool_index_assign( + _tensor: as Backend>::BoolTensorPrimitive, + _indexes: [std::ops::Range; D2], + _value: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn bool_cat( + _tensors: Vec< as Backend>::BoolTensorPrimitive>, + _dim: usize, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn bool_equal( + _lhs: as Backend>::BoolTensorPrimitive, + _rhs: as Backend>::BoolTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn bool_equal_elem( + _lhs: as Backend>::BoolTensorPrimitive, + _rhs: bool, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } +} diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs new file mode 100644 index 000000000..6234c2adb --- /dev/null +++ b/burn-wgpu/src/ops/float_ops.rs @@ -0,0 +1,406 @@ +use super::numeric::NumericOps; +use super::{BaseOps, Device, FloatElem, FloatTensor}; +use crate::kernel::{unary, unary_inplace, unary_scalar, unary_scalar_inplace}; +use crate::{ + element::{FloatElement, IntElement}, + unary, unary_inplace, GraphicsAPI, WGPUBackend, SEED, +}; +use crate::{unary_scalar, unary_scalar_inplace}; +use burn_common::rand::get_seeded_rng; +use burn_tensor::ElementConversion; +use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, Shape}; + +impl TensorOps> for WGPUBackend +where + G: GraphicsAPI + 'static, + F: FloatElement, + I: IntElement, +{ + fn from_data( + data: Data, D>, + device: &Device, + ) -> FloatTensor { + BaseOps::::from_data(data, device) + } + + fn random( + shape: Shape, + distribution: Distribution>, + device: &Device, + ) -> as Backend>::TensorPrimitive { + let mut seed = SEED.lock().unwrap(); + let mut rng = if let Some(rng_seeded) = seed.as_ref() { + rng_seeded.clone() + } else { + get_seeded_rng() + }; + let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); + *seed = Some(rng); + tensor + } + + fn shape(tensor: &FloatTensor) -> Shape { + tensor.shape.clone() + } + + fn to_data(tensor: &FloatTensor) -> Data, D> { + BaseOps::::to_data(tensor) + } + + fn device(tensor: &FloatTensor) -> Device { + tensor.context.device.clone() + } + + fn to_device( + tensor: FloatTensor, + device: &Device, + ) -> FloatTensor { + BaseOps::::to_device(tensor, device) + } + + fn empty(shape: Shape, device: &Device) -> FloatTensor { + BaseOps::::empty(shape, device) + } + + fn add( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + NumericOps::add(lhs, rhs) + } + + fn add_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + NumericOps::add_scalar(lhs, rhs) + } + + fn sub( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + NumericOps::sub(lhs, rhs) + } + + fn sub_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + NumericOps::sub_scalar(lhs, rhs) + } + + fn mul( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + NumericOps::mul(lhs, rhs) + } + + fn mul_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + NumericOps::mul_scalar(lhs, rhs) + } + + fn div( + lhs: FloatTensor, + rhs: FloatTensor, + ) -> FloatTensor { + NumericOps::div(lhs, rhs) + } + + fn div_scalar( + lhs: FloatTensor, + rhs: FloatElem, + ) -> FloatTensor { + NumericOps::div_scalar(lhs, rhs) + } + + fn matmul( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn swap_dims( + _tensor: as Backend>::TensorPrimitive, + _dim1: usize, + _dim2: usize, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn reshape( + _tensor: as Backend>::TensorPrimitive, + _shape: Shape, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn gather( + _dim: usize, + _tensor: as Backend>::TensorPrimitive, + _indexes: as Backend>::IntTensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn scatter( + _dim: usize, + _tensor: as Backend>::TensorPrimitive, + _indexes: as Backend>::IntTensorPrimitive, + _value: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn index_select( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + _indexes: as Backend>::IntTensorPrimitive<1>, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn index_select_assign( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + _indexes: as Backend>::IntTensorPrimitive<1>, + _value: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn index( + _tensor: as Backend>::TensorPrimitive, + _indexes: [std::ops::Range; D2], + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn index_assign( + _tensor: as Backend>::TensorPrimitive, + _indexes: [std::ops::Range; D2], + _value: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn mask_scatter( + _tensor: as Backend>::TensorPrimitive, + _mask: as Backend>::BoolTensorPrimitive, + _source: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn mask_fill( + _tensor: as Backend>::TensorPrimitive, + _mask: as Backend>::BoolTensorPrimitive, + _value: as Backend>::FloatElem, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn equal( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn equal_elem( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::FloatElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn greater( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn greater_elem( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::FloatElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn greater_equal( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn greater_equal_elem( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::FloatElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn lower( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn lower_elem( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::FloatElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn lower_equal( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::TensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn lower_equal_elem( + _lhs: as Backend>::TensorPrimitive, + _rhs: as Backend>::FloatElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn sum( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive<1> { + todo!() + } + + fn sum_dim( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn mean( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive<1> { + todo!() + } + + fn mean_dim( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn to_full_precision( + _tensor: & as Backend>::TensorPrimitive, + ) -> < as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive + { + todo!() + } + + fn from_full_precision( + _tensor: < as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn exp(lhs: FloatTensor) -> FloatTensor { + unary!(Exp, func "exp"); + unary_inplace!(ExpInplace, func "exp"); + + if lhs.can_mut() { + return unary_inplace::(lhs); + } + + unary::(lhs) + } + + fn log(tensor: FloatTensor) -> FloatTensor { + unary!(Log, func "log"); + unary_inplace!(LogInplace, func "log"); + + if tensor.can_mut() { + return unary_inplace::(tensor); + } + + unary::(tensor) + } + + fn log1p( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { + unary_scalar!(Powf, func "pow"); + unary_scalar_inplace!(PowfInplace, func "pow"); + + if lhs.can_mut() { + return unary_scalar_inplace::(lhs, rhs.elem()); + } + + unary_scalar::(lhs, rhs.elem()) + } + + fn sqrt( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn cos( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn sin( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn tanh( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn erf( + _tensor: as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn cat( + _tensors: Vec< as Backend>::TensorPrimitive>, + _dim: usize, + ) -> as Backend>::TensorPrimitive { + todo!() + } + + fn argmax( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn argmin( + _tensor: as Backend>::TensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } +} diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs new file mode 100644 index 000000000..5c70d277c --- /dev/null +++ b/burn-wgpu/src/ops/int_ops.rs @@ -0,0 +1,314 @@ +use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape}; + +use crate::{ + element::{FloatElement, IntElement}, + GraphicsAPI, WGPUBackend, +}; + +use super::{numeric::NumericOps, BaseOps, Device, IntElem, IntTensor}; + +impl IntTensorOps> for WGPUBackend +where + G: GraphicsAPI + 'static, + F: FloatElement, + I: IntElement, +{ + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + BaseOps::::empty(shape, device) + } + + fn int_shape( + _tensor: & as Backend>::IntTensorPrimitive, + ) -> Shape { + todo!() + } + + fn int_into_data(tensor: IntTensor) -> Data { + BaseOps::::to_data(&tensor) + } + + fn int_from_data( + data: Data, + device: &Device, + ) -> IntTensor { + BaseOps::::from_data(data, device) + } + + fn int_device( + _tensor: & as Backend>::IntTensorPrimitive, + ) -> as Backend>::Device { + todo!() + } + + fn int_to_device( + tensor: IntTensor, + device: &Device, + ) -> IntTensor { + BaseOps::::to_device(tensor, device) + } + + fn int_reshape( + _tensor: as Backend>::IntTensorPrimitive, + _shape: Shape, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_index( + _tensor: as Backend>::IntTensorPrimitive, + _indexes: [std::ops::Range; D2], + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_index_assign( + _tensor: as Backend>::IntTensorPrimitive, + _indexes: [std::ops::Range; D2], + _value: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_mask_scatter( + _tensor: as Backend>::IntTensorPrimitive, + _mask: as Backend>::BoolTensorPrimitive, + _source: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_mask_fill( + _tensor: as Backend>::IntTensorPrimitive, + _mask: as Backend>::BoolTensorPrimitive, + _value: as Backend>::IntElem, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_gather( + _dim: usize, + _tensor: as Backend>::IntTensorPrimitive, + _indexes: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_scatter( + _dim: usize, + _tensor: as Backend>::IntTensorPrimitive, + _indexes: as Backend>::IntTensorPrimitive, + _value: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_index_select_dim( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + _indexes: as Backend>::IntTensorPrimitive<1>, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_index_select_dim_assign( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + _indexes: as Backend>::IntTensorPrimitive<1>, + _value: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_cat( + _tensors: Vec< as Backend>::IntTensorPrimitive>, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_equal( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_equal_elem( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_greater( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_greater_elem( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_greater_equal( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_greater_equal_elem( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_lower( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_lower_elem( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_lower_equal( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntTensorPrimitive, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_lower_equal_elem( + _lhs: as Backend>::IntTensorPrimitive, + _rhs: as Backend>::IntElem, + ) -> as Backend>::BoolTensorPrimitive { + todo!() + } + + fn int_add( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + NumericOps::add::(lhs, rhs) + } + + fn int_add_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + NumericOps::add_scalar(lhs, rhs) + } + + fn int_sub( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + NumericOps::sub(lhs, rhs) + } + + fn int_sub_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + NumericOps::sub_scalar(lhs, rhs) + } + + fn int_mul( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + NumericOps::mul(lhs, rhs) + } + + fn int_mul_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + NumericOps::mul_scalar(lhs, rhs) + } + + fn int_div( + lhs: IntTensor, + rhs: IntTensor, + ) -> IntTensor { + NumericOps::div(lhs, rhs) + } + + fn int_div_scalar( + lhs: IntTensor, + rhs: IntElem, + ) -> IntTensor { + NumericOps::div_scalar(lhs, rhs) + } + + fn int_neg( + _tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_zeros( + _shape: Shape, + _device: & as Backend>::Device, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_ones( + _shape: Shape, + _device: & as Backend>::Device, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_sum( + _tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive<1> { + todo!() + } + + fn int_sum_dim( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_mean( + _tensor: as Backend>::IntTensorPrimitive, + ) -> as Backend>::IntTensorPrimitive<1> { + todo!() + } + + fn int_mean_dim( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_argmax( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } + + fn int_argmin( + _tensor: as Backend>::IntTensorPrimitive, + _dim: usize, + ) -> as Backend>::IntTensorPrimitive { + todo!() + } +} diff --git a/burn-wgpu/src/ops/mod.rs b/burn-wgpu/src/ops/mod.rs new file mode 100644 index 000000000..69338f67e --- /dev/null +++ b/burn-wgpu/src/ops/mod.rs @@ -0,0 +1,10 @@ +mod activation_ops; +mod bool_ops; +mod float_ops; +mod int_ops; +mod module_ops; + +mod base; +pub(crate) use base::*; + +pub(crate) mod numeric; diff --git a/burn-wgpu/src/ops/module_ops.rs b/burn-wgpu/src/ops/module_ops.rs new file mode 100644 index 000000000..e3dd61c4d --- /dev/null +++ b/burn-wgpu/src/ops/module_ops.rs @@ -0,0 +1,94 @@ +use burn_tensor::{backend::Backend, ops::ModuleOps}; + +use crate::{ + element::{FloatElement, IntElement}, + GraphicsAPI, WGPUBackend, +}; + +impl ModuleOps> for WGPUBackend +where + G: GraphicsAPI + 'static, + F: FloatElement, + I: IntElement, +{ + fn embedding( + _weights: as Backend>::TensorPrimitive<2>, + _indexes: as Backend>::IntTensorPrimitive<2>, + ) -> as Backend>::TensorPrimitive<3> { + todo!() + } + + fn embedding_backward( + _weights: as Backend>::TensorPrimitive<2>, + _output: as Backend>::TensorPrimitive<3>, + _indexes: as Backend>::IntTensorPrimitive<2>, + ) -> as Backend>::TensorPrimitive<2> { + todo!() + } + + fn conv2d( + _x: as Backend>::TensorPrimitive<4>, + _weight: as Backend>::TensorPrimitive<4>, + _bias: Option< as Backend>::TensorPrimitive<1>>, + _options: burn_tensor::ops::ConvOptions<2>, + ) -> as Backend>::TensorPrimitive<4> { + todo!() + } + + fn conv_transpose2d( + _x: as Backend>::TensorPrimitive<4>, + _weight: as Backend>::TensorPrimitive<4>, + _bias: Option< as Backend>::TensorPrimitive<1>>, + _options: burn_tensor::ops::ConvTransposeOptions<2>, + ) -> as Backend>::TensorPrimitive<4> { + todo!() + } + + fn avg_pool2d( + _x: as Backend>::TensorPrimitive<4>, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + ) -> as Backend>::TensorPrimitive<4> { + todo!() + } + + fn avg_pool2d_backward( + _x: as Backend>::TensorPrimitive<4>, + _grad: as Backend>::TensorPrimitive<4>, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + ) -> as Backend>::TensorPrimitive<4> { + todo!() + } + + fn max_pool2d( + _x: as Backend>::TensorPrimitive<4>, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + ) -> as Backend>::TensorPrimitive<4> { + todo!() + } + + fn max_pool2d_with_indexes( + _x: as Backend>::TensorPrimitive<4>, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + ) -> burn_tensor::ops::MaxPool2dWithIndexes> { + todo!() + } + + fn max_pool2d_with_indexes_backward( + _x: as Backend>::TensorPrimitive<4>, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _output_grad: as Backend>::TensorPrimitive<4>, + _indexes: as Backend>::IntTensorPrimitive<4>, + ) -> burn_tensor::ops::MaxPool2dBackward> { + todo!() + } +} diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs new file mode 100644 index 000000000..aae0952ac --- /dev/null +++ b/burn-wgpu/src/ops/numeric.rs @@ -0,0 +1,129 @@ +use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace}; +use crate::{ + binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WGPUTensor, + unary_scalar, unary_scalar_inplace, +}; + +pub struct NumericOps; + +impl NumericOps { + pub fn add( + lhs: WGPUTensor, + rhs: WGPUTensor, + ) -> WGPUTensor { + binary_elemwise!(Add, "+"); + binary_elemwise_inplace!(AddInplace, "+"); + + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace::(lhs, rhs); + } + + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace::(rhs, lhs); + } + + binary_elemwise::(lhs, rhs) + } + + pub fn add_scalar( + lhs: WGPUTensor, + rhs: E, + ) -> WGPUTensor { + unary_scalar!(AddScalar, ops "+"); + unary_scalar_inplace!(AddScalarInplace, ops "+"); + + if lhs.can_mut() { + return unary_scalar_inplace::(lhs, rhs); + } + + unary_scalar::(lhs, rhs) + } + + pub fn sub( + lhs: WGPUTensor, + rhs: WGPUTensor, + ) -> WGPUTensor { + binary_elemwise!(Sub, "-"); + binary_elemwise_inplace!(SubInplace, "-"); + + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace::(lhs, rhs); + } + + binary_elemwise::(lhs, rhs) + } + + pub fn sub_scalar( + lhs: WGPUTensor, + rhs: E, + ) -> WGPUTensor { + unary_scalar!(SubScalar, ops "-"); + unary_scalar_inplace!(SubScalarInplace, ops "-"); + + if lhs.can_mut() { + return unary_scalar_inplace::(lhs, rhs); + } + + unary_scalar::(lhs, rhs) + } + + pub fn mul( + lhs: WGPUTensor, + rhs: WGPUTensor, + ) -> WGPUTensor { + binary_elemwise!(Mul, "*"); + binary_elemwise_inplace!(MulInplace, "*"); + + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace::(lhs, rhs); + } + + if rhs.can_mut_broadcast(&lhs) { + return binary_elemwise_inplace::(rhs, lhs); + } + + binary_elemwise::(lhs, rhs) + } + + pub fn mul_scalar( + lhs: WGPUTensor, + rhs: E, + ) -> WGPUTensor { + unary_scalar!(MulScalar, ops "*"); + unary_scalar_inplace!(MulScalarInplace, ops "*"); + + if lhs.can_mut() { + return unary_scalar_inplace::(lhs, rhs); + } + + unary_scalar::(lhs, rhs) + } + + pub fn div( + lhs: WGPUTensor, + rhs: WGPUTensor, + ) -> WGPUTensor { + binary_elemwise!(Div, "/"); + binary_elemwise_inplace!(DivInplace, "/"); + + if lhs.can_mut_broadcast(&rhs) { + return binary_elemwise_inplace::(lhs, rhs); + } + + binary_elemwise::(lhs, rhs) + } + + pub fn div_scalar( + lhs: WGPUTensor, + rhs: E, + ) -> WGPUTensor { + unary_scalar!(DivScalar, ops "/"); + unary_scalar_inplace!(DivScalarInplace, ops "/"); + + if lhs.can_mut() { + return unary_scalar_inplace::(lhs, rhs); + } + + unary_scalar::(lhs, rhs) + } +} diff --git a/burn-wgpu/src/pool.rs b/burn-wgpu/src/pool.rs new file mode 100644 index 000000000..de19e57a6 --- /dev/null +++ b/burn-wgpu/src/pool.rs @@ -0,0 +1,63 @@ +use crate::{context::Context, GraphicsAPI, WGPUDevice}; +use std::{ + any::TypeId, + collections::HashMap, + sync::{Arc, Mutex}, +}; + +static POOL_CONTEXT: Mutex> = Mutex::new(None); + +#[derive(Default)] +struct ContextPool { + contexts: HashMap>, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +struct Key { + api_id: TypeId, + device: WGPUDevice, +} + +impl Key { + fn new(device: &WGPUDevice) -> Self { + Self { + api_id: TypeId::of::(), + device: device.clone(), + } + } +} + +/// Get a [context](Context) for the given [device](WGPUDevice). +/// +/// # Notes +/// +/// If a context already exist for the current [device](WGPUDevice), the same instance will be +/// returned. +pub fn get_context(device: &WGPUDevice) -> Arc { + let mut pool = POOL_CONTEXT.lock().unwrap(); + + let context = if let Some(pool) = pool.as_mut() { + // Fetch device in pool + match pool.contexts.get(&Key::new::(device)) { + Some(context) => context.clone(), + None => { + // Init new device + let context = Arc::new(Context::new::(device)); + pool.contexts.insert(Key::new::(device), context.clone()); + context + } + } + } else { + // Initialize pool + let context = Arc::new(Context::new::(device)); + let mut new_pool = ContextPool::default(); + + new_pool + .contexts + .insert(Key::new::(device), context.clone()); + *pool = Some(new_pool); + context + }; + + context +} diff --git a/burn-wgpu/src/template/binary_elemwise.wgsl b/burn-wgpu/src/template/binary_elemwise.wgsl new file mode 100644 index 000000000..2afcf4118 --- /dev/null +++ b/burn-wgpu/src/template/binary_elemwise.wgsl @@ -0,0 +1,35 @@ +@group(0) +@binding(0) +var lhs: array; + +@group(0) +@binding(1) +var rhs: array; + +@group(0) +@binding(2) +var output: array; + +@group(0) +@binding(3) +var info: array; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let dim: u32 = info[0]; + var index_lhs: u32 = 0u; + var index_rhs: u32 = 0u; + + for (var i: u32 = 0u; i < dim; i++) { + let stride_lhs = info[i + 1u]; + let stride_rhs = info[i + 1u * dim + 1u]; + let shape_lhs = info[i + 2u * dim + 1u]; + let shape_rhs = info[i + 3u * dim + 1u]; + + index_lhs += global_id.x / stride_lhs % shape_lhs * stride_lhs; + index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs; + } + + BODY +} diff --git a/burn-wgpu/src/template/binary_elemwise_inplace.wgsl b/burn-wgpu/src/template/binary_elemwise_inplace.wgsl new file mode 100644 index 000000000..4bc56c9e8 --- /dev/null +++ b/burn-wgpu/src/template/binary_elemwise_inplace.wgsl @@ -0,0 +1,27 @@ +@group(0) +@binding(0) +var lhs: array; + +@group(0) +@binding(1) +var rhs: array; + +@group(0) +@binding(2) +var info: array; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let dim: u32 = info[0]; + var index_rhs: u32 = 0u; + + for (var i: u32 = 0u; i < dim; i++) { + let stride_rhs = info[i + 1u]; + let shape_rhs = info[i + 1u * dim + 1u]; + + index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs; + } + + BODY +} diff --git a/burn-wgpu/src/template/unary.wgsl b/burn-wgpu/src/template/unary.wgsl new file mode 100644 index 000000000..46611e6fb --- /dev/null +++ b/burn-wgpu/src/template/unary.wgsl @@ -0,0 +1,13 @@ +@group(0) +@binding(0) +var input: array; + +@group(0) +@binding(1) +var output: array; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + BODY +} diff --git a/burn-wgpu/src/template/unary_inplace.wgsl b/burn-wgpu/src/template/unary_inplace.wgsl new file mode 100644 index 000000000..ece008e7b --- /dev/null +++ b/burn-wgpu/src/template/unary_inplace.wgsl @@ -0,0 +1,9 @@ +@group(0) +@binding(0) +var input: array; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + BODY +} diff --git a/burn-wgpu/src/template/unary_scalar.wgsl b/burn-wgpu/src/template/unary_scalar.wgsl new file mode 100644 index 000000000..785599d46 --- /dev/null +++ b/burn-wgpu/src/template/unary_scalar.wgsl @@ -0,0 +1,17 @@ +@group(0) +@binding(0) +var lhs: array; + +@group(0) +@binding(1) +var rhs: elem; + +@group(0) +@binding(2) +var output: array; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + BODY +} diff --git a/burn-wgpu/src/template/unary_scalar_inplace.wgsl b/burn-wgpu/src/template/unary_scalar_inplace.wgsl new file mode 100644 index 000000000..5bf35e4b5 --- /dev/null +++ b/burn-wgpu/src/template/unary_scalar_inplace.wgsl @@ -0,0 +1,13 @@ +@group(0) +@binding(0) +var lhs: array; + +@group(0) +@binding(1) +var rhs: elem; + +@compute +@workgroup_size(WORKGROUP_SIZE_X, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + BODY +} diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs new file mode 100644 index 000000000..b0bdc494f --- /dev/null +++ b/burn-wgpu/src/tensor/base.rs @@ -0,0 +1,82 @@ +use burn_tensor::Shape; +use std::{marker::PhantomData, sync::Arc}; +use wgpu::Buffer; + +use crate::{context::Context, element::WGPUElement}; + +#[derive(Debug, Clone)] +pub struct WGPUTensor { + pub(crate) context: Arc, + pub(crate) buffer: Arc, + pub(crate) shape: Shape, + pub(crate) strides: [usize; D], + elem: PhantomData, +} + +impl WGPUTensor { + pub fn new(context: Arc, shape: Shape, buffer: Arc) -> Self { + let mut strides = [0; D]; + + let mut current = 1; + shape + .dims + .iter() + .enumerate() + .rev() + .for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); + + Self { + context, + buffer, + shape, + strides, + elem: PhantomData::default(), + } + } + pub fn to_context(&self, context: Arc) -> Self { + let data = self.context.buffer_to_data(&self.buffer); + let buffer = Arc::new(context.create_buffer_with_data(&data)); + + Self { + context, + buffer, + shape: self.shape.clone(), + strides: self.strides, + elem: PhantomData::default(), + } + } + pub fn can_mut_broadcast(&self, tensor_other: &WGPUTensor) -> bool { + if Arc::strong_count(&self.buffer) > 1 { + return false; + } + + for i in 0..D { + // Output tensor will be different from the mutable tensor. + if self.shape.dims[i] < tensor_other.shape.dims[i] { + return false; + } + } + + true + } + + pub fn can_mut(&self) -> bool { + if Arc::strong_count(&self.buffer) > 1 { + return false; + } + + true + } + + pub fn assert_is_on_save_device(&self, other: &Self) { + if self.context.device != other.context.device { + panic!( + "Both tensors should be on the same device {:?} != {:?}", + self.context.device, other.context.device + ); + } + } +} diff --git a/burn-wgpu/src/tensor/mod.rs b/burn-wgpu/src/tensor/mod.rs new file mode 100644 index 000000000..096c94ead --- /dev/null +++ b/burn-wgpu/src/tensor/mod.rs @@ -0,0 +1,2 @@ +mod base; +pub use base::*;