Feat/wgpu backend setup (#376)

This commit is contained in:
Nathaniel Simard 2023-06-02 11:52:47 -04:00 committed by GitHub
parent 483f9acca5
commit 974fdfaba1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 2485 additions and 24 deletions

View File

@ -14,6 +14,7 @@ members = [
"burn-ndarray", "burn-ndarray",
"burn-no-std-tests", "burn-no-std-tests",
"burn-tch", "burn-tch",
"burn-wgpu",
"burn-tensor-testgen", "burn-tensor-testgen",
"burn-tensor", "burn-tensor",
"burn-train", "burn-train",
@ -53,6 +54,11 @@ syn = "2.0"
tempfile = "3.5.0" tempfile = "3.5.0"
thiserror = "1.0.40" thiserror = "1.0.40"
topological-sort = "0.2.2" topological-sort = "0.2.2"
# WGPU stuff
wgpu = "0.16.0"
futures-intrusive = "0.5"
pollster = "0.3"
# #
# The following packages disable the "std" feature for no_std compatibility # The following packages disable the "std" feature for no_std compatibility
# #

View File

@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
mask = mask.to_device(device).repeat(0, batch_size); mask = mask.to_device(device).repeat(0, batch_size);
mask.equal_elem(1_i64) mask.equal_elem(1_i64.elem::<i64>())
} }
pub struct GeneratePaddingMask<B: Backend> { pub struct GeneratePaddingMask<B: Backend> {

View File

@ -256,12 +256,6 @@ where
K::equal(self.primitive, other.primitive) K::equal(self.primitive, other.primitive)
} }
/// Applies element wise equal comparison and returns a boolean tensor.
pub fn equal_elem<E: Into<K::Elem>>(self, other: E) -> Tensor<B, D, Bool> {
let elem: K::Elem = other.into();
K::equal_elem::<D>(self.primitive, elem)
}
/// Concatenates all tensors into a new one along the given dimension. /// Concatenates all tensors into a new one along the given dimension.
/// ///
/// # Panics /// # Panics
@ -400,7 +394,6 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
lhs: Self::Primitive<D>, lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>, rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>; ) -> Tensor<B, D, Bool>;
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
fn elem_type_name() -> &'static str { fn elem_type_name() -> &'static str {
core::any::type_name::<Self::Elem>() core::any::type_name::<Self::Elem>()
} }
@ -478,10 +471,6 @@ impl<B: Backend> BasicOps<B> for Float {
) -> Tensor<B, D, Bool> { ) -> Tensor<B, D, Bool> {
Tensor::new(B::equal(lhs, rhs)) Tensor::new(B::equal(lhs, rhs))
} }
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::equal_elem(lhs, rhs))
}
} }
impl<B: Backend> BasicOps<B> for Int { impl<B: Backend> BasicOps<B> for Int {
@ -553,10 +542,6 @@ impl<B: Backend> BasicOps<B> for Int {
Tensor::new(B::int_equal(lhs, rhs)) Tensor::new(B::int_equal(lhs, rhs))
} }
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::int_equal_elem(lhs, rhs))
}
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> { fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
B::int_cat(vectors, dim) B::int_cat(vectors, dim)
} }
@ -631,10 +616,6 @@ impl<B: Backend> BasicOps<B> for Bool {
Tensor::new(B::bool_equal(lhs, rhs)) Tensor::new(B::bool_equal(lhs, rhs))
} }
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::bool_equal_elem(lhs, rhs))
}
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> { fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
B::bool_cat(vectors, dim) B::bool_cat(vectors, dim)
} }

View File

@ -133,6 +133,11 @@ where
Self::new(K::sum_dim(self.primitive, dim)) Self::new(K::sum_dim(self.primitive, dim))
} }
/// Applies element wise equal comparison and returns a boolean tensor.
pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
K::equal_elem::<D>(self.primitive, other.elem())
}
/// Applies element wise greater comparison and returns a boolean tensor. /// Applies element wise greater comparison and returns a boolean tensor.
/// ///
/// # Panics /// # Panics
@ -413,6 +418,7 @@ where
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>; fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>; fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>; fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
fn greater<const D: usize>( fn greater<const D: usize>(
lhs: Self::Primitive<D>, lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>, rhs: Self::Primitive<D>,
@ -559,6 +565,9 @@ impl<B: Backend> Numeric<B> for Int {
B::int_mean_dim(tensor, dim) B::int_mean_dim(tensor, dim)
} }
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::int_equal_elem(lhs, rhs))
}
fn greater<const D: usize>( fn greater<const D: usize>(
lhs: Self::Primitive<D>, lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>, rhs: Self::Primitive<D>,
@ -777,6 +786,9 @@ impl<B: Backend> Numeric<B> for Float {
B::mean_dim(tensor, dim) B::mean_dim(tensor, dim)
} }
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
Tensor::new(B::equal_elem(lhs, rhs))
}
fn greater<const D: usize>( fn greater<const D: usize>(
lhs: Self::Primitive<D>, lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>, rhs: Self::Primitive<D>,

View File

@ -79,7 +79,7 @@ pub trait Backend:
/// Tensor primitive to be used for all int operations. /// Tensor primitive to be used for all int operations.
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug; type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
/// Int element type. /// Int element type.
type IntElem: Element + From<i64> + Into<i64>; type IntElem: Element;
/// Tensor primitive to be used for all bool operations. /// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug; type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;

View File

@ -22,7 +22,9 @@ pub trait TensorOps<B: Backend> {
} }
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>; fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>; fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D>; fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
Self::to_data(&tensor)
}
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device; fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
fn to_device<const D: usize>( fn to_device<const D: usize>(
tensor: B::TensorPrimitive<D>, tensor: B::TensorPrimitive<D>,
@ -102,7 +104,9 @@ pub trait TensorOps<B: Backend> {
lhs: B::TensorPrimitive<D>, lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>, rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>; ) -> B::TensorPrimitive<D>;
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>; fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<B::FloatElem>())
}
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> { fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
Self::swap_dims(tensor, D - 2, D - 1) Self::swap_dims(tensor, D - 2, D - 1)
} }

View File

@ -15,4 +15,17 @@ mod tests {
let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]);
assert_eq!(data_expected, data_actual); assert_eq!(data_expected, data_actual);
} }
#[test]
fn test_add_broadcast() {
let data_1 = Data::from([[0.0, 1.0, 2.0]]);
let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2);
let data_actual = (tensor_1 + tensor_2).into_data();
let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]);
assert_eq!(data_expected, data_actual);
}
} }

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{Data, Tensor}; use burn_tensor::{Data, Tensor};
#[test] #[test]
fn should_support_exp_ops() { fn should_support_log_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data); let tensor = Tensor::<TestBackend, 2>::from_data(data);

32
burn-wgpu/Cargo.toml Normal file
View File

@ -0,0 +1,32 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "WGPU backend for burn"
edition = "2021"
keywords = ["deep-learning", "machine-learning", "data"]
license = "MIT/Apache-2.0"
name = "burn-wgpu"
readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
version = "0.8.0"
[dependencies]
burn-tensor = {path = "../burn-tensor", version = "0.8.0"}
burn-common = {path = "../burn-common", version = "0.8.0"}
derive-new = {workspace = true}
bytemuck = {workspace = true}
rand = {workspace = true}
num-traits = {workspace = true}
# WGPU stuff
wgpu = {workspace = true}
futures-intrusive = {workspace = true}
pollster = {workspace = true}
[dev-dependencies]
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
"export_tests",
]}
burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = false, features = [
"export_tests",
]}

1
burn-wgpu/LICENSE-APACHE Symbolic link
View File

@ -0,0 +1 @@
../LICENSE-APACHE

1
burn-wgpu/LICENSE-MIT Symbolic link
View File

@ -0,0 +1 @@
../LICENSE-MIT

3
burn-wgpu/README.md Normal file
View File

@ -0,0 +1,3 @@
# Burn WGPU Backend
[Burn](https://github.com/burn-rs/burn) WGPU backend

45
burn-wgpu/src/backend.rs Normal file
View File

@ -0,0 +1,45 @@
use burn_tensor::backend::Backend;
use rand::{rngs::StdRng, SeedableRng};
use crate::{
element::{FloatElement, IntElement},
tensor::WGPUTensor,
GraphicsAPI, WGPUDevice,
};
use std::{marker::PhantomData, sync::Mutex};
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
#[derive(Debug, Default, Clone)]
pub struct WGPUBackend<G: GraphicsAPI, F: FloatElement, I: IntElement> {
_g: PhantomData<G>,
_f: PhantomData<F>,
_i: PhantomData<I>,
}
impl<G: GraphicsAPI + 'static, F: FloatElement, I: IntElement> Backend for WGPUBackend<G, F, I> {
type Device = WGPUDevice;
type FullPrecisionBackend = WGPUBackend<G, f32, i32>;
type FullPrecisionElem = f32;
type FloatElem = F;
type IntElem = I;
type TensorPrimitive<const D: usize> = WGPUTensor<F, D>;
type IntTensorPrimitive<const D: usize> = WGPUTensor<I, D>;
type BoolTensorPrimitive<const D: usize> = WGPUTensor<u32, D>;
fn name() -> String {
String::from("wgpu")
}
fn seed(seed: u64) {
let rng = StdRng::seed_from_u64(seed);
let mut seed = SEED.lock().unwrap();
*seed = Some(rng);
}
fn ad_enabled() -> bool {
false
}
}

256
burn-wgpu/src/context.rs Normal file
View File

@ -0,0 +1,256 @@
use burn_common::id::IdGenerator;
use std::{
any::TypeId,
borrow::Cow,
collections::HashMap,
sync::{Arc, Mutex},
};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Buffer, DeviceDescriptor, DeviceType, ShaderModule, ShaderModuleDescriptor,
};
use crate::{kernel::KernelGenerator, GraphicsAPI, WGPUDevice};
/// The context is the basic struct that allows to execute GPU kernel on devices.
///
/// You can access a context for a [wgpu device](WGPUDevice) using [get_context](crate::pool::get_context).
#[derive(Debug)]
pub struct Context {
id: String,
queue: wgpu::Queue,
device_wgpu: wgpu::Device,
cache: Mutex<HashMap<TypeId, Arc<ShaderModule>>>,
pub(crate) device: WGPUDevice,
}
#[derive(new, Clone, Debug)]
pub struct WorkGroup {
pub x: u32,
pub y: u32,
pub z: u32,
}
impl Context {
pub(crate) fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
// Instantiates instance of WebGPU
let instance = wgpu::Instance::default();
// `request_adapter` instantiates the general connection to the GPU
let adapters = instance.enumerate_adapters(G::backend().into());
let mut adapters = adapters
.filter(|adapter| {
let device_type = adapter.get_info().device_type;
match device {
WGPUDevice::DiscreteGPU(_) => device_type == DeviceType::DiscreteGpu,
WGPUDevice::IntegratedGPU(_) => device_type == DeviceType::IntegratedGpu,
WGPUDevice::VirtualGPU(_) => device_type == DeviceType::VirtualGpu,
WGPUDevice::CPU => device_type == DeviceType::Cpu,
}
})
.collect::<Vec<_>>();
let adapter = match device {
WGPUDevice::DiscreteGPU(num) => {
assert!(adapters.len() > *num, "No Discrete GPU device found");
adapters.remove(*num)
}
WGPUDevice::IntegratedGPU(num) => {
assert!(adapters.len() > *num, "No Integrated GPU device found");
adapters.remove(*num)
}
WGPUDevice::VirtualGPU(num) => {
assert!(adapters.len() > *num, "No Virtual GPU device found");
adapters.remove(*num)
}
WGPUDevice::CPU => {
assert!(!adapters.is_empty(), "No CPU device found");
adapters.remove(0)
}
};
let device_wgpu = device.clone();
let (device, queue) = pollster::block_on(adapter.request_device(
&DeviceDescriptor {
label: None,
features: wgpu::Features::empty(),
limits: wgpu::Limits::downlevel_defaults(),
},
None,
))
.expect("Unable to request the device with the adapter");
Self {
id: IdGenerator::generate(),
queue,
device_wgpu: device,
device: device_wgpu,
cache: Mutex::new(HashMap::new()),
}
}
/// Create a new buffer with the provided size.
pub fn create_buffer(&self, size: usize) -> Buffer {
self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size as u64,
usage: wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}
/// Create a new buffer initialized with the provided bytes.
pub fn create_buffer_with_data(&self, data: &[u8]) -> Buffer {
let buffer_src = self.device_wgpu.create_buffer_init(&BufferInitDescriptor {
label: Some("Buffer Src"),
contents: data,
usage: wgpu::BufferUsages::COPY_SRC,
});
let buffer = self.create_buffer(buffer_src.size() as usize);
// Create a command encoder
let mut encoder =
self.device_wgpu
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Command Encoder"),
});
// Copy data from the staging buffer to the target buffer
encoder.copy_buffer_to_buffer(&buffer_src, 0, &buffer, 0, buffer_src.size());
// Submit the command encoder to the queue
self.queue.submit(std::iter::once(encoder.finish()));
buffer
}
/// Read a buffer from the GPU and return its content as bytes.
pub fn buffer_to_data(&self, buffer: &Buffer) -> Vec<u8> {
let size = buffer.size();
let buffer_dest = self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Create a command encoder
let mut encoder =
self.device_wgpu
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Command Encoder"),
});
encoder.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size);
self.queue.submit(std::iter::once(encoder.finish()));
let buffer_slice = buffer_dest.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
.send(v)
.expect("Unable to send buffer slice result to async channel.")
});
self.device_wgpu.poll(wgpu::Maintain::Wait);
let result = pollster::block_on(receiver.receive());
if let Some(Ok(())) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
buffer_dest.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
}
/// Compile a kernel template if not present in the cache.
pub fn compile<K: KernelGenerator>(&self) -> Arc<ShaderModule> {
let mut cache = self.cache.lock().unwrap();
let template_id = TypeId::of::<K>();
if let Some(module) = cache.get(&template_id) {
return module.clone();
}
let source = K::generate();
let module = self
.device_wgpu
.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source.as_ref())),
});
let module = Arc::new(module);
cache.insert(template_id, module.clone());
module
}
/// Execute a kernel using the provided buffers.
///
/// # Notes
///
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
/// buffer can be mutated when lauching a compute shaders with write access to a buffer.
///
/// Buffer positions are used as bindings when lauching a compute kernel.
pub fn execute(&self, work_group: &WorkGroup, kernel: &ShaderModule, buffers: &[&Buffer]) {
let pipeline = self
.device_wgpu
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: kernel,
entry_point: "main",
});
let group_layout = pipeline.get_bind_group_layout(0);
let entries = buffers
.iter()
.enumerate()
.map(|(i, buffer)| wgpu::BindGroupEntry {
binding: i as u32,
resource: buffer.as_entire_binding(),
})
.collect::<Vec<_>>();
let bind_group = self
.device_wgpu
.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &group_layout,
entries: &entries,
});
let mut encoder = self
.device_wgpu
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
let mut compute = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
compute.set_pipeline(&pipeline);
compute.set_bind_group(0, &bind_group, &[]);
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
std::mem::drop(compute);
self.queue.submit(Some(encoder.finish()));
}
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}

25
burn-wgpu/src/device.rs Normal file
View File

@ -0,0 +1,25 @@
/// The device struct when using the `wgpu` backend.
///
/// Note that you need to provide the device index when using a GPU backend.
///
/// # Example
///
/// ```no_run
/// use burn_wgpu::WGPUDevice;
///
/// let device_gpu_1 = WGPUDevice::DiscreteGPU(0); // First discrete GPU found.
/// let device_gpu_2 = WGPUDevice::DiscreteGPU(1); // Second discrete GPU found.
/// ```
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub enum WGPUDevice {
DiscreteGPU(usize),
IntegratedGPU(usize),
VirtualGPU(usize),
CPU,
}
impl Default for WGPUDevice {
fn default() -> Self {
Self::CPU
}
}

66
burn-wgpu/src/element.rs Normal file
View File

@ -0,0 +1,66 @@
use burn_tensor::Element;
pub trait WGPUElement: core::fmt::Debug + 'static + Clone
where
Self: Sized,
{
fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self];
}
pub trait FloatElement: WGPUElement + Element {}
pub trait IntElement: WGPUElement + Element {}
impl WGPUElement for u32 {
fn type_name() -> &'static str {
"u32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
}
impl WGPUElement for i32 {
fn type_name() -> &'static str {
"i32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
}
impl WGPUElement for i64 {
fn type_name() -> &'static str {
"i64"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
}
impl WGPUElement for f32 {
fn type_name() -> &'static str {
"f32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
}
impl FloatElement for f32 {}
impl IntElement for i32 {}
impl IntElement for i64 {}

61
burn-wgpu/src/graphics.rs Normal file
View File

@ -0,0 +1,61 @@
/// The basic trait to specify which graphics API to use as Backend.
///
/// Options are:
/// - [Vulkan](Vulkan)
/// - [Metal](Metal)
/// - [OpenGL](OpenGL)
/// - [DirectX 11](Dx11)
/// - [DirectX 12](Dx12)
/// - [WebGPU](WebGPU)
pub trait GraphicsAPI: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
fn backend() -> wgpu::Backend;
}
#[derive(Default, Debug, Clone)]
pub struct Vulkan;
#[derive(Default, Debug, Clone)]
pub struct Metal;
#[derive(Default, Debug, Clone)]
pub struct OpenGL;
#[derive(Default, Debug, Clone)]
pub struct Dx11;
#[derive(Default, Debug, Clone)]
pub struct Dx12;
#[derive(Default, Debug, Clone)]
pub struct WebGPU;
impl GraphicsAPI for Vulkan {
fn backend() -> wgpu::Backend {
wgpu::Backend::Vulkan
}
}
impl GraphicsAPI for Metal {
fn backend() -> wgpu::Backend {
wgpu::Backend::Metal
}
}
impl GraphicsAPI for OpenGL {
fn backend() -> wgpu::Backend {
wgpu::Backend::Gl
}
}
impl GraphicsAPI for Dx11 {
fn backend() -> wgpu::Backend {
wgpu::Backend::Dx11
}
}
impl GraphicsAPI for Dx12 {
fn backend() -> wgpu::Backend {
wgpu::Backend::Dx12
}
}
impl GraphicsAPI for WebGPU {
fn backend() -> wgpu::Backend {
wgpu::Backend::BrowserWebGpu
}
}

View File

@ -0,0 +1,87 @@
use crate::element::WGPUElement;
use std::marker::PhantomData;
/// Generate wgpu kernel source code to create [compute shader modules](wgpu::ShaderModule).
pub trait KernelGenerator: 'static {
/// Source code concrete type.
type Source: AsRef<str>;
/// Generate the source code.
fn generate() -> Self::Source;
}
#[macro_export]
macro_rules! kernel_wgsl {
(
$struct:ident,
$file:expr
) => {
#[derive(new)]
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = &'static str;
fn generate() -> Self::Source {
include_str!($file)
}
}
};
}
/// Generate kernel source code by replacing some information using templating.
pub struct KernelSettings<
K: KernelGenerator,
E: WGPUElement,
I: WGPUElement,
const WORKGROUP_X_SIZE: usize,
const WORKGROUP_Y_SIZE: usize,
const WORKGROUP_Z_SIZE: usize,
> {
_k: PhantomData<K>,
_e: PhantomData<E>,
_i: PhantomData<I>,
}
impl<
K: KernelGenerator,
E: WGPUElement,
I: WGPUElement,
const WORKGROUP_X_SIZE: usize,
const WORKGROUP_Y_SIZE: usize,
const WORKGROUP_Z_SIZE: usize,
> KernelGenerator
for KernelSettings<K, E, I, WORKGROUP_X_SIZE, WORKGROUP_Y_SIZE, WORKGROUP_Z_SIZE>
{
type Source = String;
fn generate() -> String {
let mut source = K::generate().as_ref().to_string();
source = source.replace("WORKGROUP_SIZE_X", &WORKGROUP_X_SIZE.to_string());
source = source.replace("WORKGROUP_SIZE_Y", &WORKGROUP_Y_SIZE.to_string());
source = source.replace("WORKGROUP_SIZE_Z", &WORKGROUP_Y_SIZE.to_string());
source = source.replace("elem", E::type_name());
source = source.replace("int", I::type_name());
source
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::any::TypeId;
#[test]
fn test_kernel_type_id() {
kernel_wgsl!(Add, "../template/binary_elemwise.wgsl");
let type_id_1 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
let type_id_2 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 5>>();
let type_id_3 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
assert_ne!(type_id_1, type_id_2);
assert_eq!(type_id_1, type_id_3);
}
}

View File

@ -0,0 +1,159 @@
use super::{KernelGenerator, KernelSettings};
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
use burn_tensor::Shape;
use num_traits::ToPrimitive;
use std::sync::Arc;
kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl");
kernel_wgsl!(
BinaryElemwiseInplaceRaw,
"../template/binary_elemwise_inplace.wgsl"
);
#[macro_export]
macro_rules! binary_elemwise {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::BinaryElemwiseRaw::generate().to_string();
let body = format!(
"output[global_id.x] = lhs[index_lhs] {} rhs[index_rhs]",
$ops
);
source.replace("BODY", &body)
}
}
};
}
#[macro_export]
macro_rules! binary_elemwise_inplace {
(
$struct:ident,
$ops:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::BinaryElemwiseInplaceRaw::generate().to_string();
let body = format!(
"lhs[global_id.x] = lhs[global_id.x] {} rhs[index_rhs];",
$ops
);
source.replace("BODY", &body)
}
}
};
}
pub fn binary_elemwise<K: KernelGenerator, E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
lhs.assert_is_on_save_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
let output = WGPUTensor::new(lhs.context.clone(), shape_out, Arc::new(buffer));
let kernel = lhs
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
lhs.strides
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
rhs.strides
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
lhs.shape
.dims
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
rhs.shape
.dims
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
lhs.context.execute(
&WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
output
}
pub fn binary_elemwise_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
lhs.assert_is_on_save_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let kernel = lhs
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
rhs.strides
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
rhs.shape
.dims
.into_iter()
.for_each(|v| info.push(v.to_u32().unwrap()));
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
lhs.context.execute(
&WorkGroup::new(
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&lhs.buffer, &rhs.buffer, &info_buffers],
);
lhs
}

View File

@ -0,0 +1,9 @@
mod base;
mod binary_elemwise;
mod unary;
mod unary_scalar;
pub use base::*;
pub use binary_elemwise::*;
pub use unary::*;
pub use unary_scalar::*;

View File

@ -0,0 +1,120 @@
use super::{KernelGenerator, KernelSettings};
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
use std::sync::Arc;
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
#[macro_export]
macro_rules! unary {
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryRaw::generate().to_string();
let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func);
source.replace("BODY", &body)
}
}
};
(
$struct:ident,
body $body:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryRaw::generate().to_string();
source.replace("BODY", $body)
}
}
};
}
#[macro_export]
macro_rules! unary_inplace {
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func);
source.replace("BODY", &body)
}
}
};
(
$struct:ident,
body $body:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
source.replace("BODY", $body)
}
}
};
}
pub fn unary<K: KernelGenerator, E: WGPUElement, const D: usize>(
input: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
let buffer = input
.context
.create_buffer(input.shape.num_elements() * core::mem::size_of::<E>());
let output = WGPUTensor::new(input.context.clone(), input.shape, Arc::new(buffer));
let kernel = input
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
input.context.execute(
&WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&input.buffer, &output.buffer],
);
output
}
pub fn unary_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
input: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
let kernel = input
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
input.context.execute(
&WorkGroup::new(
f32::ceil(input.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&input.buffer],
);
input
}

View File

@ -0,0 +1,135 @@
use super::{KernelGenerator, KernelSettings};
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
use std::sync::Arc;
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
kernel_wgsl!(
UnaryScalarInplaceRaw,
"../template/unary_scalar_inplace.wgsl"
);
#[macro_export]
macro_rules! unary_scalar {
(
$struct:ident,
ops $ops:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
let body = format!("output[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
source.replace("BODY", &body)
}
}
};
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
let body = format!("output[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
source.replace("BODY", &body)
}
}
};
}
#[macro_export]
macro_rules! unary_scalar_inplace {
(
$struct:ident,
ops $ops:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
let body = format!("lhs[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
source.replace("BODY", &body)
}
}
};
(
$struct:ident,
func $func:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
let body = format!("lhs[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
source.replace("BODY", &body)
}
}
};
}
pub fn unary_scalar<K: KernelGenerator, E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
scalar: E,
) -> WGPUTensor<E, D> {
let buffer = lhs
.context
.create_buffer(lhs.shape.num_elements() * core::mem::size_of::<E>());
let output = WGPUTensor::new(lhs.context.clone(), lhs.shape, Arc::new(buffer));
let kernel = lhs
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
lhs.context.execute(
&WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&lhs.buffer, &rhs_buffer, &output.buffer],
);
output
}
pub fn unary_scalar_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
scalar: E,
) -> WGPUTensor<E, D> {
let kernel = lhs
.context
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
lhs.context.execute(
&WorkGroup::new(
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
&kernel,
&[&lhs.buffer, &rhs_buffer],
);
lhs
}

39
burn-wgpu/src/lib.rs Normal file
View File

@ -0,0 +1,39 @@
#[macro_use]
extern crate derive_new;
mod ops;
pub(crate) mod context;
pub(crate) mod element;
pub(crate) mod kernel;
pub(crate) mod pool;
pub(crate) mod tensor;
mod device;
pub use device::*;
mod backend;
pub use backend::*;
mod graphics;
pub use graphics::*;
#[cfg(test)]
mod tests {
type TestBackend = crate::WGPUBackend<crate::Vulkan, f32, i64>;
burn_tensor::testgen_add!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_div!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_log!();
burn_tensor::testgen_relu!();
// Once all operations will be implemented.
// type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
// type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
// burn_tensor::testgen_all!();
}

View File

@ -0,0 +1,27 @@
use burn_tensor::ops::ActivationOps;
use crate::{
element::{FloatElement, IntElement},
kernel::{unary, unary_inplace},
unary, unary_inplace, GraphicsAPI, WGPUBackend,
};
use super::FloatTensor;
impl<G, F, I> ActivationOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
where
G: GraphicsAPI + 'static,
F: FloatElement,
I: IntElement,
{
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Relu, body "output[global_id.x] = max(input[global_id.x], 0.0);");
unary_inplace!(ReluInplace, body "input[global_id.x] = max(input[global_id.x], 0.0);");
if tensor.can_mut() {
return unary_inplace::<ReluInplace, F, D>(tensor);
}
unary::<Relu, F, D>(tensor)
}
}

59
burn-wgpu/src/ops/base.rs Normal file
View File

@ -0,0 +1,59 @@
use std::{marker::PhantomData, sync::Arc};
use burn_tensor::{backend::Backend, Data, Shape};
use crate::{element::WGPUElement, pool::get_context, tensor::WGPUTensor, GraphicsAPI, WGPUDevice};
pub type FloatElem<B> = <B as Backend>::FloatElem;
pub type Device<B> = <B as Backend>::Device;
pub type FloatTensor<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
pub type IntElem<B> = <B as Backend>::IntElem;
pub type IntTensor<B, const D: usize> = <B as Backend>::IntTensorPrimitive<D>;
pub type BoolTensor<B, const D: usize> = <B as Backend>::BoolTensorPrimitive<D>;
pub struct BaseOps<G: GraphicsAPI> {
_g: PhantomData<G>,
}
impl<G: GraphicsAPI> BaseOps<G> {
pub fn from_data<E: WGPUElement, const D: usize>(
data: Data<E, D>,
device: &WGPUDevice,
) -> WGPUTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer_with_data(E::as_bytes(&data.value));
WGPUTensor::new(context, data.shape, Arc::new(buffer))
}
pub fn to_data<E: WGPUElement, const D: usize>(tensor: &WGPUTensor<E, D>) -> Data<E, D> {
let bytes = tensor.context.buffer_to_data(&tensor.buffer);
let values = E::from_bytes(&bytes);
Data::new(values.to_vec(), tensor.shape.clone())
}
pub fn to_device<E: WGPUElement, const D: usize>(
tensor: WGPUTensor<E, D>,
device: &WGPUDevice,
) -> WGPUTensor<E, D> {
if &tensor.context.device == device {
return tensor;
}
let context = get_context::<G>(device);
tensor.to_context(context)
}
pub fn empty<E: WGPUElement, const D: usize>(
shape: Shape<D>,
device: &WGPUDevice,
) -> WGPUTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WGPUTensor::new(context, shape, Arc::new(buffer))
}
}

View File

@ -0,0 +1,106 @@
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
use crate::{
element::{FloatElement, IntElement},
GraphicsAPI, WGPUBackend,
};
use super::{BaseOps, BoolTensor, Device, IntTensor};
impl<G, F, I> BoolTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
where
G: GraphicsAPI + 'static,
F: FloatElement,
I: IntElement,
{
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
BaseOps::<G>::empty(shape, device)
}
fn bool_shape<const D: usize>(tensor: &BoolTensor<Self, D>) -> Shape<D> {
tensor.shape.clone()
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
let data = BaseOps::<G>::to_data(&tensor);
Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape)
}
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
device: &Device<Self>,
) -> BoolTensor<Self, D> {
let data: Data<u32, D> = Data::new(
data.value
.into_iter()
.map(|c| match c {
true => 1,
false => 0,
})
.collect(),
data.shape,
);
BaseOps::<G>::from_data(data, device)
}
fn bool_into_int<const D: usize>(_tensor: BoolTensor<Self, D>) -> IntTensor<Self, D> {
todo!()
}
fn bool_device<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::Device {
todo!()
}
fn bool_to_device<const D: usize>(
tensor: BoolTensor<Self, D>,
device: &Device<Self>,
) -> BoolTensor<Self, D> {
BaseOps::<G>::to_device(tensor, device)
}
fn bool_reshape<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
_shape: Shape<D2>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D2> {
todo!()
}
fn bool_index<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
todo!()
}
fn bool_index_assign<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
_value: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
todo!()
}
fn bool_cat<const D: usize>(
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn bool_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn bool_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_rhs: bool,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
}

View File

@ -0,0 +1,406 @@
use super::numeric::NumericOps;
use super::{BaseOps, Device, FloatElem, FloatTensor};
use crate::kernel::{unary, unary_inplace, unary_scalar, unary_scalar_inplace};
use crate::{
element::{FloatElement, IntElement},
unary, unary_inplace, GraphicsAPI, WGPUBackend, SEED,
};
use crate::{unary_scalar, unary_scalar_inplace};
use burn_common::rand::get_seeded_rng;
use burn_tensor::ElementConversion;
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, Shape};
impl<G, F, I> TensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
where
G: GraphicsAPI + 'static,
F: FloatElement,
I: IntElement,
{
fn from_data<const D: usize>(
data: Data<FloatElem<Self>, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
BaseOps::<G>::from_data(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<FloatElem<Self>>,
device: &Device<Self>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
let mut seed = SEED.lock().unwrap();
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
rng_seeded.clone()
} else {
get_seeded_rng()
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}
fn shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
tensor.shape.clone()
}
fn to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
BaseOps::<G>::to_data(tensor)
}
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
}
fn to_device<const D: usize>(
tensor: FloatTensor<Self, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
BaseOps::<G>::to_device(tensor, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
BaseOps::<G>::empty(shape, device)
}
fn add<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::add(lhs, rhs)
}
fn add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::add_scalar(lhs, rhs)
}
fn sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::sub(lhs, rhs)
}
fn sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::sub_scalar(lhs, rhs)
}
fn mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::mul(lhs, rhs)
}
fn mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::mul_scalar(lhs, rhs)
}
fn div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::div(lhs, rhs)
}
fn div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::div_scalar(lhs, rhs)
}
fn matmul<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn swap_dims<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim1: usize,
_dim2: usize,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn reshape<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
_shape: Shape<D2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2> {
todo!()
}
fn gather<const D: usize>(
_dim: usize,
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn scatter<const D: usize>(
_dim: usize,
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn index_select<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn index_select_assign<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
_dim: usize,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
todo!()
}
fn index<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
todo!()
}
fn index_assign<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
todo!()
}
fn mask_scatter<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_source: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn mask_fill<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_value: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn greater<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn greater_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn greater_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn greater_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn lower<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn lower_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn lower_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn lower_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn sum<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
todo!()
}
fn sum_dim<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn mean<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
todo!()
}
fn mean_dim<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn to_full_precision<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>
{
todo!()
}
fn from_full_precision<const D: usize>(
_tensor: <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Exp, func "exp");
unary_inplace!(ExpInplace, func "exp");
if lhs.can_mut() {
return unary_inplace::<ExpInplace, F, D>(lhs);
}
unary::<Exp, F, D>(lhs)
}
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Log, func "log");
unary_inplace!(LogInplace, func "log");
if tensor.can_mut() {
return unary_inplace::<LogInplace, F, D>(tensor);
}
unary::<Log, F, D>(tensor)
}
fn log1p<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
unary_scalar!(Powf, func "pow");
unary_scalar_inplace!(PowfInplace, func "pow");
if lhs.can_mut() {
return unary_scalar_inplace::<PowfInplace, F, D>(lhs, rhs.elem());
}
unary_scalar::<Powf, F, D>(lhs, rhs.elem())
}
fn sqrt<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn cos<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn sin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn tanh<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn erf<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn cat<const D: usize>(
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
}
fn argmax<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn argmin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
}

View File

@ -0,0 +1,314 @@
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
use crate::{
element::{FloatElement, IntElement},
GraphicsAPI, WGPUBackend,
};
use super::{numeric::NumericOps, BaseOps, Device, IntElem, IntTensor};
impl<G, F, I> IntTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
where
G: GraphicsAPI + 'static,
F: FloatElement,
I: IntElement,
{
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
BaseOps::<G>::empty(shape, device)
}
fn int_shape<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> Shape<D> {
todo!()
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
BaseOps::<G>::to_data(&tensor)
}
fn int_from_data<const D: usize>(
data: Data<I, D>,
device: &Device<Self>,
) -> IntTensor<Self, D> {
BaseOps::<G>::from_data(data, device)
}
fn int_device<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::Device {
todo!()
}
fn int_to_device<const D: usize>(
tensor: IntTensor<Self, D>,
device: &Device<Self>,
) -> IntTensor<Self, D> {
BaseOps::<G>::to_device(tensor, device)
}
fn int_reshape<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
_shape: Shape<D2>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2> {
todo!()
}
fn int_index<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
todo!()
}
fn int_index_assign<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
_indexes: [std::ops::Range<usize>; D2],
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
todo!()
}
fn int_mask_scatter<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_source: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_mask_fill<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
_value: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_gather<const D: usize>(
_dim: usize,
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_scatter<const D: usize>(
_dim: usize,
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_index_select_dim<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
_dim: usize,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
todo!()
}
fn int_cat<const D: usize>(
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_greater<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_greater_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_greater_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_greater_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_lower<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_lower_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_lower_equal<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_lower_equal_elem<const D: usize>(
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
todo!()
}
fn int_add<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::add::<I, D>(lhs, rhs)
}
fn int_add_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::add_scalar(lhs, rhs)
}
fn int_sub<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::sub(lhs, rhs)
}
fn int_sub_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::sub_scalar(lhs, rhs)
}
fn int_mul<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::mul(lhs, rhs)
}
fn int_mul_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::mul_scalar(lhs, rhs)
}
fn int_div<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::div(lhs, rhs)
}
fn int_div_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::div_scalar(lhs, rhs)
}
fn int_neg<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_zeros<const D: usize>(
_shape: Shape<D>,
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_ones<const D: usize>(
_shape: Shape<D>,
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_sum<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
todo!()
}
fn int_sum_dim<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_mean<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
todo!()
}
fn int_mean_dim<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_argmax<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_argmin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
}

10
burn-wgpu/src/ops/mod.rs Normal file
View File

@ -0,0 +1,10 @@
mod activation_ops;
mod bool_ops;
mod float_ops;
mod int_ops;
mod module_ops;
mod base;
pub(crate) use base::*;
pub(crate) mod numeric;

View File

@ -0,0 +1,94 @@
use burn_tensor::{backend::Backend, ops::ModuleOps};
use crate::{
element::{FloatElement, IntElement},
GraphicsAPI, WGPUBackend,
};
impl<G, F, I> ModuleOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
where
G: GraphicsAPI + 'static,
F: FloatElement,
I: IntElement,
{
fn embedding(
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3> {
todo!()
}
fn embedding_backward(
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
_output: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2> {
todo!()
}
fn conv2d(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
_options: burn_tensor::ops::ConvOptions<2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
todo!()
}
fn conv_transpose2d(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
_options: burn_tensor::ops::ConvTransposeOptions<2>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
todo!()
}
fn avg_pool2d(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
todo!()
}
fn avg_pool2d_backward(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
todo!()
}
fn max_pool2d(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
todo!()
}
fn max_pool2d_with_indexes(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
) -> burn_tensor::ops::MaxPool2dWithIndexes<WGPUBackend<G, F, I>> {
todo!()
}
fn max_pool2d_with_indexes_backward(
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_output_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
) -> burn_tensor::ops::MaxPool2dBackward<WGPUBackend<G, F, I>> {
todo!()
}
}

View File

@ -0,0 +1,129 @@
use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace};
use crate::{
binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WGPUTensor,
unary_scalar, unary_scalar_inplace,
};
pub struct NumericOps;
impl NumericOps {
pub fn add<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
binary_elemwise!(Add, "+");
binary_elemwise_inplace!(AddInplace, "+");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace::<AddInplace, E, D>(lhs, rhs);
}
if rhs.can_mut_broadcast(&lhs) {
return binary_elemwise_inplace::<AddInplace, E, D>(rhs, lhs);
}
binary_elemwise::<Add, E, D>(lhs, rhs)
}
pub fn add_scalar<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: E,
) -> WGPUTensor<E, D> {
unary_scalar!(AddScalar, ops "+");
unary_scalar_inplace!(AddScalarInplace, ops "+");
if lhs.can_mut() {
return unary_scalar_inplace::<AddScalarInplace, E, D>(lhs, rhs);
}
unary_scalar::<AddScalar, E, D>(lhs, rhs)
}
pub fn sub<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
binary_elemwise!(Sub, "-");
binary_elemwise_inplace!(SubInplace, "-");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace::<SubInplace, E, D>(lhs, rhs);
}
binary_elemwise::<Sub, E, D>(lhs, rhs)
}
pub fn sub_scalar<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: E,
) -> WGPUTensor<E, D> {
unary_scalar!(SubScalar, ops "-");
unary_scalar_inplace!(SubScalarInplace, ops "-");
if lhs.can_mut() {
return unary_scalar_inplace::<SubScalarInplace, E, D>(lhs, rhs);
}
unary_scalar::<SubScalar, E, D>(lhs, rhs)
}
pub fn mul<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
binary_elemwise!(Mul, "*");
binary_elemwise_inplace!(MulInplace, "*");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace::<MulInplace, E, D>(lhs, rhs);
}
if rhs.can_mut_broadcast(&lhs) {
return binary_elemwise_inplace::<MulInplace, E, D>(rhs, lhs);
}
binary_elemwise::<Mul, E, D>(lhs, rhs)
}
pub fn mul_scalar<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: E,
) -> WGPUTensor<E, D> {
unary_scalar!(MulScalar, ops "*");
unary_scalar_inplace!(MulScalarInplace, ops "*");
if lhs.can_mut() {
return unary_scalar_inplace::<MulScalarInplace, E, D>(lhs, rhs);
}
unary_scalar::<MulScalar, E, D>(lhs, rhs)
}
pub fn div<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: WGPUTensor<E, D>,
) -> WGPUTensor<E, D> {
binary_elemwise!(Div, "/");
binary_elemwise_inplace!(DivInplace, "/");
if lhs.can_mut_broadcast(&rhs) {
return binary_elemwise_inplace::<DivInplace, E, D>(lhs, rhs);
}
binary_elemwise::<Div, E, D>(lhs, rhs)
}
pub fn div_scalar<E: WGPUElement, const D: usize>(
lhs: WGPUTensor<E, D>,
rhs: E,
) -> WGPUTensor<E, D> {
unary_scalar!(DivScalar, ops "/");
unary_scalar_inplace!(DivScalarInplace, ops "/");
if lhs.can_mut() {
return unary_scalar_inplace::<DivScalarInplace, E, D>(lhs, rhs);
}
unary_scalar::<DivScalar, E, D>(lhs, rhs)
}
}

63
burn-wgpu/src/pool.rs Normal file
View File

@ -0,0 +1,63 @@
use crate::{context::Context, GraphicsAPI, WGPUDevice};
use std::{
any::TypeId,
collections::HashMap,
sync::{Arc, Mutex},
};
static POOL_CONTEXT: Mutex<Option<ContextPool>> = Mutex::new(None);
#[derive(Default)]
struct ContextPool {
contexts: HashMap<Key, Arc<Context>>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Key {
api_id: TypeId,
device: WGPUDevice,
}
impl Key {
fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
Self {
api_id: TypeId::of::<G>(),
device: device.clone(),
}
}
}
/// Get a [context](Context) for the given [device](WGPUDevice).
///
/// # Notes
///
/// If a context already exist for the current [device](WGPUDevice), the same instance will be
/// returned.
pub fn get_context<G: GraphicsAPI>(device: &WGPUDevice) -> Arc<Context> {
let mut pool = POOL_CONTEXT.lock().unwrap();
let context = if let Some(pool) = pool.as_mut() {
// Fetch device in pool
match pool.contexts.get(&Key::new::<G>(device)) {
Some(context) => context.clone(),
None => {
// Init new device
let context = Arc::new(Context::new::<G>(device));
pool.contexts.insert(Key::new::<G>(device), context.clone());
context
}
}
} else {
// Initialize pool
let context = Arc::new(Context::new::<G>(device));
let mut new_pool = ContextPool::default();
new_pool
.contexts
.insert(Key::new::<G>(device), context.clone());
*pool = Some(new_pool);
context
};
context
}

View File

@ -0,0 +1,35 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<elem>;
@group(0)
@binding(1)
var<storage, read> rhs: array<elem>;
@group(0)
@binding(2)
var<storage, read_write> output: array<elem>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let dim: u32 = info[0];
var index_lhs: u32 = 0u;
var index_rhs: u32 = 0u;
for (var i: u32 = 0u; i < dim; i++) {
let stride_lhs = info[i + 1u];
let stride_rhs = info[i + 1u * dim + 1u];
let shape_lhs = info[i + 2u * dim + 1u];
let shape_rhs = info[i + 3u * dim + 1u];
index_lhs += global_id.x / stride_lhs % shape_lhs * stride_lhs;
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
}
BODY
}

View File

@ -0,0 +1,27 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<elem>;
@group(0)
@binding(1)
var<storage, read> rhs: array<elem>;
@group(0)
@binding(2)
var<storage, read> info: array<u32>;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let dim: u32 = info[0];
var index_rhs: u32 = 0u;
for (var i: u32 = 0u; i < dim; i++) {
let stride_rhs = info[i + 1u];
let shape_rhs = info[i + 1u * dim + 1u];
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
}
BODY
}

View File

@ -0,0 +1,13 @@
@group(0)
@binding(0)
var<storage, read> input: array<elem>;
@group(0)
@binding(1)
var<storage, read_write> output: array<elem>;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
BODY
}

View File

@ -0,0 +1,9 @@
@group(0)
@binding(0)
var<storage, read_write> input: array<elem>;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
BODY
}

View File

@ -0,0 +1,17 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<elem>;
@group(0)
@binding(1)
var<storage, read> rhs: elem;
@group(0)
@binding(2)
var<storage, read_write> output: array<elem>;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
BODY
}

View File

@ -0,0 +1,13 @@
@group(0)
@binding(0)
var<storage, read_write> lhs: array<elem>;
@group(0)
@binding(1)
var<storage, read> rhs: elem;
@compute
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
BODY
}

View File

@ -0,0 +1,82 @@
use burn_tensor::Shape;
use std::{marker::PhantomData, sync::Arc};
use wgpu::Buffer;
use crate::{context::Context, element::WGPUElement};
#[derive(Debug, Clone)]
pub struct WGPUTensor<E: WGPUElement, const D: usize> {
pub(crate) context: Arc<Context>,
pub(crate) buffer: Arc<Buffer>,
pub(crate) shape: Shape<D>,
pub(crate) strides: [usize; D],
elem: PhantomData<E>,
}
impl<E: WGPUElement, const D: usize> WGPUTensor<E, D> {
pub fn new(context: Arc<Context>, shape: Shape<D>, buffer: Arc<Buffer>) -> Self {
let mut strides = [0; D];
let mut current = 1;
shape
.dims
.iter()
.enumerate()
.rev()
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
Self {
context,
buffer,
shape,
strides,
elem: PhantomData::default(),
}
}
pub fn to_context(&self, context: Arc<Context>) -> Self {
let data = self.context.buffer_to_data(&self.buffer);
let buffer = Arc::new(context.create_buffer_with_data(&data));
Self {
context,
buffer,
shape: self.shape.clone(),
strides: self.strides,
elem: PhantomData::default(),
}
}
pub fn can_mut_broadcast(&self, tensor_other: &WGPUTensor<E, D>) -> bool {
if Arc::strong_count(&self.buffer) > 1 {
return false;
}
for i in 0..D {
// Output tensor will be different from the mutable tensor.
if self.shape.dims[i] < tensor_other.shape.dims[i] {
return false;
}
}
true
}
pub fn can_mut(&self) -> bool {
if Arc::strong_count(&self.buffer) > 1 {
return false;
}
true
}
pub fn assert_is_on_save_device(&self, other: &Self) {
if self.context.device != other.context.device {
panic!(
"Both tensors should be on the same device {:?} != {:?}",
self.context.device, other.context.device
);
}
}
}

View File

@ -0,0 +1,2 @@
mod base;
pub use base::*;