mirror of https://github.com/tracel-ai/burn.git
Feat/wgpu backend setup (#376)
This commit is contained in:
parent
483f9acca5
commit
974fdfaba1
|
@ -14,6 +14,7 @@ members = [
|
||||||
"burn-ndarray",
|
"burn-ndarray",
|
||||||
"burn-no-std-tests",
|
"burn-no-std-tests",
|
||||||
"burn-tch",
|
"burn-tch",
|
||||||
|
"burn-wgpu",
|
||||||
"burn-tensor-testgen",
|
"burn-tensor-testgen",
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
"burn-train",
|
"burn-train",
|
||||||
|
@ -53,6 +54,11 @@ syn = "2.0"
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
thiserror = "1.0.40"
|
thiserror = "1.0.40"
|
||||||
topological-sort = "0.2.2"
|
topological-sort = "0.2.2"
|
||||||
|
|
||||||
|
# WGPU stuff
|
||||||
|
wgpu = "0.16.0"
|
||||||
|
futures-intrusive = "0.5"
|
||||||
|
pollster = "0.3"
|
||||||
#
|
#
|
||||||
# The following packages disable the "std" feature for no_std compatibility
|
# The following packages disable the "std" feature for no_std compatibility
|
||||||
#
|
#
|
||||||
|
|
|
@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
||||||
|
|
||||||
mask = mask.to_device(device).repeat(0, batch_size);
|
mask = mask.to_device(device).repeat(0, batch_size);
|
||||||
|
|
||||||
mask.equal_elem(1_i64)
|
mask.equal_elem(1_i64.elem::<i64>())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct GeneratePaddingMask<B: Backend> {
|
pub struct GeneratePaddingMask<B: Backend> {
|
||||||
|
|
|
@ -256,12 +256,6 @@ where
|
||||||
K::equal(self.primitive, other.primitive)
|
K::equal(self.primitive, other.primitive)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies element wise equal comparison and returns a boolean tensor.
|
|
||||||
pub fn equal_elem<E: Into<K::Elem>>(self, other: E) -> Tensor<B, D, Bool> {
|
|
||||||
let elem: K::Elem = other.into();
|
|
||||||
K::equal_elem::<D>(self.primitive, elem)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Concatenates all tensors into a new one along the given dimension.
|
/// Concatenates all tensors into a new one along the given dimension.
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
|
@ -400,7 +394,6 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||||
lhs: Self::Primitive<D>,
|
lhs: Self::Primitive<D>,
|
||||||
rhs: Self::Primitive<D>,
|
rhs: Self::Primitive<D>,
|
||||||
) -> Tensor<B, D, Bool>;
|
) -> Tensor<B, D, Bool>;
|
||||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
|
||||||
fn elem_type_name() -> &'static str {
|
fn elem_type_name() -> &'static str {
|
||||||
core::any::type_name::<Self::Elem>()
|
core::any::type_name::<Self::Elem>()
|
||||||
}
|
}
|
||||||
|
@ -478,10 +471,6 @@ impl<B: Backend> BasicOps<B> for Float {
|
||||||
) -> Tensor<B, D, Bool> {
|
) -> Tensor<B, D, Bool> {
|
||||||
Tensor::new(B::equal(lhs, rhs))
|
Tensor::new(B::equal(lhs, rhs))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
|
||||||
Tensor::new(B::equal_elem(lhs, rhs))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> BasicOps<B> for Int {
|
impl<B: Backend> BasicOps<B> for Int {
|
||||||
|
@ -553,10 +542,6 @@ impl<B: Backend> BasicOps<B> for Int {
|
||||||
Tensor::new(B::int_equal(lhs, rhs))
|
Tensor::new(B::int_equal(lhs, rhs))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
|
||||||
Tensor::new(B::int_equal_elem(lhs, rhs))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||||
B::int_cat(vectors, dim)
|
B::int_cat(vectors, dim)
|
||||||
}
|
}
|
||||||
|
@ -631,10 +616,6 @@ impl<B: Backend> BasicOps<B> for Bool {
|
||||||
Tensor::new(B::bool_equal(lhs, rhs))
|
Tensor::new(B::bool_equal(lhs, rhs))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
|
||||||
Tensor::new(B::bool_equal_elem(lhs, rhs))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||||
B::bool_cat(vectors, dim)
|
B::bool_cat(vectors, dim)
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,6 +133,11 @@ where
|
||||||
Self::new(K::sum_dim(self.primitive, dim))
|
Self::new(K::sum_dim(self.primitive, dim))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies element wise equal comparison and returns a boolean tensor.
|
||||||
|
pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
|
||||||
|
K::equal_elem::<D>(self.primitive, other.elem())
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies element wise greater comparison and returns a boolean tensor.
|
/// Applies element wise greater comparison and returns a boolean tensor.
|
||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
|
@ -413,6 +418,7 @@ where
|
||||||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||||
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||||
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||||
|
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||||
fn greater<const D: usize>(
|
fn greater<const D: usize>(
|
||||||
lhs: Self::Primitive<D>,
|
lhs: Self::Primitive<D>,
|
||||||
rhs: Self::Primitive<D>,
|
rhs: Self::Primitive<D>,
|
||||||
|
@ -559,6 +565,9 @@ impl<B: Backend> Numeric<B> for Int {
|
||||||
B::int_mean_dim(tensor, dim)
|
B::int_mean_dim(tensor, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||||
|
Tensor::new(B::int_equal_elem(lhs, rhs))
|
||||||
|
}
|
||||||
fn greater<const D: usize>(
|
fn greater<const D: usize>(
|
||||||
lhs: Self::Primitive<D>,
|
lhs: Self::Primitive<D>,
|
||||||
rhs: Self::Primitive<D>,
|
rhs: Self::Primitive<D>,
|
||||||
|
@ -777,6 +786,9 @@ impl<B: Backend> Numeric<B> for Float {
|
||||||
B::mean_dim(tensor, dim)
|
B::mean_dim(tensor, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool> {
|
||||||
|
Tensor::new(B::equal_elem(lhs, rhs))
|
||||||
|
}
|
||||||
fn greater<const D: usize>(
|
fn greater<const D: usize>(
|
||||||
lhs: Self::Primitive<D>,
|
lhs: Self::Primitive<D>,
|
||||||
rhs: Self::Primitive<D>,
|
rhs: Self::Primitive<D>,
|
||||||
|
|
|
@ -79,7 +79,7 @@ pub trait Backend:
|
||||||
/// Tensor primitive to be used for all int operations.
|
/// Tensor primitive to be used for all int operations.
|
||||||
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||||
/// Int element type.
|
/// Int element type.
|
||||||
type IntElem: Element + From<i64> + Into<i64>;
|
type IntElem: Element;
|
||||||
|
|
||||||
/// Tensor primitive to be used for all bool operations.
|
/// Tensor primitive to be used for all bool operations.
|
||||||
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||||
|
|
|
@ -22,7 +22,9 @@ pub trait TensorOps<B: Backend> {
|
||||||
}
|
}
|
||||||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
||||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
|
||||||
|
Self::to_data(&tensor)
|
||||||
|
}
|
||||||
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
||||||
fn to_device<const D: usize>(
|
fn to_device<const D: usize>(
|
||||||
tensor: B::TensorPrimitive<D>,
|
tensor: B::TensorPrimitive<D>,
|
||||||
|
@ -102,7 +104,9 @@ pub trait TensorOps<B: Backend> {
|
||||||
lhs: B::TensorPrimitive<D>,
|
lhs: B::TensorPrimitive<D>,
|
||||||
rhs: B::TensorPrimitive<D>,
|
rhs: B::TensorPrimitive<D>,
|
||||||
) -> B::TensorPrimitive<D>;
|
) -> B::TensorPrimitive<D>;
|
||||||
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||||
|
Self::mul_scalar(tensor, (-1.0_f32).elem::<B::FloatElem>())
|
||||||
|
}
|
||||||
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||||
Self::swap_dims(tensor, D - 2, D - 1)
|
Self::swap_dims(tensor, D - 2, D - 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,4 +15,17 @@ mod tests {
|
||||||
let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]);
|
let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]);
|
||||||
assert_eq!(data_expected, data_actual);
|
assert_eq!(data_expected, data_actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_broadcast() {
|
||||||
|
let data_1 = Data::from([[0.0, 1.0, 2.0]]);
|
||||||
|
let data_2 = Data::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);
|
||||||
|
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
|
||||||
|
let tensor_2 = Tensor::<TestBackend, 2>::from_data(data_2);
|
||||||
|
|
||||||
|
let data_actual = (tensor_1 + tensor_2).into_data();
|
||||||
|
|
||||||
|
let data_expected = Data::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]);
|
||||||
|
assert_eq!(data_expected, data_actual);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
||||||
use burn_tensor::{Data, Tensor};
|
use burn_tensor::{Data, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_exp_ops() {
|
fn should_support_log_ops() {
|
||||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
[package]
|
||||||
|
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
categories = ["science"]
|
||||||
|
description = "WGPU backend for burn"
|
||||||
|
edition = "2021"
|
||||||
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
name = "burn-wgpu"
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
|
||||||
|
version = "0.8.0"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
burn-tensor = {path = "../burn-tensor", version = "0.8.0"}
|
||||||
|
burn-common = {path = "../burn-common", version = "0.8.0"}
|
||||||
|
derive-new = {workspace = true}
|
||||||
|
bytemuck = {workspace = true}
|
||||||
|
rand = {workspace = true}
|
||||||
|
num-traits = {workspace = true}
|
||||||
|
|
||||||
|
# WGPU stuff
|
||||||
|
wgpu = {workspace = true}
|
||||||
|
futures-intrusive = {workspace = true}
|
||||||
|
pollster = {workspace = true}
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
|
||||||
|
"export_tests",
|
||||||
|
]}
|
||||||
|
burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = false, features = [
|
||||||
|
"export_tests",
|
||||||
|
]}
|
|
@ -0,0 +1 @@
|
||||||
|
../LICENSE-APACHE
|
|
@ -0,0 +1 @@
|
||||||
|
../LICENSE-MIT
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Burn WGPU Backend
|
||||||
|
|
||||||
|
[Burn](https://github.com/burn-rs/burn) WGPU backend
|
|
@ -0,0 +1,45 @@
|
||||||
|
use burn_tensor::backend::Backend;
|
||||||
|
use rand::{rngs::StdRng, SeedableRng};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
tensor::WGPUTensor,
|
||||||
|
GraphicsAPI, WGPUDevice,
|
||||||
|
};
|
||||||
|
use std::{marker::PhantomData, sync::Mutex};
|
||||||
|
|
||||||
|
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct WGPUBackend<G: GraphicsAPI, F: FloatElement, I: IntElement> {
|
||||||
|
_g: PhantomData<G>,
|
||||||
|
_f: PhantomData<F>,
|
||||||
|
_i: PhantomData<I>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<G: GraphicsAPI + 'static, F: FloatElement, I: IntElement> Backend for WGPUBackend<G, F, I> {
|
||||||
|
type Device = WGPUDevice;
|
||||||
|
type FullPrecisionBackend = WGPUBackend<G, f32, i32>;
|
||||||
|
|
||||||
|
type FullPrecisionElem = f32;
|
||||||
|
type FloatElem = F;
|
||||||
|
type IntElem = I;
|
||||||
|
|
||||||
|
type TensorPrimitive<const D: usize> = WGPUTensor<F, D>;
|
||||||
|
type IntTensorPrimitive<const D: usize> = WGPUTensor<I, D>;
|
||||||
|
type BoolTensorPrimitive<const D: usize> = WGPUTensor<u32, D>;
|
||||||
|
|
||||||
|
fn name() -> String {
|
||||||
|
String::from("wgpu")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn seed(seed: u64) {
|
||||||
|
let rng = StdRng::seed_from_u64(seed);
|
||||||
|
let mut seed = SEED.lock().unwrap();
|
||||||
|
*seed = Some(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ad_enabled() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,256 @@
|
||||||
|
use burn_common::id::IdGenerator;
|
||||||
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
|
borrow::Cow,
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
|
use wgpu::{
|
||||||
|
util::{BufferInitDescriptor, DeviceExt},
|
||||||
|
Buffer, DeviceDescriptor, DeviceType, ShaderModule, ShaderModuleDescriptor,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{kernel::KernelGenerator, GraphicsAPI, WGPUDevice};
|
||||||
|
|
||||||
|
/// The context is the basic struct that allows to execute GPU kernel on devices.
|
||||||
|
///
|
||||||
|
/// You can access a context for a [wgpu device](WGPUDevice) using [get_context](crate::pool::get_context).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Context {
|
||||||
|
id: String,
|
||||||
|
queue: wgpu::Queue,
|
||||||
|
device_wgpu: wgpu::Device,
|
||||||
|
cache: Mutex<HashMap<TypeId, Arc<ShaderModule>>>,
|
||||||
|
pub(crate) device: WGPUDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new, Clone, Debug)]
|
||||||
|
pub struct WorkGroup {
|
||||||
|
pub x: u32,
|
||||||
|
pub y: u32,
|
||||||
|
pub z: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
pub(crate) fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
|
||||||
|
// Instantiates instance of WebGPU
|
||||||
|
let instance = wgpu::Instance::default();
|
||||||
|
|
||||||
|
// `request_adapter` instantiates the general connection to the GPU
|
||||||
|
let adapters = instance.enumerate_adapters(G::backend().into());
|
||||||
|
let mut adapters = adapters
|
||||||
|
.filter(|adapter| {
|
||||||
|
let device_type = adapter.get_info().device_type;
|
||||||
|
match device {
|
||||||
|
WGPUDevice::DiscreteGPU(_) => device_type == DeviceType::DiscreteGpu,
|
||||||
|
WGPUDevice::IntegratedGPU(_) => device_type == DeviceType::IntegratedGpu,
|
||||||
|
WGPUDevice::VirtualGPU(_) => device_type == DeviceType::VirtualGpu,
|
||||||
|
WGPUDevice::CPU => device_type == DeviceType::Cpu,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let adapter = match device {
|
||||||
|
WGPUDevice::DiscreteGPU(num) => {
|
||||||
|
assert!(adapters.len() > *num, "No Discrete GPU device found");
|
||||||
|
adapters.remove(*num)
|
||||||
|
}
|
||||||
|
WGPUDevice::IntegratedGPU(num) => {
|
||||||
|
assert!(adapters.len() > *num, "No Integrated GPU device found");
|
||||||
|
adapters.remove(*num)
|
||||||
|
}
|
||||||
|
WGPUDevice::VirtualGPU(num) => {
|
||||||
|
assert!(adapters.len() > *num, "No Virtual GPU device found");
|
||||||
|
adapters.remove(*num)
|
||||||
|
}
|
||||||
|
WGPUDevice::CPU => {
|
||||||
|
assert!(!adapters.is_empty(), "No CPU device found");
|
||||||
|
adapters.remove(0)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let device_wgpu = device.clone();
|
||||||
|
let (device, queue) = pollster::block_on(adapter.request_device(
|
||||||
|
&DeviceDescriptor {
|
||||||
|
label: None,
|
||||||
|
features: wgpu::Features::empty(),
|
||||||
|
limits: wgpu::Limits::downlevel_defaults(),
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
.expect("Unable to request the device with the adapter");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
id: IdGenerator::generate(),
|
||||||
|
queue,
|
||||||
|
device_wgpu: device,
|
||||||
|
device: device_wgpu,
|
||||||
|
cache: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new buffer with the provided size.
|
||||||
|
pub fn create_buffer(&self, size: usize) -> Buffer {
|
||||||
|
self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: None,
|
||||||
|
size: size as u64,
|
||||||
|
usage: wgpu::BufferUsages::COPY_DST
|
||||||
|
| wgpu::BufferUsages::STORAGE
|
||||||
|
| wgpu::BufferUsages::COPY_SRC,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new buffer initialized with the provided bytes.
|
||||||
|
pub fn create_buffer_with_data(&self, data: &[u8]) -> Buffer {
|
||||||
|
let buffer_src = self.device_wgpu.create_buffer_init(&BufferInitDescriptor {
|
||||||
|
label: Some("Buffer Src"),
|
||||||
|
contents: data,
|
||||||
|
usage: wgpu::BufferUsages::COPY_SRC,
|
||||||
|
});
|
||||||
|
|
||||||
|
let buffer = self.create_buffer(buffer_src.size() as usize);
|
||||||
|
|
||||||
|
// Create a command encoder
|
||||||
|
let mut encoder =
|
||||||
|
self.device_wgpu
|
||||||
|
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
label: Some("Command Encoder"),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Copy data from the staging buffer to the target buffer
|
||||||
|
encoder.copy_buffer_to_buffer(&buffer_src, 0, &buffer, 0, buffer_src.size());
|
||||||
|
|
||||||
|
// Submit the command encoder to the queue
|
||||||
|
self.queue.submit(std::iter::once(encoder.finish()));
|
||||||
|
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a buffer from the GPU and return its content as bytes.
|
||||||
|
pub fn buffer_to_data(&self, buffer: &Buffer) -> Vec<u8> {
|
||||||
|
let size = buffer.size();
|
||||||
|
|
||||||
|
let buffer_dest = self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: None,
|
||||||
|
size,
|
||||||
|
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create a command encoder
|
||||||
|
let mut encoder =
|
||||||
|
self.device_wgpu
|
||||||
|
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
label: Some("Command Encoder"),
|
||||||
|
});
|
||||||
|
|
||||||
|
encoder.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size);
|
||||||
|
|
||||||
|
self.queue.submit(std::iter::once(encoder.finish()));
|
||||||
|
|
||||||
|
let buffer_slice = buffer_dest.slice(..);
|
||||||
|
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||||
|
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||||
|
sender
|
||||||
|
.send(v)
|
||||||
|
.expect("Unable to send buffer slice result to async channel.")
|
||||||
|
});
|
||||||
|
|
||||||
|
self.device_wgpu.poll(wgpu::Maintain::Wait);
|
||||||
|
|
||||||
|
let result = pollster::block_on(receiver.receive());
|
||||||
|
|
||||||
|
if let Some(Ok(())) = result {
|
||||||
|
let data = buffer_slice.get_mapped_range();
|
||||||
|
let result = bytemuck::cast_slice(&data).to_vec();
|
||||||
|
|
||||||
|
drop(data);
|
||||||
|
buffer_dest.unmap();
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
panic!("Unable to read buffer {:?}", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compile a kernel template if not present in the cache.
|
||||||
|
pub fn compile<K: KernelGenerator>(&self) -> Arc<ShaderModule> {
|
||||||
|
let mut cache = self.cache.lock().unwrap();
|
||||||
|
let template_id = TypeId::of::<K>();
|
||||||
|
|
||||||
|
if let Some(module) = cache.get(&template_id) {
|
||||||
|
return module.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
let source = K::generate();
|
||||||
|
|
||||||
|
let module = self
|
||||||
|
.device_wgpu
|
||||||
|
.create_shader_module(ShaderModuleDescriptor {
|
||||||
|
label: None,
|
||||||
|
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source.as_ref())),
|
||||||
|
});
|
||||||
|
let module = Arc::new(module);
|
||||||
|
|
||||||
|
cache.insert(template_id, module.clone());
|
||||||
|
|
||||||
|
module
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a kernel using the provided buffers.
|
||||||
|
///
|
||||||
|
/// # Notes
|
||||||
|
///
|
||||||
|
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
|
||||||
|
/// buffer can be mutated when lauching a compute shaders with write access to a buffer.
|
||||||
|
///
|
||||||
|
/// Buffer positions are used as bindings when lauching a compute kernel.
|
||||||
|
pub fn execute(&self, work_group: &WorkGroup, kernel: &ShaderModule, buffers: &[&Buffer]) {
|
||||||
|
let pipeline = self
|
||||||
|
.device_wgpu
|
||||||
|
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||||
|
label: None,
|
||||||
|
layout: None,
|
||||||
|
module: kernel,
|
||||||
|
entry_point: "main",
|
||||||
|
});
|
||||||
|
|
||||||
|
let group_layout = pipeline.get_bind_group_layout(0);
|
||||||
|
|
||||||
|
let entries = buffers
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, buffer)| wgpu::BindGroupEntry {
|
||||||
|
binding: i as u32,
|
||||||
|
resource: buffer.as_entire_binding(),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let bind_group = self
|
||||||
|
.device_wgpu
|
||||||
|
.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||||
|
label: None,
|
||||||
|
layout: &group_layout,
|
||||||
|
entries: &entries,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut encoder = self
|
||||||
|
.device_wgpu
|
||||||
|
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||||
|
let mut compute = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||||
|
compute.set_pipeline(&pipeline);
|
||||||
|
compute.set_bind_group(0, &bind_group, &[]);
|
||||||
|
|
||||||
|
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
|
||||||
|
std::mem::drop(compute);
|
||||||
|
|
||||||
|
self.queue.submit(Some(encoder.finish()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for Context {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.id == other.id
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
/// The device struct when using the `wgpu` backend.
|
||||||
|
///
|
||||||
|
/// Note that you need to provide the device index when using a GPU backend.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use burn_wgpu::WGPUDevice;
|
||||||
|
///
|
||||||
|
/// let device_gpu_1 = WGPUDevice::DiscreteGPU(0); // First discrete GPU found.
|
||||||
|
/// let device_gpu_2 = WGPUDevice::DiscreteGPU(1); // Second discrete GPU found.
|
||||||
|
/// ```
|
||||||
|
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||||
|
pub enum WGPUDevice {
|
||||||
|
DiscreteGPU(usize),
|
||||||
|
IntegratedGPU(usize),
|
||||||
|
VirtualGPU(usize),
|
||||||
|
CPU,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WGPUDevice {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::CPU
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
use burn_tensor::Element;
|
||||||
|
|
||||||
|
pub trait WGPUElement: core::fmt::Debug + 'static + Clone
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
fn type_name() -> &'static str;
|
||||||
|
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||||
|
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait FloatElement: WGPUElement + Element {}
|
||||||
|
|
||||||
|
pub trait IntElement: WGPUElement + Element {}
|
||||||
|
|
||||||
|
impl WGPUElement for u32 {
|
||||||
|
fn type_name() -> &'static str {
|
||||||
|
"u32"
|
||||||
|
}
|
||||||
|
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||||
|
bytemuck::cast_slice(slice)
|
||||||
|
}
|
||||||
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
|
bytemuck::cast_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WGPUElement for i32 {
|
||||||
|
fn type_name() -> &'static str {
|
||||||
|
"i32"
|
||||||
|
}
|
||||||
|
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||||
|
bytemuck::cast_slice(slice)
|
||||||
|
}
|
||||||
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
|
bytemuck::cast_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WGPUElement for i64 {
|
||||||
|
fn type_name() -> &'static str {
|
||||||
|
"i64"
|
||||||
|
}
|
||||||
|
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||||
|
bytemuck::cast_slice(slice)
|
||||||
|
}
|
||||||
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
|
bytemuck::cast_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WGPUElement for f32 {
|
||||||
|
fn type_name() -> &'static str {
|
||||||
|
"f32"
|
||||||
|
}
|
||||||
|
fn as_bytes(slice: &[Self]) -> &[u8] {
|
||||||
|
bytemuck::cast_slice(slice)
|
||||||
|
}
|
||||||
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
|
bytemuck::cast_slice(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FloatElement for f32 {}
|
||||||
|
impl IntElement for i32 {}
|
||||||
|
impl IntElement for i64 {}
|
|
@ -0,0 +1,61 @@
|
||||||
|
/// The basic trait to specify which graphics API to use as Backend.
|
||||||
|
///
|
||||||
|
/// Options are:
|
||||||
|
/// - [Vulkan](Vulkan)
|
||||||
|
/// - [Metal](Metal)
|
||||||
|
/// - [OpenGL](OpenGL)
|
||||||
|
/// - [DirectX 11](Dx11)
|
||||||
|
/// - [DirectX 12](Dx12)
|
||||||
|
/// - [WebGPU](WebGPU)
|
||||||
|
pub trait GraphicsAPI: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
||||||
|
fn backend() -> wgpu::Backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct Vulkan;
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct Metal;
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct OpenGL;
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct Dx11;
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct Dx12;
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct WebGPU;
|
||||||
|
|
||||||
|
impl GraphicsAPI for Vulkan {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::Vulkan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GraphicsAPI for Metal {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::Metal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GraphicsAPI for OpenGL {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::Gl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GraphicsAPI for Dx11 {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::Dx11
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GraphicsAPI for Dx12 {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::Dx12
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GraphicsAPI for WebGPU {
|
||||||
|
fn backend() -> wgpu::Backend {
|
||||||
|
wgpu::Backend::BrowserWebGpu
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,87 @@
|
||||||
|
use crate::element::WGPUElement;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
/// Generate wgpu kernel source code to create [compute shader modules](wgpu::ShaderModule).
|
||||||
|
pub trait KernelGenerator: 'static {
|
||||||
|
/// Source code concrete type.
|
||||||
|
type Source: AsRef<str>;
|
||||||
|
|
||||||
|
/// Generate the source code.
|
||||||
|
fn generate() -> Self::Source;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! kernel_wgsl {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
$file:expr
|
||||||
|
) => {
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = &'static str;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
include_str!($file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate kernel source code by replacing some information using templating.
|
||||||
|
pub struct KernelSettings<
|
||||||
|
K: KernelGenerator,
|
||||||
|
E: WGPUElement,
|
||||||
|
I: WGPUElement,
|
||||||
|
const WORKGROUP_X_SIZE: usize,
|
||||||
|
const WORKGROUP_Y_SIZE: usize,
|
||||||
|
const WORKGROUP_Z_SIZE: usize,
|
||||||
|
> {
|
||||||
|
_k: PhantomData<K>,
|
||||||
|
_e: PhantomData<E>,
|
||||||
|
_i: PhantomData<I>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
K: KernelGenerator,
|
||||||
|
E: WGPUElement,
|
||||||
|
I: WGPUElement,
|
||||||
|
const WORKGROUP_X_SIZE: usize,
|
||||||
|
const WORKGROUP_Y_SIZE: usize,
|
||||||
|
const WORKGROUP_Z_SIZE: usize,
|
||||||
|
> KernelGenerator
|
||||||
|
for KernelSettings<K, E, I, WORKGROUP_X_SIZE, WORKGROUP_Y_SIZE, WORKGROUP_Z_SIZE>
|
||||||
|
{
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> String {
|
||||||
|
let mut source = K::generate().as_ref().to_string();
|
||||||
|
|
||||||
|
source = source.replace("WORKGROUP_SIZE_X", &WORKGROUP_X_SIZE.to_string());
|
||||||
|
source = source.replace("WORKGROUP_SIZE_Y", &WORKGROUP_Y_SIZE.to_string());
|
||||||
|
source = source.replace("WORKGROUP_SIZE_Z", &WORKGROUP_Y_SIZE.to_string());
|
||||||
|
source = source.replace("elem", E::type_name());
|
||||||
|
source = source.replace("int", I::type_name());
|
||||||
|
|
||||||
|
source
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use core::any::TypeId;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_type_id() {
|
||||||
|
kernel_wgsl!(Add, "../template/binary_elemwise.wgsl");
|
||||||
|
|
||||||
|
let type_id_1 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||||
|
let type_id_2 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 5>>();
|
||||||
|
let type_id_3 = TypeId::of::<KernelSettings<Add, f32, i32, 2, 3, 4>>();
|
||||||
|
|
||||||
|
assert_ne!(type_id_1, type_id_2);
|
||||||
|
assert_eq!(type_id_1, type_id_3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
use super::{KernelGenerator, KernelSettings};
|
||||||
|
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||||
|
use burn_tensor::Shape;
|
||||||
|
use num_traits::ToPrimitive;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
kernel_wgsl!(BinaryElemwiseRaw, "../template/binary_elemwise.wgsl");
|
||||||
|
kernel_wgsl!(
|
||||||
|
BinaryElemwiseInplaceRaw,
|
||||||
|
"../template/binary_elemwise_inplace.wgsl"
|
||||||
|
);
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! binary_elemwise {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
$ops:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::BinaryElemwiseRaw::generate().to_string();
|
||||||
|
let body = format!(
|
||||||
|
"output[global_id.x] = lhs[index_lhs] {} rhs[index_rhs]",
|
||||||
|
$ops
|
||||||
|
);
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! binary_elemwise_inplace {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
$ops:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::BinaryElemwiseInplaceRaw::generate().to_string();
|
||||||
|
let body = format!(
|
||||||
|
"lhs[global_id.x] = lhs[global_id.x] {} rhs[index_rhs];",
|
||||||
|
$ops
|
||||||
|
);
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_elemwise<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
lhs.assert_is_on_save_device(&rhs);
|
||||||
|
|
||||||
|
let mut shape_out = [0; D];
|
||||||
|
lhs.shape
|
||||||
|
.dims
|
||||||
|
.iter()
|
||||||
|
.zip(rhs.shape.dims.iter())
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||||
|
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||||
|
});
|
||||||
|
|
||||||
|
let shape_out = Shape::new(shape_out);
|
||||||
|
|
||||||
|
let buffer = lhs
|
||||||
|
.context
|
||||||
|
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||||
|
let output = WGPUTensor::new(lhs.context.clone(), shape_out, Arc::new(buffer));
|
||||||
|
|
||||||
|
let kernel = lhs
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
|
||||||
|
|
||||||
|
lhs.strides
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
rhs.strides
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
lhs.shape
|
||||||
|
.dims
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
rhs.shape
|
||||||
|
.dims
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
let info_buffers = lhs
|
||||||
|
.context
|
||||||
|
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||||
|
|
||||||
|
lhs.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||||
|
);
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
pub fn binary_elemwise_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
lhs.assert_is_on_save_device(&rhs);
|
||||||
|
|
||||||
|
let mut shape_out = [0; D];
|
||||||
|
lhs.shape
|
||||||
|
.dims
|
||||||
|
.iter()
|
||||||
|
.zip(rhs.shape.dims.iter())
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||||
|
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||||
|
});
|
||||||
|
|
||||||
|
let kernel = lhs
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
let mut info: Vec<u32> = vec![D.to_u32().unwrap()];
|
||||||
|
rhs.strides
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
rhs.shape
|
||||||
|
.dims
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|v| info.push(v.to_u32().unwrap()));
|
||||||
|
let info_buffers = lhs
|
||||||
|
.context
|
||||||
|
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||||
|
|
||||||
|
lhs.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&lhs.buffer, &rhs.buffer, &info_buffers],
|
||||||
|
);
|
||||||
|
|
||||||
|
lhs
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
mod base;
|
||||||
|
mod binary_elemwise;
|
||||||
|
mod unary;
|
||||||
|
mod unary_scalar;
|
||||||
|
|
||||||
|
pub use base::*;
|
||||||
|
pub use binary_elemwise::*;
|
||||||
|
pub use unary::*;
|
||||||
|
pub use unary_scalar::*;
|
|
@ -0,0 +1,120 @@
|
||||||
|
use super::{KernelGenerator, KernelSettings};
|
||||||
|
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||||
|
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! unary {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
func $func:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryRaw::generate().to_string();
|
||||||
|
let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func);
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
body $body:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryRaw::generate().to_string();
|
||||||
|
source.replace("BODY", $body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! unary_inplace {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
func $func:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
|
||||||
|
let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func);
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
body $body:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
|
||||||
|
source.replace("BODY", $body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
input: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let buffer = input
|
||||||
|
.context
|
||||||
|
.create_buffer(input.shape.num_elements() * core::mem::size_of::<E>());
|
||||||
|
let output = WGPUTensor::new(input.context.clone(), input.shape, Arc::new(buffer));
|
||||||
|
let kernel = input
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
|
||||||
|
input.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&input.buffer, &output.buffer],
|
||||||
|
);
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
input: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let kernel = input
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
|
||||||
|
input.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(input.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&input.buffer],
|
||||||
|
);
|
||||||
|
|
||||||
|
input
|
||||||
|
}
|
|
@ -0,0 +1,135 @@
|
||||||
|
use super::{KernelGenerator, KernelSettings};
|
||||||
|
use crate::{context::WorkGroup, element::WGPUElement, kernel_wgsl, tensor::WGPUTensor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
||||||
|
kernel_wgsl!(
|
||||||
|
UnaryScalarInplaceRaw,
|
||||||
|
"../template/unary_scalar_inplace.wgsl"
|
||||||
|
);
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! unary_scalar {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
ops $ops:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
|
||||||
|
let body = format!("output[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
|
||||||
|
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
func $func:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryScalarRaw::generate().to_string();
|
||||||
|
let body = format!("output[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
|
||||||
|
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! unary_scalar_inplace {
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
ops $ops:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
|
||||||
|
let body = format!("lhs[global_id.x] = lhs[global_id.x] {} rhs;", $ops);
|
||||||
|
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
(
|
||||||
|
$struct:ident,
|
||||||
|
func $func:expr
|
||||||
|
) => {
|
||||||
|
pub struct $struct;
|
||||||
|
|
||||||
|
impl $crate::kernel::KernelGenerator for $struct {
|
||||||
|
type Source = String;
|
||||||
|
|
||||||
|
fn generate() -> Self::Source {
|
||||||
|
let source = $crate::kernel::UnaryScalarInplaceRaw::generate().to_string();
|
||||||
|
let body = format!("lhs[global_id.x] = {}(lhs[global_id.x], rhs);", $func);
|
||||||
|
|
||||||
|
source.replace("BODY", &body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_scalar<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
scalar: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let buffer = lhs
|
||||||
|
.context
|
||||||
|
.create_buffer(lhs.shape.num_elements() * core::mem::size_of::<E>());
|
||||||
|
let output = WGPUTensor::new(lhs.context.clone(), lhs.shape, Arc::new(buffer));
|
||||||
|
let kernel = lhs
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||||
|
|
||||||
|
lhs.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&lhs.buffer, &rhs_buffer, &output.buffer],
|
||||||
|
);
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_scalar_inplace<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
scalar: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let kernel = lhs
|
||||||
|
.context
|
||||||
|
.compile::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||||
|
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||||
|
|
||||||
|
lhs.context.execute(
|
||||||
|
&WorkGroup::new(
|
||||||
|
f32::ceil(lhs.shape.num_elements() as f32 / 256_f32) as u32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
&kernel,
|
||||||
|
&[&lhs.buffer, &rhs_buffer],
|
||||||
|
);
|
||||||
|
|
||||||
|
lhs
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate derive_new;
|
||||||
|
|
||||||
|
mod ops;
|
||||||
|
|
||||||
|
pub(crate) mod context;
|
||||||
|
pub(crate) mod element;
|
||||||
|
pub(crate) mod kernel;
|
||||||
|
pub(crate) mod pool;
|
||||||
|
pub(crate) mod tensor;
|
||||||
|
|
||||||
|
mod device;
|
||||||
|
pub use device::*;
|
||||||
|
|
||||||
|
mod backend;
|
||||||
|
pub use backend::*;
|
||||||
|
|
||||||
|
mod graphics;
|
||||||
|
pub use graphics::*;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
type TestBackend = crate::WGPUBackend<crate::Vulkan, f32, i64>;
|
||||||
|
|
||||||
|
burn_tensor::testgen_add!();
|
||||||
|
burn_tensor::testgen_sub!();
|
||||||
|
burn_tensor::testgen_div!();
|
||||||
|
burn_tensor::testgen_mul!();
|
||||||
|
burn_tensor::testgen_neg!();
|
||||||
|
burn_tensor::testgen_powf!();
|
||||||
|
burn_tensor::testgen_exp!();
|
||||||
|
burn_tensor::testgen_log!();
|
||||||
|
burn_tensor::testgen_relu!();
|
||||||
|
|
||||||
|
// Once all operations will be implemented.
|
||||||
|
// type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||||
|
// type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||||
|
// burn_tensor::testgen_all!();
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
use burn_tensor::ops::ActivationOps;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
kernel::{unary, unary_inplace},
|
||||||
|
unary, unary_inplace, GraphicsAPI, WGPUBackend,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::FloatTensor;
|
||||||
|
|
||||||
|
impl<G, F, I> ActivationOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsAPI + 'static,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
unary!(Relu, body "output[global_id.x] = max(input[global_id.x], 0.0);");
|
||||||
|
unary_inplace!(ReluInplace, body "input[global_id.x] = max(input[global_id.x], 0.0);");
|
||||||
|
|
||||||
|
if tensor.can_mut() {
|
||||||
|
return unary_inplace::<ReluInplace, F, D>(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary::<Relu, F, D>(tensor)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
use std::{marker::PhantomData, sync::Arc};
|
||||||
|
|
||||||
|
use burn_tensor::{backend::Backend, Data, Shape};
|
||||||
|
|
||||||
|
use crate::{element::WGPUElement, pool::get_context, tensor::WGPUTensor, GraphicsAPI, WGPUDevice};
|
||||||
|
|
||||||
|
pub type FloatElem<B> = <B as Backend>::FloatElem;
|
||||||
|
pub type Device<B> = <B as Backend>::Device;
|
||||||
|
|
||||||
|
pub type FloatTensor<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
|
||||||
|
|
||||||
|
pub type IntElem<B> = <B as Backend>::IntElem;
|
||||||
|
pub type IntTensor<B, const D: usize> = <B as Backend>::IntTensorPrimitive<D>;
|
||||||
|
pub type BoolTensor<B, const D: usize> = <B as Backend>::BoolTensorPrimitive<D>;
|
||||||
|
|
||||||
|
pub struct BaseOps<G: GraphicsAPI> {
|
||||||
|
_g: PhantomData<G>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<G: GraphicsAPI> BaseOps<G> {
|
||||||
|
pub fn from_data<E: WGPUElement, const D: usize>(
|
||||||
|
data: Data<E, D>,
|
||||||
|
device: &WGPUDevice,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let context = get_context::<G>(device);
|
||||||
|
let buffer = context.create_buffer_with_data(E::as_bytes(&data.value));
|
||||||
|
|
||||||
|
WGPUTensor::new(context, data.shape, Arc::new(buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_data<E: WGPUElement, const D: usize>(tensor: &WGPUTensor<E, D>) -> Data<E, D> {
|
||||||
|
let bytes = tensor.context.buffer_to_data(&tensor.buffer);
|
||||||
|
let values = E::from_bytes(&bytes);
|
||||||
|
|
||||||
|
Data::new(values.to_vec(), tensor.shape.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_device<E: WGPUElement, const D: usize>(
|
||||||
|
tensor: WGPUTensor<E, D>,
|
||||||
|
device: &WGPUDevice,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
if &tensor.context.device == device {
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
let context = get_context::<G>(device);
|
||||||
|
tensor.to_context(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty<E: WGPUElement, const D: usize>(
|
||||||
|
shape: Shape<D>,
|
||||||
|
device: &WGPUDevice,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
let context = get_context::<G>(device);
|
||||||
|
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||||
|
|
||||||
|
WGPUTensor::new(context, shape, Arc::new(buffer))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
GraphicsAPI, WGPUBackend,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{BaseOps, BoolTensor, Device, IntTensor};
|
||||||
|
|
||||||
|
impl<G, F, I> BoolTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsAPI + 'static,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||||
|
BaseOps::<G>::empty(shape, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_shape<const D: usize>(tensor: &BoolTensor<Self, D>) -> Shape<D> {
|
||||||
|
tensor.shape.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
|
||||||
|
let data = BaseOps::<G>::to_data(&tensor);
|
||||||
|
|
||||||
|
Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_from_data<const D: usize>(
|
||||||
|
data: Data<bool, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> BoolTensor<Self, D> {
|
||||||
|
let data: Data<u32, D> = Data::new(
|
||||||
|
data.value
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| match c {
|
||||||
|
true => 1,
|
||||||
|
false => 0,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
data.shape,
|
||||||
|
);
|
||||||
|
BaseOps::<G>::from_data(data, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_into_int<const D: usize>(_tensor: BoolTensor<Self, D>) -> IntTensor<Self, D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_device<const D: usize>(
|
||||||
|
_tensor: &<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_to_device<const D: usize>(
|
||||||
|
tensor: BoolTensor<Self, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> BoolTensor<Self, D> {
|
||||||
|
BaseOps::<G>::to_device(tensor, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||||
|
_shape: Shape<D2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D2> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_index<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_cat<const D: usize>(
|
||||||
|
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bool_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_rhs: bool,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,406 @@
|
||||||
|
use super::numeric::NumericOps;
|
||||||
|
use super::{BaseOps, Device, FloatElem, FloatTensor};
|
||||||
|
use crate::kernel::{unary, unary_inplace, unary_scalar, unary_scalar_inplace};
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
unary, unary_inplace, GraphicsAPI, WGPUBackend, SEED,
|
||||||
|
};
|
||||||
|
use crate::{unary_scalar, unary_scalar_inplace};
|
||||||
|
use burn_common::rand::get_seeded_rng;
|
||||||
|
use burn_tensor::ElementConversion;
|
||||||
|
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, Shape};
|
||||||
|
|
||||||
|
impl<G, F, I> TensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsAPI + 'static,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
fn from_data<const D: usize>(
|
||||||
|
data: Data<FloatElem<Self>, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
BaseOps::<G>::from_data(data, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn random<const D: usize>(
|
||||||
|
shape: Shape<D>,
|
||||||
|
distribution: Distribution<FloatElem<Self>>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
let mut seed = SEED.lock().unwrap();
|
||||||
|
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
|
||||||
|
rng_seeded.clone()
|
||||||
|
} else {
|
||||||
|
get_seeded_rng()
|
||||||
|
};
|
||||||
|
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
|
||||||
|
*seed = Some(rng);
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
|
||||||
|
tensor.shape.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
|
||||||
|
BaseOps::<G>::to_data(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||||
|
tensor.context.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_device<const D: usize>(
|
||||||
|
tensor: FloatTensor<Self, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
BaseOps::<G>::to_device(tensor, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||||
|
BaseOps::<G>::empty(shape, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatTensor<Self, D>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_scalar<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatElem<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::add_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatTensor<Self, D>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::sub(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub_scalar<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatElem<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::sub_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mul<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatTensor<Self, D>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::mul(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mul_scalar<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatElem<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::mul_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn div<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatTensor<Self, D>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::div(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn div_scalar<const D: usize>(
|
||||||
|
lhs: FloatTensor<Self, D>,
|
||||||
|
rhs: FloatElem<Self>,
|
||||||
|
) -> FloatTensor<Self, D> {
|
||||||
|
NumericOps::div_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn swap_dims<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim1: usize,
|
||||||
|
_dim2: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reshape<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||||
|
_shape: Shape<D2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gather<const D: usize>(
|
||||||
|
_dim: usize,
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter<const D: usize>(
|
||||||
|
_dim: usize,
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||||
|
_dim: usize,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_assign<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mask_scatter<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_source: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mask_fill<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn greater<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn greater_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn greater_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn greater_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::FloatElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_dim<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mean<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mean_dim<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_full_precision<const D: usize>(
|
||||||
|
_tensor: &<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>
|
||||||
|
{
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_full_precision<const D: usize>(
|
||||||
|
_tensor: <<WGPUBackend<G, F, I> as Backend>::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
unary!(Exp, func "exp");
|
||||||
|
unary_inplace!(ExpInplace, func "exp");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_inplace::<ExpInplace, F, D>(lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary::<Exp, F, D>(lhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
unary!(Log, func "log");
|
||||||
|
unary_inplace!(LogInplace, func "log");
|
||||||
|
|
||||||
|
if tensor.can_mut() {
|
||||||
|
return unary_inplace::<LogInplace, F, D>(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary::<Log, F, D>(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log1p<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||||
|
unary_scalar!(Powf, func "pow");
|
||||||
|
unary_scalar_inplace!(PowfInplace, func "pow");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_scalar_inplace::<PowfInplace, F, D>(lhs, rhs.elem());
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_scalar::<Powf, F, D>(lhs, rhs.elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sqrt<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cos<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sin<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tanh<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn erf<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cat<const D: usize>(
|
||||||
|
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn argmax<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn argmin<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,314 @@
|
||||||
|
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
GraphicsAPI, WGPUBackend,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{numeric::NumericOps, BaseOps, Device, IntElem, IntTensor};
|
||||||
|
|
||||||
|
impl<G, F, I> IntTensorOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsAPI + 'static,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||||
|
BaseOps::<G>::empty(shape, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_shape<const D: usize>(
|
||||||
|
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> Shape<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
|
||||||
|
BaseOps::<G>::to_data(&tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_from_data<const D: usize>(
|
||||||
|
data: Data<I, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
BaseOps::<G>::from_data(data, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_device<const D: usize>(
|
||||||
|
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_to_device<const D: usize>(
|
||||||
|
tensor: IntTensor<Self, D>,
|
||||||
|
device: &Device<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
BaseOps::<G>::to_device(tensor, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_reshape<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||||
|
_shape: Shape<D2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||||
|
_indexes: [std::ops::Range<usize>; D2],
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mask_scatter<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_source: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mask_fill<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_mask: <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_gather<const D: usize>(
|
||||||
|
_dim: usize,
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_scatter<const D: usize>(
|
||||||
|
_dim: usize,
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_select_dim<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||||
|
_dim: usize,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||||
|
_value: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_cat<const D: usize>(
|
||||||
|
_tensors: Vec<<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_greater<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_greater_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_greater_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_greater_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_lower<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_lower_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_lower_equal<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_lower_equal_elem<const D: usize>(
|
||||||
|
_lhs: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_rhs: <WGPUBackend<G, F, I> as Backend>::IntElem,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_add<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntTensor<Self, D>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::add::<I, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_add_scalar<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntElem<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::add_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_sub<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntTensor<Self, D>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::sub(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_sub_scalar<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntElem<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::sub_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mul<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntTensor<Self, D>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::mul(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mul_scalar<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntElem<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::mul_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_div<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntTensor<Self, D>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::div(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_div_scalar<const D: usize>(
|
||||||
|
lhs: IntTensor<Self, D>,
|
||||||
|
rhs: IntElem<Self>,
|
||||||
|
) -> IntTensor<Self, D> {
|
||||||
|
NumericOps::div_scalar(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_neg<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_zeros<const D: usize>(
|
||||||
|
_shape: Shape<D>,
|
||||||
|
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_ones<const D: usize>(
|
||||||
|
_shape: Shape<D>,
|
||||||
|
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_sum<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_sum_dim<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mean<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<1> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_mean_dim<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_argmax<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_argmin<const D: usize>(
|
||||||
|
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||||
|
_dim: usize,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
mod activation_ops;
|
||||||
|
mod bool_ops;
|
||||||
|
mod float_ops;
|
||||||
|
mod int_ops;
|
||||||
|
mod module_ops;
|
||||||
|
|
||||||
|
mod base;
|
||||||
|
pub(crate) use base::*;
|
||||||
|
|
||||||
|
pub(crate) mod numeric;
|
|
@ -0,0 +1,94 @@
|
||||||
|
use burn_tensor::{backend::Backend, ops::ModuleOps};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
element::{FloatElement, IntElement},
|
||||||
|
GraphicsAPI, WGPUBackend,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<G, F, I> ModuleOps<WGPUBackend<G, F, I>> for WGPUBackend<G, F, I>
|
||||||
|
where
|
||||||
|
G: GraphicsAPI + 'static,
|
||||||
|
F: FloatElement,
|
||||||
|
I: IntElement,
|
||||||
|
{
|
||||||
|
fn embedding(
|
||||||
|
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embedding_backward(
|
||||||
|
_weights: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2>,
|
||||||
|
_output: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<3>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<2> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv2d(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
|
||||||
|
_options: burn_tensor::ops::ConvOptions<2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv_transpose2d(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_weight: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_bias: Option<<WGPUBackend<G, F, I> as Backend>::TensorPrimitive<1>>,
|
||||||
|
_options: burn_tensor::ops::ConvTransposeOptions<2>,
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_kernel_size: [usize; 2],
|
||||||
|
_stride: [usize; 2],
|
||||||
|
_padding: [usize; 2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d_backward(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_kernel_size: [usize; 2],
|
||||||
|
_stride: [usize; 2],
|
||||||
|
_padding: [usize; 2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_kernel_size: [usize; 2],
|
||||||
|
_stride: [usize; 2],
|
||||||
|
_padding: [usize; 2],
|
||||||
|
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d_with_indexes(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_kernel_size: [usize; 2],
|
||||||
|
_stride: [usize; 2],
|
||||||
|
_padding: [usize; 2],
|
||||||
|
) -> burn_tensor::ops::MaxPool2dWithIndexes<WGPUBackend<G, F, I>> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pool2d_with_indexes_backward(
|
||||||
|
_x: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_kernel_size: [usize; 2],
|
||||||
|
_stride: [usize; 2],
|
||||||
|
_padding: [usize; 2],
|
||||||
|
_output_grad: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||||
|
_indexes: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
|
||||||
|
) -> burn_tensor::ops::MaxPool2dBackward<WGPUBackend<G, F, I>> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,129 @@
|
||||||
|
use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace};
|
||||||
|
use crate::{
|
||||||
|
binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WGPUTensor,
|
||||||
|
unary_scalar, unary_scalar_inplace,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct NumericOps;
|
||||||
|
|
||||||
|
impl NumericOps {
|
||||||
|
pub fn add<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
binary_elemwise!(Add, "+");
|
||||||
|
binary_elemwise_inplace!(AddInplace, "+");
|
||||||
|
|
||||||
|
if lhs.can_mut_broadcast(&rhs) {
|
||||||
|
return binary_elemwise_inplace::<AddInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
if rhs.can_mut_broadcast(&lhs) {
|
||||||
|
return binary_elemwise_inplace::<AddInplace, E, D>(rhs, lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
binary_elemwise::<Add, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_scalar<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
unary_scalar!(AddScalar, ops "+");
|
||||||
|
unary_scalar_inplace!(AddScalarInplace, ops "+");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_scalar_inplace::<AddScalarInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_scalar::<AddScalar, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sub<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
binary_elemwise!(Sub, "-");
|
||||||
|
binary_elemwise_inplace!(SubInplace, "-");
|
||||||
|
|
||||||
|
if lhs.can_mut_broadcast(&rhs) {
|
||||||
|
return binary_elemwise_inplace::<SubInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
binary_elemwise::<Sub, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sub_scalar<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
unary_scalar!(SubScalar, ops "-");
|
||||||
|
unary_scalar_inplace!(SubScalarInplace, ops "-");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_scalar_inplace::<SubScalarInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_scalar::<SubScalar, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
binary_elemwise!(Mul, "*");
|
||||||
|
binary_elemwise_inplace!(MulInplace, "*");
|
||||||
|
|
||||||
|
if lhs.can_mut_broadcast(&rhs) {
|
||||||
|
return binary_elemwise_inplace::<MulInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
if rhs.can_mut_broadcast(&lhs) {
|
||||||
|
return binary_elemwise_inplace::<MulInplace, E, D>(rhs, lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
binary_elemwise::<Mul, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul_scalar<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
unary_scalar!(MulScalar, ops "*");
|
||||||
|
unary_scalar_inplace!(MulScalarInplace, ops "*");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_scalar_inplace::<MulScalarInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_scalar::<MulScalar, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn div<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: WGPUTensor<E, D>,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
binary_elemwise!(Div, "/");
|
||||||
|
binary_elemwise_inplace!(DivInplace, "/");
|
||||||
|
|
||||||
|
if lhs.can_mut_broadcast(&rhs) {
|
||||||
|
return binary_elemwise_inplace::<DivInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
binary_elemwise::<Div, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn div_scalar<E: WGPUElement, const D: usize>(
|
||||||
|
lhs: WGPUTensor<E, D>,
|
||||||
|
rhs: E,
|
||||||
|
) -> WGPUTensor<E, D> {
|
||||||
|
unary_scalar!(DivScalar, ops "/");
|
||||||
|
unary_scalar_inplace!(DivScalarInplace, ops "/");
|
||||||
|
|
||||||
|
if lhs.can_mut() {
|
||||||
|
return unary_scalar_inplace::<DivScalarInplace, E, D>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_scalar::<DivScalar, E, D>(lhs, rhs)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
use crate::{context::Context, GraphicsAPI, WGPUDevice};
|
||||||
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
|
static POOL_CONTEXT: Mutex<Option<ContextPool>> = Mutex::new(None);
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct ContextPool {
|
||||||
|
contexts: HashMap<Key, Arc<Context>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||||
|
struct Key {
|
||||||
|
api_id: TypeId,
|
||||||
|
device: WGPUDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Key {
|
||||||
|
fn new<G: GraphicsAPI>(device: &WGPUDevice) -> Self {
|
||||||
|
Self {
|
||||||
|
api_id: TypeId::of::<G>(),
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a [context](Context) for the given [device](WGPUDevice).
|
||||||
|
///
|
||||||
|
/// # Notes
|
||||||
|
///
|
||||||
|
/// If a context already exist for the current [device](WGPUDevice), the same instance will be
|
||||||
|
/// returned.
|
||||||
|
pub fn get_context<G: GraphicsAPI>(device: &WGPUDevice) -> Arc<Context> {
|
||||||
|
let mut pool = POOL_CONTEXT.lock().unwrap();
|
||||||
|
|
||||||
|
let context = if let Some(pool) = pool.as_mut() {
|
||||||
|
// Fetch device in pool
|
||||||
|
match pool.contexts.get(&Key::new::<G>(device)) {
|
||||||
|
Some(context) => context.clone(),
|
||||||
|
None => {
|
||||||
|
// Init new device
|
||||||
|
let context = Arc::new(Context::new::<G>(device));
|
||||||
|
pool.contexts.insert(Key::new::<G>(device), context.clone());
|
||||||
|
context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Initialize pool
|
||||||
|
let context = Arc::new(Context::new::<G>(device));
|
||||||
|
let mut new_pool = ContextPool::default();
|
||||||
|
|
||||||
|
new_pool
|
||||||
|
.contexts
|
||||||
|
.insert(Key::new::<G>(device), context.clone());
|
||||||
|
*pool = Some(new_pool);
|
||||||
|
context
|
||||||
|
};
|
||||||
|
|
||||||
|
context
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read> lhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(1)
|
||||||
|
var<storage, read> rhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(2)
|
||||||
|
var<storage, read_write> output: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(3)
|
||||||
|
var<storage, read> info: array<u32>;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
let dim: u32 = info[0];
|
||||||
|
var index_lhs: u32 = 0u;
|
||||||
|
var index_rhs: u32 = 0u;
|
||||||
|
|
||||||
|
for (var i: u32 = 0u; i < dim; i++) {
|
||||||
|
let stride_lhs = info[i + 1u];
|
||||||
|
let stride_rhs = info[i + 1u * dim + 1u];
|
||||||
|
let shape_lhs = info[i + 2u * dim + 1u];
|
||||||
|
let shape_rhs = info[i + 3u * dim + 1u];
|
||||||
|
|
||||||
|
index_lhs += global_id.x / stride_lhs % shape_lhs * stride_lhs;
|
||||||
|
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read_write> lhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(1)
|
||||||
|
var<storage, read> rhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(2)
|
||||||
|
var<storage, read> info: array<u32>;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
let dim: u32 = info[0];
|
||||||
|
var index_rhs: u32 = 0u;
|
||||||
|
|
||||||
|
for (var i: u32 = 0u; i < dim; i++) {
|
||||||
|
let stride_rhs = info[i + 1u];
|
||||||
|
let shape_rhs = info[i + 1u * dim + 1u];
|
||||||
|
|
||||||
|
index_rhs += global_id.x / stride_rhs % shape_rhs * stride_rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read> input: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(1)
|
||||||
|
var<storage, read_write> output: array<elem>;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read_write> input: array<elem>;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read> lhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(1)
|
||||||
|
var<storage, read> rhs: elem;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(2)
|
||||||
|
var<storage, read_write> output: array<elem>;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
@group(0)
|
||||||
|
@binding(0)
|
||||||
|
var<storage, read_write> lhs: array<elem>;
|
||||||
|
|
||||||
|
@group(0)
|
||||||
|
@binding(1)
|
||||||
|
var<storage, read> rhs: elem;
|
||||||
|
|
||||||
|
@compute
|
||||||
|
@workgroup_size(WORKGROUP_SIZE_X, 1, 1)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
BODY
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
use burn_tensor::Shape;
|
||||||
|
use std::{marker::PhantomData, sync::Arc};
|
||||||
|
use wgpu::Buffer;
|
||||||
|
|
||||||
|
use crate::{context::Context, element::WGPUElement};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WGPUTensor<E: WGPUElement, const D: usize> {
|
||||||
|
pub(crate) context: Arc<Context>,
|
||||||
|
pub(crate) buffer: Arc<Buffer>,
|
||||||
|
pub(crate) shape: Shape<D>,
|
||||||
|
pub(crate) strides: [usize; D],
|
||||||
|
elem: PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: WGPUElement, const D: usize> WGPUTensor<E, D> {
|
||||||
|
pub fn new(context: Arc<Context>, shape: Shape<D>, buffer: Arc<Buffer>) -> Self {
|
||||||
|
let mut strides = [0; D];
|
||||||
|
|
||||||
|
let mut current = 1;
|
||||||
|
shape
|
||||||
|
.dims
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.rev()
|
||||||
|
.for_each(|(index, val)| {
|
||||||
|
strides[index] = current;
|
||||||
|
current *= val;
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
context,
|
||||||
|
buffer,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
elem: PhantomData::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn to_context(&self, context: Arc<Context>) -> Self {
|
||||||
|
let data = self.context.buffer_to_data(&self.buffer);
|
||||||
|
let buffer = Arc::new(context.create_buffer_with_data(&data));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
context,
|
||||||
|
buffer,
|
||||||
|
shape: self.shape.clone(),
|
||||||
|
strides: self.strides,
|
||||||
|
elem: PhantomData::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn can_mut_broadcast(&self, tensor_other: &WGPUTensor<E, D>) -> bool {
|
||||||
|
if Arc::strong_count(&self.buffer) > 1 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..D {
|
||||||
|
// Output tensor will be different from the mutable tensor.
|
||||||
|
if self.shape.dims[i] < tensor_other.shape.dims[i] {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn can_mut(&self) -> bool {
|
||||||
|
if Arc::strong_count(&self.buffer) > 1 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assert_is_on_save_device(&self, other: &Self) {
|
||||||
|
if self.context.device != other.context.device {
|
||||||
|
panic!(
|
||||||
|
"Both tensors should be on the same device {:?} != {:?}",
|
||||||
|
self.context.device, other.context.device
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
mod base;
|
||||||
|
pub use base::*;
|
Loading…
Reference in New Issue