mirror of https://github.com/tracel-ai/burn.git
Feat/wgpu backend setup (#376)
This commit is contained in:
parent
483f9acca5
commit
974fdfaba1
|
@ -14,6 +14,7 @@ members = [
|
|||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
"burn-wgpu",
|
||||
"burn-tensor-testgen",
|
||||
"burn-tensor",
|
||||
"burn-train",
|
||||
|
@ -53,6 +54,11 @@ syn = "2.0"
|
|||
tempfile = "3.5.0"
|
||||
thiserror = "1.0.40"
|
||||
topological-sort = "0.2.2"
|
||||
|
||||
# WGPU stuff
|
||||
wgpu = "0.16.0"
|
||||
futures-intrusive = "0.5"
|
||||
pollster = "0.3"
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
|
|
|
@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
|||
|
||||
mask = mask.to_device(device).repeat(0, batch_size);
|
||||
|
||||
mask.equal_elem(1_i64)
|
||||
mask.equal_elem(1_i64.elem::<i64>())
|
||||
}
|
||||
|
||||
pub struct GeneratePaddingMask<B: Backend> {
|
||||
|
|
|
@ -256,12 +256,6 @@ where
|
|||
K::equal(self.primitive, other.primitive)
|
||||
}
|
||||
|
||||
/// Applies element wise equal comparison and returns a boolean tensor.
|
||||
pub fn equal_elem<E: Into<K::Elem>>(self, other: E) -> Tensor<B, D, Bool> {
|
||||
let elem: K::Elem = other.into();
|
||||
K::equal_elem::<D>(self.primitive, elem)
|
||||
}
|
||||
|
||||
/// Concatenates all tensors into a new one along the given dimension.
|
||||
///
|
||||
/// # Panics
|
||||
|
@ -400,7 +394,6 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||
fn elem_type_name() -> &'static str {
|
||||
core::any::type_name::<Self::Elem>()
|
||||
}
|
||||
|
@ -478,10 +471,6 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::equal(lhs, rhs))
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::equal_elem(lhs, rhs))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BasicOps<B> for Int {
|
||||
|
@ -553,10 +542,6 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
Tensor::new(B::int_equal(lhs, rhs))
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::int_equal_elem(lhs, rhs))
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||
B::int_cat(vectors, dim)
|
||||
}
|
||||
|
@ -631,10 +616,6 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
Tensor::new(B::bool_equal(lhs, rhs))
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::bool_equal_elem(lhs, rhs))
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||
B::bool_cat(vectors, dim)
|
||||
}
|
||||
|
|
|
@ -133,6 +133,11 @@ where
|
|||
Self::new(K::sum_dim(self.primitive, dim))
|
||||
}
|
||||
|
||||
/// Applies element wise equal comparison and returns a boolean tensor.
|
||||
pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
|
||||
K::equal_elem::<D>(self.primitive, other.elem())
|
||||
}
|
||||
|
||||
/// Applies element wise greater comparison and returns a boolean tensor.
|
||||
///
|
||||
/// # Panics
|
||||
|
@ -413,6 +418,7 @@ where
|
|||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||
fn greater<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
|
@ -559,6 +565,9 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::int_equal_elem(lhs, rhs))
|
||||
}
|
||||
fn greater<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
|
@ -777,6 +786,9 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||
Tensor::new(B::equal_elem(lhs, rhs))
|
||||
}
|
||||
fn greater<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
|
|
|
@ -79,7 +79,7 @@ pub trait Backend:
|
|||
/// Tensor primitive to be used for all int operations.
|
||||
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
/// Int element type.
|
||||
type IntElem: Element + From<i64> + Into<i64>;
|
||||
type IntElem: Element;
|
||||
|
||||
/// Tensor primitive to be used for all bool operations.
|
||||
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
|
|
|
@ -22,7 +22,9 @@ pub trait TensorOps<B: Backend> {
|
|||
}
|
||||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
|
||||
Self::to_data(&tensor)
|
||||
}
|
||||
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
||||
fn to_device<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
|
@ -102,7 +104,9 @@ pub trait TensorOps<B: Backend> {
|
|||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
Self::mul_scalar(tensor, (-1.0_f32).elem::<B::FloatElem>())
|
||||
}
|
||||
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
Self::swap_dims(tensor, D - 2, D - 1)
|
||||
}
|
||||
|
|
|
@ -15,4 +15,17 @@ mod tests {
|
|||
let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_broadcast() {
|
||||
let data_1 = Data::from([[0.0, 1.0, 2.0]]);
|
||||
let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2);
|
||||
|
||||
let data_actual = (tensor_1 + tensor_2).into_data();
|
||||
|
||||
let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_exp_ops() {
|
||||
fn should_support_log_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "WGPU backend for burn"
|
||||
edition = "2021"
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license = "MIT/Apache-2.0"
|
||||
name = "burn-wgpu"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
|
||||
version = "0.8.0"
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = {path = "../burn-tensor", version = "0.8.0"}
|
||||
burn-common = {path = "../burn-common", version = "0.8.0"}
|
||||
derive-new = {workspace = true}
|
||||
bytemuck = {workspace = true}
|
||||
rand = {workspace = true}
|
||||
num-traits = {workspace = true}
|
||||
|
||||
# WGPU stuff
|
||||
wgpu = {workspace = true}
|
||||
futures-intrusive = {workspace = true}
|
||||
pollster = {workspace = true}
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
]}
|
||||
burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
]}
|
|
@ -0,0 +1 @@
|
|||
../LICENSE-APACHE
|
|
@ -0,0 +1 @@
|
|||
../LICENSE-MIT
|
|
@ -0,0 +1,3 @@
|
|||
# Burn WGPU Backend
|
||||
|
||||
[Burn](https://github.com/burn-rs/burn) WGPU backend
|
|
@ -0,0 +1,45 @@
|
|||
use burn_tensor::backend::Backend;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
tensor::WGPUTensor,
|
||||
GraphicsAPI, WGPUDevice,
|
||||
};
|
||||
use std::{marker::PhantomData, sync::Mutex};
|
||||
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct WGPUBackend<G: GraphicsAPI, F: FloatElement, I: IntElement> {
|
||||
_g: PhantomData<G>,
|
||||
_f: PhantomData<F>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<G: GraphicsAPI + 'static, F: FloatElement, I: IntElement> Backend for WGPUBackend<G, F, I> {
|
||||
type Device = WGPUDevice;
|
||||
type FullPrecisionBackend = WGPUBackend<G, f32, i32>;
|
||||
|
||||
type FullPrecisionElem = f32;
|
||||
type FloatElem = F;
|
||||
type IntElem = I;
|
||||
|
||||
type TensorPrimitive<const D: usize> = WGPUTensor<F, D>;
|
||||
type IntTensorPrimitive<const D: usize> = WGPUTensor<I, D>;
|
||||
type BoolTensorPrimitive<const D: usize> = WGPUTensor<u32, D>;
|
||||
|
||||
fn name() -> String {
|
||||
String::from("wgpu")
|
||||
}
|
||||
|
||||
fn seed(seed: u64) {
|
||||
let rng = StdRng::seed_from_u64(seed);
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
*seed = Some(rng);
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,256 @@
|
|||
use burn_common::id::IdGenerator;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt},
|
||||
Buffer, DeviceDescriptor, DeviceType, ShaderModule, ShaderModuleDescriptor,
|
||||
};
|
||||
|
||||
use crate::{kernel::KernelGenerator, GraphicsAPI, WGPUDevice};
|
||||
|
||||
/// The context is the basic struct that allows to execute GPU kernel on devices.
|
||||
///
|
||||
/// You can access a context for a [wgpu device](WGPUDevice) using [get_context](crate::pool::get_context).
|
||||
#[derive(Debug)]
|
||||
pub struct Context {
|
||||
id: String,
|
||||
queue: wgpu::Queue,
|
||||
device_wgpu: wgpu::Device,
|
||||
cache: Mutex<HashMap<TypeId, Arc<ShaderModule>>>,
|
||||
pub(crate) device: WGPUDevice,
|
||||
}
|
||||
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct WorkGroup {
|
||||
pub x: u32,
|
||||
pub y: u32,
|
||||
pub z: u32,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub(crate) fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
|
||||
// Instantiates instance of WebGPU
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
// `request_adapter` instantiates the general connection to the GPU
|
||||
let adapters = instance.enumerate_adapters(G::backend().into());
|
||||
let mut adapters = adapters
|
||||
.filter(|adapter| {
|
||||
let device_type = adapter.get_info().device_type;
|
||||
match device {
|
||||
WGPUDevice::DiscreteGPU(_) => device_type == DeviceType::DiscreteGpu,
|
||||
WGPUDevice::IntegratedGPU(_) => device_type == DeviceType::IntegratedGpu,
|
||||
WGPUDevice::VirtualGPU(_) => device_type == DeviceType::VirtualGpu,
|
||||
WGPUDevice::CPU => device_type == DeviceType::Cpu,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let adapter = match device {
|
||||
WGPUDevice::DiscreteGPU(num) => {
|
||||
assert!(adapters.len() > *num, "No Discrete GPU device found");
|
||||
adapters.remove(*num)
|
||||
}
|
||||
WGPUDevice::IntegratedGPU(num) => {
|
||||
assert!(adapters.len() > *num, "No Integrated GPU device found");
|
||||
adapters.remove(*num)
|
||||
}
|
||||
WGPUDevice::VirtualGPU(num) => {
|
||||
assert!(adapters.len() > *num, "No Virtual GPU device found");
|
||||
adapters.remove(*num)
|
||||
}
|
||||
WGPUDevice::CPU => {
|
||||
assert!(!adapters.is_empty(), "No CPU device found");
|
||||
adapters.remove(0)
|
||||
}
|
||||
};
|
||||
|
||||
let device_wgpu = device.clone();
|
||||
let (device, queue) = pollster::block_on(adapter.request_device(
|
||||
&DeviceDescriptor {
|
||||
label: None,
|
||||
features: wgpu::Features::empty(),
|
||||
limits: wgpu::Limits::downlevel_defaults(),
|
||||
},
|
||||
None,
|
||||
))
|
||||
.expect("Unable to request the device with the adapter");
|
||||
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
queue,
|
||||
device_wgpu: device,
|
||||
device: device_wgpu,
|
||||
cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new buffer with the provided size.
|
||||
pub fn create_buffer(&self, size: usize) -> Buffer {
|
||||
self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new buffer initialized with the provided bytes.
|
||||
pub fn create_buffer_with_data(&self, data: &[u8]) -> Buffer {
|
||||
let buffer_src = self.device_wgpu.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
});
|
||||
|
||||
let buffer = self.create_buffer(buffer_src.size() as usize);
|
||||
|
||||
// Create a command encoder
|
||||
let mut encoder =
|
||||
self.device_wgpu
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("Command Encoder"),
|
||||
});
|
||||
|
||||
// Copy data from the staging buffer to the target buffer
|
||||
encoder.copy_buffer_to_buffer(&buffer_src, 0, &buffer, 0, buffer_src.size());
|
||||
|
||||
// Submit the command encoder to the queue
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
buffer
|
||||
}
|
||||
|
||||
/// Read a buffer from the GPU and return its content as bytes.
|
||||
pub fn buffer_to_data(&self, buffer: &Buffer) -> Vec<u8> {
|
||||
let size = buffer.size();
|
||||
|
||||
let buffer_dest = self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Create a command encoder
|
||||
let mut encoder =
|
||||
self.device_wgpu
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("Command Encoder"),
|
||||
});
|
||||
|
||||
encoder.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size);
|
||||
|
||||
self.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
let buffer_slice = buffer_dest.slice(..);
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
.send(v)
|
||||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
self.device_wgpu.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = pollster::block_on(receiver.receive());
|
||||
|
||||
if let Some(Ok(())) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer_dest.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile a kernel template if not present in the cache.
|
||||
pub fn compile<K: KernelGenerator>(&self) -> Arc<ShaderModule> {
|
||||
let mut cache = self.cache.lock().unwrap();
|
||||
let template_id = TypeId::of::<K>();
|
||||
|
||||
if let Some(module) = cache.get(&template_id) {
|
||||
return module.clone();
|
||||
}
|
||||
|
||||
let source = K::generate();
|
||||
|
||||
let module = self
|
||||
.device_wgpu
|
||||
.create_shader_module(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source.as_ref())),
|
||||
});
|
||||
let module = Arc::new(module);
|
||||
|
||||
cache.insert(template_id, module.clone());
|
||||
|
||||
module
|
||||
}
|
||||
|
||||
/// Execute a kernel using the provided buffers.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
|
||||
/// buffer can be mutated when lauching a compute shaders with write access to a buffer.
|
||||
///
|
||||
/// Buffer positions are used as bindings when lauching a compute kernel.
|
||||
pub fn execute(&self, work_group: &WorkGroup, kernel: &ShaderModule, buffers: &[&Buffer]) {
|
||||
let pipeline = self
|
||||
.device_wgpu
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: kernel,
|
||||
entry_point: "main",
|
||||
});
|
||||
|
||||
let group_layout = pipeline.get_bind_group_layout(0);
|
||||
|
||||
let entries = buffers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| wgpu::BindGroupEntry {
|
||||
binding: i as u32,
|
||||
resource: buffer.as_entire_binding(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bind_group = self
|
||||
.device_wgpu
|
||||
.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &group_layout,
|
||||
entries: &entries,
|
||||
});
|
||||
|
||||
let mut encoder = self
|
||||
.device_wgpu
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
let mut compute = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||
compute.set_pipeline(&pipeline);
|
||||
compute.set_bind_group(0, &bind_group, &[]);
|
||||
|
||||
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
|
||||
std::mem::drop(compute);
|
||||
|
||||
self.queue.submit(Some(encoder.finish()));
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Context {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/// The device struct when using the `wgpu` backend.
|
||||
///
|
||||
/// Note that you need to provide the device index when using a GPU backend.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use burn_wgpu::WGPUDevice;
|
||||
///
|
||||
/// let device_gpu_1 = WGPUDevice::DiscreteGPU(0); // First discrete GPU found.
|
||||
/// let device_gpu_2 = WGPUDevice::DiscreteGPU(1); // Second discrete GPU found.
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub enum WGPUDevice {
|
||||
DiscreteGPU(usize),
|
||||
IntegratedGPU(usize),
|
||||
VirtualGPU(usize),
|
||||
CPU,
|
||||
}
|
||||
|
||||
impl Default for WGPUDevice {
|
||||
fn default() -> Self {
|
||||
Self::CPU
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
pub trait WGPUElement: core::fmt::Debug + 'static + Clone
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn type_name() -> &'static str;
|
||||
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||
}
|
||||
|
||||
pub trait FloatElement: WGPUElement + Element {}
|
||||
|
||||
pub trait IntElement: WGPUElement + Element {}
|
||||
|
||||
impl WGPUElement for u32 {
|
||||
fn type_name() -> &'static str {
|
||||
"u32"
|
||||
}
|
||||
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||
bytemuck::cast_slice(slice)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WGPUElement for i32 {
|
||||
fn type_name() -> &'static str {
|
||||
"i32"
|
||||
}
|
||||
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||
bytemuck::cast_slice(slice)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WGPUElement for i64 {
|
||||
fn type_name() -> &'static str {
|
||||
"i64"
|
||||
}
|
||||
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||
bytemuck::cast_slice(slice)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WGPUElement for f32 {
|
||||
fn type_name() -> &'static str {
|
||||
"f32"
|
||||
}
|
||||
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||
bytemuck::cast_slice(slice)
|
||||
}
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatElement for f32 {}
|
||||
impl IntElement for i32 {}
|
||||
impl IntElement for i64 {}
|
|
@ -0,0 +1,61 @@
|
|||
/// The basic trait to specify which graphics API to use as Backend.
|
||||
///
|
||||
/// Options are:
|
||||
/// - [Vulkan](Vulkan)
|
||||
/// - [Metal](Metal)
|
||||
/// - [OpenGL](OpenGL)
|
||||
/// - [DirectX 11](Dx11)
|
||||
/// - [DirectX 12](Dx12)
|
||||
/// - [WebGPU](WebGPU)
|
||||
pub trait GraphicsAPI: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
||||
fn backend() -> wgpu::Backend;
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Vulkan;
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Metal;
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct OpenGL;
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Dx11;
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Dx12;
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct WebGPU;
|
||||
|
||||
impl GraphicsAPI for Vulkan {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::Vulkan
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphicsAPI for Metal {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::Metal
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphicsAPI for OpenGL {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::Gl
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphicsAPI for Dx11 {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::Dx11
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphicsAPI for Dx12 {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::Dx12
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphicsAPI for WebGPU {
|
||||
fn backend() -> wgpu::Backend {
|
||||
wgpu::Backend::BrowserWebGpu
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
use crate::element::WGPUElement;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Generate wgpu kernel source code to create [compute shader modules](wgpu::ShaderModule).
|
||||
pub trait KernelGenerator: 'static {
|
||||
/// Source code concrete type.
|
||||
type Source: AsRef<str>;
|
||||
|
||||
/// Generate the source code.
|
||||
fn generate() -> Self::Source;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! kernel_wgsl {
|
||||
(
|
||||
$struct:ident,
|
||||
$file:expr
|
||||
) => {
|
||||
#[derive(new)]
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = &'static str;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
include_str!($file)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Generate kernel source code by replacing some information using templating.
|
||||
pub struct KernelSettings<
|
||||
K: KernelGenerator,
|
||||
E: WGPUElement,
|
||||
I: WGPUElement,
|
||||
const WORKGROUP_X_SIZE: usize,
|
||||
const WORKGROUP_Y_SIZE: usize,
|
||||
const WORKGROUP_Z_SIZE: usize,
|
||||
> {
|
||||
_k: PhantomData<K>,
|
||||
_e: PhantomData<E>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<
|
||||
K: KernelGenerator,
|
||||
E: WGPUElement,
|
||||
I: WGPUElement,
|
||||
const WORKGROUP_X_SIZE: usize,
|
||||
const WORKGROUP_Y_SIZE: usize,
|
||||
const WORKGROUP_Z_SIZE: usize,
|
||||
> KernelGenerator
|
||||
for KernelSettings<K, E, I, WORKGROUP_X_SIZE, WORKGROUP_Y_SIZE, WORKGROUP_Z_SIZE>
|
||||
{
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> String {
|
||||
let mut source = K::generate().as_ref().to_string();
|
||||
|
||||
source = source.replace("WORKGROUP_SIZE_X", &WORKGROUP_X_SIZE.to_string());
|
||||
source = source.replace("WORKGROUP_SIZE_Y", &WORKGROUP_Y_SIZE.to_string());
|
||||
source = source.replace("WORKGROUP_SIZE_Z", &WORKGROUP_Y_SIZE.to_string());
|
||||
source = source.replace("elem", E::type_name());
|
||||
source = source.replace("int", I::type_name());
|
||||
|
||||
source
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use core::any::TypeId;
|
||||
|
||||
#[test]
|
||||
fn test_kernel_type_id() {
|
||||
kernel_wgsl!(Add, "../template/binary_elemwise.wgsl");
|
||||
|
||||
let type_id_1 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||
let type_id_2 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 5>>();
|
||||
let type_id_3 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||
|
||||
assert_ne!(type_id_1, type_id_2);
|
||||
assert_eq!(type_id_1, type_id_3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
use super::{KernelGenerator, KernelSettings};
|
||||
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||
use burn_tensor::Shape;
|
||||
use num_traits::ToPrimitive;
|
||||
use std::sync::Arc;
|
||||
|
||||
kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl");
|
||||
kernel_wgsl!(
|
||||
BinaryElemwiseInplaceRaw,
|
||||
"../template/binary_elemwise_inplace.wgsl"
|
||||
);
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::BinaryElemwiseRaw::generate().to_string();
|
||||
let body = format!(
|
||||
"output[global_id.x] = lhs[index_lhs] {} rhs[index_rhs]",
|
||||
$ops
|
||||
);
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! binary_elemwise_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::BinaryElemwiseInplaceRaw::generate().to_string();
|
||||
let body = format!(
|
||||
"lhs[global_id.x] = lhs[global_id.x] {} rhs[index_rhs];",
|
||||
$ops
|
||||
);
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn binary_elemwise<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
lhs.assert_is_on_save_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WGPUTensor::new(lhs.context.clone(), shape_out, Arc::new(buffer));
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
|
||||
|
||||
lhs.strides
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
rhs.strides
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
lhs.shape
|
||||
.dims
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
rhs.shape
|
||||
.dims
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
pub fn binary_elemwise_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
lhs.assert_is_on_save_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
|
||||
rhs.strides
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
rhs.shape
|
||||
.dims
|
||||
.into_iter()
|
||||
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
lhs
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
mod base;
|
||||
mod binary_elemwise;
|
||||
mod unary;
|
||||
mod unary_scalar;
|
||||
|
||||
pub use base::*;
|
||||
pub use binary_elemwise::*;
|
||||
pub use unary::*;
|
||||
pub use unary_scalar::*;
|
|
@ -0,0 +1,120 @@
|
|||
use super::{KernelGenerator, KernelSettings};
|
||||
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||
use std::sync::Arc;
|
||||
|
||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unary {
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryRaw::generate().to_string();
|
||||
let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func);
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
body $body:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryRaw::generate().to_string();
|
||||
source.replace("BODY", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unary_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
|
||||
let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func);
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
body $body:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
|
||||
source.replace("BODY", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn unary<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
input: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(input.shape.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WGPUTensor::new(input.context.clone(), input.shape, Arc::new(buffer));
|
||||
let kernel = input
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
|
||||
input.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&input.buffer, &output.buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn unary_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
input: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let kernel = input
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
|
||||
input.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(input.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&input.buffer],
|
||||
);
|
||||
|
||||
input
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
use super::{KernelGenerator, KernelSettings};
|
||||
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||
use std::sync::Arc;
|
||||
|
||||
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
||||
kernel_wgsl!(
|
||||
UnaryScalarInplaceRaw,
|
||||
"../template/unary_scalar_inplace.wgsl"
|
||||
);
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar {
|
||||
(
|
||||
$struct:ident,
|
||||
ops $ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
|
||||
let body = format!("output[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
|
||||
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
|
||||
let body = format!("output[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
|
||||
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
ops $ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
|
||||
let body = format!("lhs[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
|
||||
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
|
||||
let body = format!("lhs[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
|
||||
|
||||
source.replace("BODY", &body)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn unary_scalar<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(lhs.shape.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WGPUTensor::new(lhs.context.clone(), lhs.shape, Arc::new(buffer));
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&lhs.buffer, &rhs_buffer, &output.buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn unary_scalar_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.context.execute(
|
||||
&WorkGroup::new(
|
||||
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
&kernel,
|
||||
&[&lhs.buffer, &rhs_buffer],
|
||||
);
|
||||
|
||||
lhs
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod ops;
|
||||
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod element;
|
||||
pub(crate) mod kernel;
|
||||
pub(crate) mod pool;
|
||||
pub(crate) mod tensor;
|
||||
|
||||
mod device;
|
||||
pub use device::*;
|
||||
|
||||
mod backend;
|
||||
pub use backend::*;
|
||||
|
||||
mod graphics;
|
||||
pub use graphics::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
type TestBackend = crate::WGPUBackend<crate::Vulkan, f32, i64>;
|
||||
|
||||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_relu!();
|
||||
|
||||
// Once all operations will be implemented.
|
||||
// type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
// type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
// burn_tensor::testgen_all!();
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
kernel::{unary, unary_inplace},
|
||||
unary, unary_inplace, GraphicsAPI, WGPUBackend,
|
||||
};
|
||||
|
||||
use super::FloatTensor;
|
||||
|
||||
impl<G, F, I> ActivationOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsAPI + 'static,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Relu, body "output[global_id.x] = max(input[global_id.x], 0.0);");
|
||||
unary_inplace!(ReluInplace, body "input[global_id.x] = max(input[global_id.x], 0.0);");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<ReluInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Relu, F, D>(tensor)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use burn_tensor::{backend::Backend, Data, Shape};
|
||||
|
||||
use crate::{element::WGPUElement, pool::get_context, tensor::WGPUTensor, GraphicsAPI, WGPUDevice};
|
||||
|
||||
pub type FloatElem<B> = <B as Backend>::FloatElem;
|
||||
pub type Device<B> = <B as Backend>::Device;
|
||||
|
||||
pub type FloatTensor<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
|
||||
|
||||
pub type IntElem<B> = <B as Backend>::IntElem;
|
||||
pub type IntTensor<B, const D: usize> = <B as Backend>::IntTensorPrimitive<D>;
|
||||
pub type BoolTensor<B, const D: usize> = <B as Backend>::BoolTensorPrimitive<D>;
|
||||
|
||||
pub struct BaseOps<G: GraphicsAPI> {
|
||||
_g: PhantomData<G>,
|
||||
}
|
||||
|
||||
impl<G: GraphicsAPI> BaseOps<G> {
|
||||
pub fn from_data<E: WGPUElement, const D: usize>(
|
||||
data: Data<E, D>,
|
||||
device: &WGPUDevice,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
let buffer = context.create_buffer_with_data(E::as_bytes(&data.value));
|
||||
|
||||
WGPUTensor::new(context, data.shape, Arc::new(buffer))
|
||||
}
|
||||
|
||||
pub fn to_data<E: WGPUElement, const D: usize>(tensor: &WGPUTensor<E, D>) -> Data<E, D> {
|
||||
let bytes = tensor.context.buffer_to_data(&tensor.buffer);
|
||||
let values = E::from_bytes(&bytes);
|
||||
|
||||
Data::new(values.to_vec(), tensor.shape.clone())
|
||||
}
|
||||
|
||||
pub fn to_device<E: WGPUElement, const D: usize>(
|
||||
tensor: WGPUTensor<E, D>,
|
||||
device: &WGPUDevice,
|
||||
) -> WGPUTensor<E, D> {
|
||||
if &tensor.context.device == device {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let context = get_context::<G>(device);
|
||||
tensor.to_context(context)
|
||||
}
|
||||
|
||||
pub fn empty<E: WGPUElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WGPUDevice,
|
||||
) -> WGPUTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WGPUTensor::new(context, shape, Arc::new(buffer))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
GraphicsAPI, WGPUBackend,
|
||||
};
|
||||
|
||||
use super::{BaseOps, BoolTensor, Device, IntTensor};
|
||||
|
||||
impl<G, F, I> BoolTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsAPI + 'static,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
BaseOps::<G>::empty(shape, device)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(tensor: &BoolTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
|
||||
let data = BaseOps::<G>::to_data(&tensor);
|
||||
|
||||
Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape)
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &Device<Self>,
|
||||
) -> BoolTensor<Self, D> {
|
||||
let data: Data<u32, D> = Data::new(
|
||||
data.value
|
||||
.into_iter()
|
||||
.map(|c| match c {
|
||||
true => 1,
|
||||
false => 0,
|
||||
})
|
||||
.collect(),
|
||||
data.shape,
|
||||
);
|
||||
BaseOps::<G>::from_data(data, device)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(_tensor: BoolTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> BoolTensor<Self, D> {
|
||||
BaseOps::<G>::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||
_shape: Shape<D2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D2> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_rhs: bool,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,406 @@
|
|||
use super::numeric::NumericOps;
|
||||
use super::{BaseOps, Device, FloatElem, FloatTensor};
|
||||
use crate::kernel::{unary, unary_inplace, unary_scalar, unary_scalar_inplace};
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
unary, unary_inplace, GraphicsAPI, WGPUBackend, SEED,
|
||||
};
|
||||
use crate::{unary_scalar, unary_scalar_inplace};
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, Shape};
|
||||
|
||||
impl<G, F, I> TensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsAPI + 'static,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<FloatElem<Self>, D>,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::from_data(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<FloatElem<Self>>,
|
||||
device: &Device<Self>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
|
||||
rng_seeded.clone()
|
||||
} else {
|
||||
get_seeded_rng()
|
||||
};
|
||||
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
|
||||
BaseOps::<G>::to_data(tensor)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::empty(shape, device)
|
||||
}
|
||||
|
||||
fn add<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn add_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn sub<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn sub_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn mul<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn mul_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn div<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn div_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn matmul<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn reshape<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||
_shape: Shape<D2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn gather<const D: usize>(
|
||||
_dim: usize,
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
_dim: usize,
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||
_dim: usize,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn mask_scatter<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_source: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn mask_fill<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn greater<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn greater_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn greater_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn greater_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn lower<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn lower_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn lower_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn lower_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn sum<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn sum_dim<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn mean<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn mean_dim<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn to_full_precision<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_full_precision<const D: usize>(
|
||||
_tensor: <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Exp, func "exp");
|
||||
unary_inplace!(ExpInplace, func "exp");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_inplace::<ExpInplace, F, D>(lhs);
|
||||
}
|
||||
|
||||
unary::<Exp, F, D>(lhs)
|
||||
}
|
||||
|
||||
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Log, func "log");
|
||||
unary_inplace!(LogInplace, func "log");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<LogInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Log, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||
unary_scalar!(Powf, func "pow");
|
||||
unary_scalar_inplace!(PowfInplace, func "pow");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace::<PowfInplace, F, D>(lhs, rhs.elem());
|
||||
}
|
||||
|
||||
unary_scalar::<Powf, F, D>(lhs, rhs.elem())
|
||||
}
|
||||
|
||||
fn sqrt<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(
|
||||
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,314 @@
|
|||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
GraphicsAPI, WGPUBackend,
|
||||
};
|
||||
|
||||
use super::{numeric::NumericOps, BaseOps, Device, IntElem, IntTensor};
|
||||
|
||||
impl<G, F, I> IntTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsAPI + 'static,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
BaseOps::<G>::empty(shape, device)
|
||||
}
|
||||
|
||||
fn int_shape<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> Shape<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
|
||||
BaseOps::<G>::to_data(&tensor)
|
||||
}
|
||||
|
||||
fn int_from_data<const D: usize>(
|
||||
data: Data<I, D>,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
BaseOps::<G>::from_data(data, device)
|
||||
}
|
||||
|
||||
fn int_device<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
BaseOps::<G>::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn int_reshape<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
_shape: Shape<D2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
_indexes: [std::ops::Range<usize>; D2],
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_mask_scatter<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_source: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_mask_fill<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_gather<const D: usize>(
|
||||
_dim: usize,
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_scatter<const D: usize>(
|
||||
_dim: usize,
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
_dim: usize,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(
|
||||
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_greater<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_greater_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_greater_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_lower<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_lower_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_lower_equal<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem<const D: usize>(
|
||||
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_add<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::add::<I, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_neg<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_zeros<const D: usize>(
|
||||
_shape: Shape<D>,
|
||||
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_ones<const D: usize>(
|
||||
_shape: Shape<D>,
|
||||
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_sum<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_sum_dim<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_mean<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_mean_dim<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_argmax<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_argmin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
mod activation_ops;
|
||||
mod bool_ops;
|
||||
mod float_ops;
|
||||
mod int_ops;
|
||||
mod module_ops;
|
||||
|
||||
mod base;
|
||||
pub(crate) use base::*;
|
||||
|
||||
pub(crate) mod numeric;
|
|
@ -0,0 +1,94 @@
|
|||
use burn_tensor::{backend::Backend, ops::ModuleOps};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
GraphicsAPI, WGPUBackend,
|
||||
};
|
||||
|
||||
impl<G, F, I> ModuleOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsAPI + 'static,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn embedding(
|
||||
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn embedding_backward(
|
||||
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
|
||||
_output: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
|
||||
_options: burn_tensor::ops::ConvOptions<2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
|
||||
_options: burn_tensor::ops::ConvTransposeOptions<2>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> burn_tensor::ops::MaxPool2dWithIndexes<WGPUBackend<G, F, I>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_output_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
|
||||
) -> burn_tensor::ops::MaxPool2dBackward<WGPUBackend<G, F, I>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace};
|
||||
use crate::{
|
||||
binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WGPUTensor,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
};
|
||||
|
||||
pub struct NumericOps;
|
||||
|
||||
impl NumericOps {
|
||||
pub fn add<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
binary_elemwise!(Add, "+");
|
||||
binary_elemwise_inplace!(AddInplace, "+");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace::<AddInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
if rhs.can_mut_broadcast(&lhs) {
|
||||
return binary_elemwise_inplace::<AddInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
binary_elemwise::<Add, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn add_scalar<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
unary_scalar!(AddScalar, ops "+");
|
||||
unary_scalar_inplace!(AddScalarInplace, ops "+");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace::<AddScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar::<AddScalar, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sub<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
binary_elemwise!(Sub, "-");
|
||||
binary_elemwise_inplace!(SubInplace, "-");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace::<SubInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
binary_elemwise::<Sub, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sub_scalar<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
unary_scalar!(SubScalar, ops "-");
|
||||
unary_scalar_inplace!(SubScalarInplace, ops "-");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace::<SubScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar::<SubScalar, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn mul<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
binary_elemwise!(Mul, "*");
|
||||
binary_elemwise_inplace!(MulInplace, "*");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace::<MulInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
if rhs.can_mut_broadcast(&lhs) {
|
||||
return binary_elemwise_inplace::<MulInplace, E, D>(rhs, lhs);
|
||||
}
|
||||
|
||||
binary_elemwise::<Mul, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn mul_scalar<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
unary_scalar!(MulScalar, ops "*");
|
||||
unary_scalar_inplace!(MulScalarInplace, ops "*");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace::<MulScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar::<MulScalar, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn div<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: WGPUTensor<E, D>,
|
||||
) -> WGPUTensor<E, D> {
|
||||
binary_elemwise!(Div, "/");
|
||||
binary_elemwise_inplace!(DivInplace, "/");
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
return binary_elemwise_inplace::<DivInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
binary_elemwise::<Div, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn div_scalar<E: WGPUElement, const D: usize>(
|
||||
lhs: WGPUTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WGPUTensor<E, D> {
|
||||
unary_scalar!(DivScalar, ops "/");
|
||||
unary_scalar_inplace!(DivScalarInplace, ops "/");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace::<DivScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar::<DivScalar, E, D>(lhs, rhs)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
use crate::{context::Context, GraphicsAPI, WGPUDevice};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
static POOL_CONTEXT: Mutex<Option<ContextPool>> = Mutex::new(None);
|
||||
|
||||
#[derive(Default)]
|
||||
struct ContextPool {
|
||||
contexts: HashMap<Key, Arc<Context>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
struct Key {
|
||||
api_id: TypeId,
|
||||
device: WGPUDevice,
|
||||
}
|
||||
|
||||
impl Key {
|
||||
fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
|
||||
Self {
|
||||
api_id: TypeId::of::<G>(),
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a [context](Context) for the given [device](WGPUDevice).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// If a context already exist for the current [device](WGPUDevice), the same instance will be
|
||||
/// returned.
|
||||
pub fn get_context<G: GraphicsAPI>(device: &WGPUDevice) -> Arc<Context> {
|
||||
let mut pool = POOL_CONTEXT.lock().unwrap();
|
||||
|
||||
let context = if let Some(pool) = pool.as_mut() {
|
||||
// Fetch device in pool
|
||||
match pool.contexts.get(&Key::new::<G>(device)) {
|
||||
Some(context) => context.clone(),
|
||||
None => {
|
||||
// Init new device
|
||||
let context = Arc::new(Context::new::<G>(device));
|
||||
pool.contexts.insert(Key::new::<G>(device), context.clone());
|
||||
context
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Initialize pool
|
||||
let context = Arc::new(Context::new::<G>(device));
|
||||
let mut new_pool = ContextPool::default();
|
||||
|
||||
new_pool
|
||||
.contexts
|
||||
.insert(Key::new::<G>(device), context.clone());
|
||||
*pool = Some(new_pool);
|
||||
context
|
||||
};
|
||||
|
||||
context
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let dim: u32 = info[0];
|
||||
var index_lhs: u32 = 0u;
|
||||
var index_rhs: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 0u; i < dim; i++) {
|
||||
let stride_lhs = info[i + 1u];
|
||||
let stride_rhs = info[i + 1u * dim + 1u];
|
||||
let shape_lhs = info[i + 2u * dim + 1u];
|
||||
let shape_rhs = info[i + 3u * dim + 1u];
|
||||
|
||||
index_lhs += global_id.x / stride_lhs % shape_lhs * stride_lhs;
|
||||
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let dim: u32 = info[0];
|
||||
var index_rhs: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 0u; i < dim; i++) {
|
||||
let stride_rhs = info[i + 1u];
|
||||
let shape_rhs = info[i + 1u * dim + 1u];
|
||||
|
||||
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<elem>;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<elem>;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: elem;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<elem>;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<elem>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: elem;
|
||||
|
||||
@compute
|
||||
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
BODY
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
use burn_tensor::Shape;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use wgpu::Buffer;
|
||||
|
||||
use crate::{context::Context, element::WGPUElement};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WGPUTensor<E: WGPUElement, const D: usize> {
|
||||
pub(crate) context: Arc<Context>,
|
||||
pub(crate) buffer: Arc<Buffer>,
|
||||
pub(crate) shape: Shape<D>,
|
||||
pub(crate) strides: [usize; D],
|
||||
elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WGPUElement, const D: usize> WGPUTensor<E, D> {
|
||||
pub fn new(context: Arc<Context>, shape: Shape<D>, buffer: Arc<Buffer>) -> Self {
|
||||
let mut strides = [0; D];
|
||||
|
||||
let mut current = 1;
|
||||
shape
|
||||
.dims
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
});
|
||||
|
||||
Self {
|
||||
context,
|
||||
buffer,
|
||||
shape,
|
||||
strides,
|
||||
elem: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
pub fn to_context(&self, context: Arc<Context>) -> Self {
|
||||
let data = self.context.buffer_to_data(&self.buffer);
|
||||
let buffer = Arc::new(context.create_buffer_with_data(&data));
|
||||
|
||||
Self {
|
||||
context,
|
||||
buffer,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides,
|
||||
elem: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
pub fn can_mut_broadcast(&self, tensor_other: &WGPUTensor<E, D>) -> bool {
|
||||
if Arc::strong_count(&self.buffer) > 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
for i in 0..D {
|
||||
// Output tensor will be different from the mutable tensor.
|
||||
if self.shape.dims[i] < tensor_other.shape.dims[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn can_mut(&self) -> bool {
|
||||
if Arc::strong_count(&self.buffer) > 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn assert_is_on_save_device(&self, other: &Self) {
|
||||
if self.context.device != other.context.device {
|
||||
panic!(
|
||||
"Both tensors should be on the same device {:?} != {:?}",
|
||||
self.context.device, other.context.device
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
mod base;
|
||||
pub use base::*;
|
Loading…
Reference in New Issue