From 24014aca33a130d8e235b94a64f78e5227cc9c7a Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 15 Nov 2023 15:13:37 -0500 Subject: [PATCH] WGPU: Support elemwise operation fusion (#948) --- backend-comparison/Cargo.toml | 1 + backend-comparison/src/lib.rs | 8 + burn-fusion/src/backend.rs | 14 +- burn-fusion/src/client/base.rs | 24 +- burn-fusion/src/client/mutex.rs | 47 +- burn-fusion/src/graph/execution.rs | 9 +- burn-fusion/src/graph/ops.rs | 34 ++ burn-fusion/src/handle.rs | 171 ++------ burn-fusion/src/ops/boolean.rs | 25 +- burn-fusion/src/ops/float.rs | 133 +++--- burn-fusion/src/ops/int.rs | 107 ++--- burn-fusion/src/ops/module.rs | 40 +- burn-fusion/src/server.rs | 28 +- burn-fusion/src/tensor.rs | 7 +- burn-tensor/src/tensor/ops/modules/base.rs | 4 +- burn-wgpu/Cargo.toml | 5 + burn-wgpu/benches/fused_elemwise.rs | 74 ++++ burn-wgpu/src/compute/server.rs | 4 +- burn-wgpu/src/{fusion.rs => fusion/base.rs} | 27 +- burn-wgpu/src/fusion/codegen/body.rs | 27 ++ burn-wgpu/src/fusion/codegen/function.rs | 71 +++ burn-wgpu/src/fusion/codegen/mod.rs | 11 + burn-wgpu/src/fusion/codegen/operator.rs | 146 ++++++ burn-wgpu/src/fusion/codegen/shader.rs | 201 +++++++++ burn-wgpu/src/fusion/codegen/variable.rs | 21 + burn-wgpu/src/fusion/elemwise/mod.rs | 3 + burn-wgpu/src/fusion/elemwise/ops.rs | 463 ++++++++++++++++++++ burn-wgpu/src/fusion/kernel.rs | 320 ++++++++++++++ burn-wgpu/src/fusion/mod.rs | 8 + burn-wgpu/src/kernel/base.rs | 52 ++- 30 files changed, 1718 insertions(+), 367 deletions(-) create mode 100644 burn-wgpu/benches/fused_elemwise.rs rename burn-wgpu/src/{fusion.rs => fusion/base.rs} (83%) create mode 100644 burn-wgpu/src/fusion/codegen/body.rs create mode 100644 burn-wgpu/src/fusion/codegen/function.rs create mode 100644 burn-wgpu/src/fusion/codegen/mod.rs create mode 100644 burn-wgpu/src/fusion/codegen/operator.rs create mode 100644 burn-wgpu/src/fusion/codegen/shader.rs create mode 100644 burn-wgpu/src/fusion/codegen/variable.rs create mode 100644 burn-wgpu/src/fusion/elemwise/mod.rs create mode 100644 burn-wgpu/src/fusion/elemwise/ops.rs create mode 100644 burn-wgpu/src/fusion/kernel.rs create mode 100644 burn-wgpu/src/fusion/mod.rs diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 9c9e2baa0..432f6ab93 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -21,6 +21,7 @@ ndarray-blas-openblas = ["burn/ndarray-blas-openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] +wgpu-fusion = ["burn/wgpu", "burn/fusion"] [dependencies] burn = { path = "../burn" } diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index b78742df7..065b50f41 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -1,6 +1,14 @@ #[macro_export] macro_rules! bench_on_backend { () => { + #[cfg(feature = "wgpu-fusion")] + { + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + use burn::backend::Fusion; + + bench::>>(&WgpuDevice::default()); + } + #[cfg(feature = "wgpu")] { use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; diff --git a/burn-fusion/src/backend.rs b/burn-fusion/src/backend.rs index 7aed859f6..1a9a30604 100644 --- a/burn-fusion/src/backend.rs +++ b/burn-fusion/src/backend.rs @@ -2,7 +2,7 @@ use crate::{ client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor, HandleContainer, }; -use burn_tensor::{backend::Backend, Shape}; +use burn_tensor::{backend::Backend, Device, Shape}; use core::marker::PhantomData; use std::sync::Arc; @@ -36,12 +36,18 @@ impl Backend for Fusion { type BoolTensorPrimitive = FusionTensor; fn name() -> String { - format!("Fusion<{}>", B::name()) + format!("fusion<{}>", B::name()) } fn seed(seed: u64) { B::seed(seed); } + + fn sync(device: &Self::Device) { + let client = CLIENTS.client::(&device.clone().into()); + client.drain_graph(); + B::sync(device) + } } /// The status of a [fusion ops](FusionOps). @@ -116,14 +122,14 @@ pub trait FusionBackend: Backend { /// The device type that can return an ID. /// /// It can be the same as (Backend::Device), but must implement (FusionDevice). - type FusionDevice: FusionDevice + From + Into; + type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; /// The type that can be used to point to a tensor of any kind. type Handle: Sync + Send + Clone; /// What kind of client should be used. type FusionClient: FusionClient; /// The list of operations that will be used to optimize the computational graph. - fn operations() -> Vec>>; + fn operations(device: &Device) -> Vec>>; /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive). fn float_tensor( diff --git a/burn-fusion/src/client/base.rs b/burn-fusion/src/client/base.rs index c0a557a34..778c71030 100644 --- a/burn-fusion/src/client/base.rs +++ b/burn-fusion/src/client/base.rs @@ -1,6 +1,6 @@ use crate::{ graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionTensor, TensorDescription, TensorId, + FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, }; use burn_tensor::{ ops::{FloatElem, IntElem}, @@ -18,26 +18,18 @@ pub trait FusionClient: Send + Sync + Clone { fn new(device: ::FusionDevice) -> Self; /// Register a new [tensor operation description](TensorOpsDescription). fn register(&self, ops: TensorOpsDescription); - /// Sync the computation. - fn sync(&self); + /// Register all lazy computation. + fn drain_graph(&self); /// Get the current device used by all operations handled by this client. fn device(&self) -> &::FusionDevice; - /// Create an empty tensor. - fn create_tensor_empty(&self, shape: Vec) -> FusionTensor; - /// Create a float tensor with the given values. - fn create_tensor_float( + /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; + /// Create a tensor with the given handle and shape. + fn register_tensor( &self, - values: Vec>, + handle: Handle, shape: Vec, ) -> FusionTensor; - /// Create an integer tensor with the given values. - fn create_tensor_int( - &self, - values: Vec>, - shape: Vec, - ) -> FusionTensor; - /// Create a bool tensor with the given values. - fn create_tensor_bool(&self, values: Vec, shape: Vec) -> FusionTensor; /// Read the values contained by a float tensor. fn read_tensor_float( &self, diff --git a/burn-fusion/src/client/mutex.rs b/burn-fusion/src/client/mutex.rs index 2d0c823a8..db4bceb55 100644 --- a/burn-fusion/src/client/mutex.rs +++ b/burn-fusion/src/client/mutex.rs @@ -1,7 +1,7 @@ use super::FusionClient; use crate::{ graph::{GraphExecution, TensorOpsDescription}, - FusionBackend, FusionServer, FusionTensor, + FusionBackend, FusionServer, FusionTensor, Handle, }; use burn_tensor::ops::FloatElem; use spin::Mutex; @@ -49,10 +49,11 @@ where self.server.lock().register(ops); } - fn sync(&self) { - self.server.lock().sync(); + fn drain_graph(&self) { + self.server.lock().drain_graph(); } - fn create_tensor_empty(&self, shape: Vec) -> FusionTensor { + + fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor { let id = self.server.lock().create_empty_handle(); FusionTensor::new(id, shape, self.clone()) @@ -61,6 +62,18 @@ where fn device(&self) -> &::FusionDevice { &self.device } + fn register_tensor( + &self, + handle: Handle, + shape: Vec, + ) -> FusionTensor { + let mut server = self.server.lock(); + let id = server.create_empty_handle(); + server.handles.register_handle(id.as_ref().clone(), handle); + core::mem::drop(server); + + FusionTensor::new(id, shape, self.clone()) + } fn read_tensor_float( &self, @@ -69,32 +82,6 @@ where self.server.lock().read_float(tensor) } - fn create_tensor_float( - &self, - values: Vec>, - shape: Vec, - ) -> FusionTensor { - let id = self.server.lock().create_float_handle(values); - - FusionTensor::new(id, shape, self.clone()) - } - - fn create_tensor_int( - &self, - values: Vec>, - shape: Vec, - ) -> FusionTensor { - let id = self.server.lock().create_int_handle(values); - - FusionTensor::new(id, shape, self.clone()) - } - - fn create_tensor_bool(&self, values: Vec, shape: Vec) -> FusionTensor { - let id = self.server.lock().create_bool_handle(values); - - FusionTensor::new(id, shape, self.clone()) - } - fn read_tensor_int( &self, tensor: crate::TensorDescription, diff --git a/burn-fusion/src/graph/execution.rs b/burn-fusion/src/graph/execution.rs index d9466ce0a..36cbf1a6d 100644 --- a/burn-fusion/src/graph/execution.rs +++ b/burn-fusion/src/graph/execution.rs @@ -32,8 +32,13 @@ impl GraphExecution for GreedyGraphExecution { } match find_best_optimization_index(optimizations) { - Some(index) => graph.execute_optimization(handles, index, optimizations), - None => graph.execute(handles), + Some(index) => { + graph.execute_optimization(handles, index, optimizations); + } + None => { + graph.execute(handles); + optimizations.iter_mut().for_each(|ops| ops.reset()); + } } if graph.is_empty() { diff --git a/burn-fusion/src/graph/ops.rs b/burn-fusion/src/graph/ops.rs index 5d7b60d41..3f437bd1f 100644 --- a/burn-fusion/src/graph/ops.rs +++ b/burn-fusion/src/graph/ops.rs @@ -5,6 +5,7 @@ use burn_tensor::{ ops::{ConvOptions, ConvTransposeOptions}, Distribution, Element, }; +use core::hash::Hash; use std::ops::Range; /// General trait to abstract how a single operation is executed. @@ -652,6 +653,7 @@ pub enum BoolOpsDescription { ), } +#[derive(Hash)] /// Swap dim operation description. pub struct SwapDimsDescription { /// Input tensor description. @@ -664,6 +666,7 @@ pub struct SwapDimsDescription { pub dim2: usize, } +#[derive(Hash)] #[allow(missing_docs)] pub struct ReshapeDescription { pub input: TensorDescription, @@ -671,6 +674,7 @@ pub struct ReshapeDescription { pub shape: Vec, } +#[derive(Hash)] #[allow(missing_docs)] pub struct BinaryOpsDescription { pub lhs: TensorDescription, @@ -678,6 +682,7 @@ pub struct BinaryOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct UnaryOpsDescription { pub input: TensorDescription, @@ -691,6 +696,7 @@ pub struct ScalarOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct GatherOpsDescription { pub tensor: TensorDescription, @@ -699,6 +705,7 @@ pub struct GatherOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct ScatterOpsDescription { pub tensor: TensorDescription, @@ -708,6 +715,7 @@ pub struct ScatterOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct SelectOpsDescription { pub tensor: TensorDescription, @@ -716,6 +724,7 @@ pub struct SelectOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct SelectAssignOpsDescription { pub tensor: TensorDescription, @@ -725,6 +734,7 @@ pub struct SelectAssignOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct SliceOpsDescription { pub tensor: TensorDescription, @@ -732,6 +742,7 @@ pub struct SliceOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct SliceAssignOpsDescription { pub tensor: TensorDescription, @@ -740,6 +751,7 @@ pub struct SliceAssignOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaskWhereOpsDescription { pub tensor: TensorDescription, @@ -773,6 +785,7 @@ pub struct RepeatOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct CatOpsDescription { pub tensors: Vec, @@ -780,6 +793,7 @@ pub struct CatOpsDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct ReduceDimWithIndicesDescription { pub tensor: TensorDescription, @@ -788,6 +802,7 @@ pub struct ReduceDimWithIndicesDescription { pub out_indices: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingDescription { pub weights: TensorDescription, @@ -795,6 +810,7 @@ pub struct EmbeddingDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct EmbeddingBackwardDescription { pub weights: TensorDescription, @@ -803,6 +819,7 @@ pub struct EmbeddingBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct Conv1dDescription { pub x: TensorDescription, @@ -812,6 +829,7 @@ pub struct Conv1dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct Conv2dDescription { pub x: TensorDescription, @@ -821,6 +839,7 @@ pub struct Conv2dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose1dDescription { pub x: TensorDescription, @@ -830,6 +849,7 @@ pub struct ConvTranspose1dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct ConvTranspose2dDescription { pub x: TensorDescription, @@ -839,6 +859,7 @@ pub struct ConvTranspose2dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dDescription { pub x: TensorDescription, @@ -849,6 +870,7 @@ pub struct AvgPool1dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dDescription { pub x: TensorDescription, @@ -859,6 +881,7 @@ pub struct AvgPool2dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool1dBackwardDescription { pub x: TensorDescription, @@ -870,6 +893,7 @@ pub struct AvgPool1dBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AvgPool2dBackwardDescription { pub x: TensorDescription, @@ -881,6 +905,7 @@ pub struct AvgPool2dBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dDescription { pub x: TensorDescription, @@ -888,6 +913,7 @@ pub struct AdaptiveAvgPool1dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dDescription { pub x: TensorDescription, @@ -895,6 +921,7 @@ pub struct AdaptiveAvgPool2dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dBackwardDescription { pub x: TensorDescription, @@ -902,6 +929,7 @@ pub struct AdaptiveAvgPool1dBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dBackwardDescription { pub x: TensorDescription, @@ -909,6 +937,7 @@ pub struct AdaptiveAvgPool2dBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dDescription { pub x: TensorDescription, @@ -919,6 +948,7 @@ pub struct MaxPool1dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesDescription { pub x: TensorDescription, @@ -930,6 +960,7 @@ pub struct MaxPool1dWithIndicesDescription { pub out_indices: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesBackwardDescription { pub x: TensorDescription, @@ -942,6 +973,7 @@ pub struct MaxPool1dWithIndicesBackwardDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dDescription { pub x: TensorDescription, @@ -952,6 +984,7 @@ pub struct MaxPool2dDescription { pub out: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesDescription { pub x: TensorDescription, @@ -963,6 +996,7 @@ pub struct MaxPool2dWithIndicesDescription { pub out_indices: TensorDescription, } +#[derive(Hash)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesBackwardDescription { pub x: TensorDescription, diff --git a/burn-fusion/src/handle.rs b/burn-fusion/src/handle.rs index 1f362d24c..10a6dbbef 100644 --- a/burn-fusion/src/handle.rs +++ b/burn-fusion/src/handle.rs @@ -1,8 +1,5 @@ use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus}; -use burn_tensor::{ - ops::{FloatElem, IntElem}, - Data, ElementConversion, Shape, -}; +use burn_tensor::Shape; use std::{collections::HashMap, sync::Arc}; /// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources @@ -17,10 +14,7 @@ pub struct HandleContainer { } enum Handle { - Empty, - DataFloat(Vec>), - DataInt(Vec>), - DataBool(Vec), + NotInit, Existing(B::Handle), } @@ -34,47 +28,38 @@ impl HandleContainer { } } + /// Register a handle for the given [tensor id](TensorId). + pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) { + self.handles.insert(id, Handle::Existing(handle)); + } + + /// Get the handle for the given [tensor id](TensorId). + pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle { + let (id, handle) = self + .handles + .remove_entry(&tensor.id) + .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); + + match handle { + Handle::Existing(handle) => match tensor.status { + TensorStatus::ReadOnly => { + self.handles.insert(id, Handle::Existing(handle.clone())); + handle + } + TensorStatus::ReadWrite => handle, + TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."), + }, + Handle::NotInit => panic!("Cannot get uninitialized handle."), + } + } + /// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_float_tensor( &mut self, tensor: &TensorDescription, ) -> B::TensorPrimitive { - let (id, handle) = self - .handles - .remove_entry(&tensor.id) - .unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id)); - - if let Handle::Existing(handle) = handle { - match tensor.status { - TensorStatus::ReadOnly => { - self.handles.insert(id, Handle::Existing(handle.clone())); - return B::float_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::ReadWrite => { - return B::float_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::NotInit => panic!("Can't get uninitialized tensor."), - } - } - - let output = match handle { - Handle::Empty => B::empty(Shape::from(tensor.shape.clone()), &self.device), - Handle::DataFloat(values) => B::from_data( - Data::new(values, Shape::from(tensor.shape.clone())), - &self.device, - ), - Handle::Existing(_) => unreachable!(), - Handle::DataInt(_) => panic!("From int unsupported when getting float tensor."), - Handle::DataBool(_) => panic!("From bool unsupported when getting float tensor."), - }; - - if let TensorStatus::ReadOnly = tensor.status { - self.handles - .insert(id, Handle::Existing(B::float_tensor_handle(output.clone()))); - } - - output + B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) } /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the @@ -83,38 +68,7 @@ impl HandleContainer { &mut self, tensor: &TensorDescription, ) -> B::IntTensorPrimitive { - let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap(); - - if let Handle::Existing(handle) = handle { - match tensor.status { - TensorStatus::ReadOnly => { - self.handles.insert(id, Handle::Existing(handle.clone())); - return B::int_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::ReadWrite => { - return B::int_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::NotInit => panic!("Can get uninitialized tensor."), - } - } - - let output = match handle { - Handle::Empty => B::int_empty(Shape::from(tensor.shape.clone()), &self.device), - Handle::DataInt(values) => B::int_from_data( - Data::new(values, Shape::from(tensor.shape.clone())), - &self.device, - ), - Handle::Existing(_) => unreachable!(), - Handle::DataFloat(_) => panic!("From float unsupported when getting int tensor."), - Handle::DataBool(_) => panic!("From bool unsupported when getting int tensor."), - }; - - if let TensorStatus::ReadOnly = tensor.status { - self.handles - .insert(id, Handle::Existing(B::int_tensor_handle(output.clone()))); - } - - output + B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) } /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the @@ -123,41 +77,7 @@ impl HandleContainer { &mut self, tensor: &TensorDescription, ) -> B::BoolTensorPrimitive { - let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap(); - - if let Handle::Existing(handle) = handle { - match tensor.status { - TensorStatus::ReadOnly => { - self.handles.insert(id, Handle::Existing(handle.clone())); - return B::bool_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::ReadWrite => { - return B::bool_tensor(handle, Shape::from(tensor.shape.clone())); - } - TensorStatus::NotInit => panic!("Can get uninitialized tensor."), - } - } - - let output = match handle { - Handle::Empty => B::int_equal_elem( - B::int_empty(Shape::from(tensor.shape.clone()), &self.device), - 0.elem(), - ), - Handle::DataBool(data) => B::bool_from_data( - Data::new(data, Shape::from(tensor.shape.clone())), - &self.device, - ), - Handle::Existing(_) => unreachable!(), - Handle::DataFloat(_) => panic!("From float unsupported when getting bool tensor."), - Handle::DataInt(_) => panic!("From int unsupported when getting bool tensor."), - }; - - if let TensorStatus::ReadOnly = tensor.status { - self.handles - .insert(id, Handle::Existing(B::bool_tensor_handle(output.clone()))); - } - - output + B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape)) } /// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId). @@ -191,37 +111,10 @@ impl HandleContainer { } /// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId). - pub fn create_tensor_empty(&mut self) -> Arc { + pub fn create_tensor_uninit(&mut self) -> Arc { let id = TensorId::new(self.counter); self.counter += 1; - self.handles.insert(id.clone(), Handle::Empty); - - Arc::new(id) - } - - /// Lazily create a new float tensor and return its corresponding [tensor id](TensorId). - pub(crate) fn create_tensor_float(&mut self, values: Vec>) -> Arc { - let id = TensorId::new(self.counter); - self.counter += 1; - self.handles.insert(id.clone(), Handle::DataFloat(values)); - - Arc::new(id) - } - - /// Lazily create a new int tensor and return its corresponding [tensor id](TensorId). - pub(crate) fn create_tensor_int(&mut self, values: Vec>) -> Arc { - let id = TensorId::new(self.counter); - self.counter += 1; - self.handles.insert(id.clone(), Handle::DataInt(values)); - - Arc::new(id) - } - - /// Lazily create a new bool tensor and return its corresponding [tensor id](TensorId). - pub(crate) fn create_tensor_bool(&mut self, values: Vec) -> Arc { - let id = TensorId::new(self.counter); - self.counter += 1; - self.handles.insert(id.clone(), Handle::DataBool(values)); + self.handles.insert(id.clone(), Handle::NotInit); Arc::new(id) } diff --git a/burn-fusion/src/ops/boolean.rs b/burn-fusion/src/ops/boolean.rs index dccc96f3c..179db25d4 100644 --- a/burn-fusion/src/ops/boolean.rs +++ b/burn-fusion/src/ops/boolean.rs @@ -17,8 +17,9 @@ use burn_tensor::{ impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { let client = get_client::(&device.clone().into()); + let tensor = B::bool_empty(shape.clone(), device); - client.create_tensor_empty(shape.dims.into()) + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) } fn bool_shape(tensor: &BoolTensor) -> Shape { @@ -36,8 +37,10 @@ impl BoolTensorOps for Fusion { device: &Device, ) -> BoolTensor { let client = get_client::(&device.clone().into()); + let tensor = B::bool_from_data(data, device); + let shape = B::bool_shape(&tensor); - client.create_tensor_bool(data.value, data.shape.dims.into()) + client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into()) } fn bool_into_int( @@ -55,7 +58,7 @@ impl BoolTensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt( @@ -84,7 +87,7 @@ impl BoolTensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::BoolOps( BoolOpsDescription::IntoFloat( @@ -139,7 +142,7 @@ impl BoolTensorOps for Fusion { } let shape: Vec = shape.dims.into(); - let out = tensor.client.create_tensor_empty(shape.clone()); + let out = tensor.client.tensor_uninitialized(shape.clone()); tensor .client @@ -183,7 +186,7 @@ impl BoolTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -227,7 +230,7 @@ impl BoolTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -279,7 +282,7 @@ impl BoolTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat( CatOpsDescription { @@ -312,7 +315,7 @@ impl BoolTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::BaseOpsBool( BaseOpsDescription::Equal( @@ -341,7 +344,7 @@ impl BoolTensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::BoolOps( crate::graph::BoolOpsDescription::Not( @@ -377,7 +380,7 @@ impl BoolTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index af56fc460..fc60d3ef4 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -26,8 +26,10 @@ impl TensorOps for Fusion { device: &Device, ) -> FloatTensor { let client = get_client::(&device.clone().into()); + let tensor = B::from_data(data, device); + let shape = B::shape(&tensor); - client.create_tensor_float(data.value, data.shape.dims.into()) + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) } fn random( @@ -54,7 +56,7 @@ impl TensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random( (out.to_description_out(), distribution), @@ -79,7 +81,7 @@ impl TensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), @@ -103,7 +105,7 @@ impl TensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), @@ -131,7 +133,7 @@ impl TensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Full( @@ -188,7 +190,7 @@ impl TensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::FloatOps( FloatOpsDescription::IntoInt( @@ -205,8 +207,9 @@ impl TensorOps for Fusion { fn empty(shape: Shape, device: &Device) -> FloatTensor { let client = get_client::(&device.clone().into()); + let tensor = B::empty(shape.clone(), device); - client.create_tensor_empty(shape.dims.into()) + client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into()) } fn add( @@ -217,7 +220,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Add( @@ -239,7 +242,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(AddOps, B::add_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::AddScalar( @@ -261,7 +264,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(ClampMinOps, B::clamp_min); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::ClampMin( @@ -283,7 +286,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(ClampMaxOps, B::clamp_max); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::ClampMax( @@ -317,7 +320,7 @@ impl TensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Clamp( @@ -342,7 +345,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Sub( @@ -364,7 +367,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(SubOps, B::sub_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::SubScalar( @@ -388,7 +391,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Mul( @@ -410,7 +413,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(MulOps, B::mul_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MulScalar( @@ -434,7 +437,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Div( @@ -456,7 +459,7 @@ impl TensorOps for Fusion { ) -> FloatTensor { scalar_float_ops!(DivOps, B::div_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::DivScalar( @@ -483,7 +486,7 @@ impl TensorOps for Fusion { shape[D - 2] = lhs.shape[D - 2]; shape[D - 1] = rhs.shape[D - 1]; - let out = lhs.client.create_tensor_empty(shape); + let out = lhs.client.tensor_uninitialized(shape); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul( @@ -519,7 +522,7 @@ impl TensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -556,7 +559,7 @@ impl TensorOps for Fusion { } let shape: Vec = shape.dims.into(); - let out = tensor.client.create_tensor_empty(shape.clone()); + let out = tensor.client.tensor_uninitialized(shape.clone()); tensor .client @@ -595,7 +598,7 @@ impl TensorOps for Fusion { } let shape: Vec = indices.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -638,7 +641,7 @@ impl TensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -681,7 +684,7 @@ impl TensorOps for Fusion { let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -724,7 +727,7 @@ impl TensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -769,7 +772,7 @@ impl TensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -813,7 +816,7 @@ impl TensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -855,7 +858,7 @@ impl TensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -896,7 +899,7 @@ impl TensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -924,7 +927,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::BaseOpsFloat( BaseOpsDescription::Equal( @@ -946,7 +949,7 @@ impl TensorOps for Fusion { ) -> BoolTensor { scalar_float_cmp_ops!(EqualElemOps, B::equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::EqualElem( @@ -970,7 +973,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Greater( @@ -992,7 +995,7 @@ impl TensorOps for Fusion { ) -> BoolTensor { scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::GreaterElem( @@ -1016,7 +1019,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::GreaterEqual( @@ -1038,7 +1041,7 @@ impl TensorOps for Fusion { ) -> BoolTensor { scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::GreaterEqualElem( @@ -1062,7 +1065,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Lower( @@ -1084,7 +1087,7 @@ impl TensorOps for Fusion { ) -> BoolTensor { scalar_float_cmp_ops!(LowerElemOps, B::lower_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::LowerElem( @@ -1108,7 +1111,7 @@ impl TensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::LowerEqual( @@ -1130,7 +1133,7 @@ impl TensorOps for Fusion { ) -> BoolTensor { scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::LowerEqualElem( @@ -1149,7 +1152,7 @@ impl TensorOps for Fusion { fn sum(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SumOps, B::sum); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Sum( @@ -1169,7 +1172,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::SumDim( @@ -1188,7 +1191,7 @@ impl TensorOps for Fusion { fn mean(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MeanOps, B::mean); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Mean( @@ -1208,7 +1211,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MeanDim( @@ -1239,7 +1242,7 @@ impl TensorOps for Fusion { fn exp(lhs: FloatTensor) -> FloatTensor { unary_float_ops!(ExpOps, B::exp); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp( @@ -1256,7 +1259,7 @@ impl TensorOps for Fusion { fn log(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(LogOps, B::log); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log( @@ -1273,7 +1276,7 @@ impl TensorOps for Fusion { fn log1p(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(Log1pOps, B::log1p); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p( @@ -1290,7 +1293,7 @@ impl TensorOps for Fusion { fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { scalar_float_ops!(PowfOps, B::powf, f32); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf( @@ -1308,7 +1311,7 @@ impl TensorOps for Fusion { fn sqrt(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SqrtOps, B::sqrt); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt( @@ -1325,7 +1328,7 @@ impl TensorOps for Fusion { fn abs(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(AbsOps, B::abs); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Abs( @@ -1343,7 +1346,7 @@ impl TensorOps for Fusion { fn cos(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(CosOps, B::cos); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos( @@ -1360,7 +1363,7 @@ impl TensorOps for Fusion { fn sin(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SinOps, B::sin); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin( @@ -1377,7 +1380,7 @@ impl TensorOps for Fusion { fn tanh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanhOps, B::tanh); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh( @@ -1394,7 +1397,7 @@ impl TensorOps for Fusion { fn recip(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(Recip, B::recip); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip( UnaryOpsDescription { @@ -1409,7 +1412,7 @@ impl TensorOps for Fusion { fn erf(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanhOps, B::erf); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client .register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf( @@ -1452,7 +1455,7 @@ impl TensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat( CatOpsDescription { @@ -1471,7 +1474,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::ArgMax( @@ -1492,7 +1495,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::ArgMin( @@ -1511,7 +1514,7 @@ impl TensorOps for Fusion { fn max(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MaxOps, B::max); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Max( @@ -1531,7 +1534,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MaxDim( @@ -1568,8 +1571,8 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.create_tensor_empty(shape.clone()); - let out_indices = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MaxDimWithIndices( @@ -1589,7 +1592,7 @@ impl TensorOps for Fusion { fn min(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MinOps, B::min); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::Min( @@ -1609,7 +1612,7 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MinDim( @@ -1646,8 +1649,8 @@ impl TensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.create_tensor_empty(shape.clone()); - let out_indices = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsFloat( NumericOpsDescription::MinDimWithIndices( diff --git a/burn-fusion/src/ops/int.rs b/burn-fusion/src/ops/int.rs index 9f8d274b5..32d2d6a54 100644 --- a/burn-fusion/src/ops/int.rs +++ b/burn-fusion/src/ops/int.rs @@ -22,8 +22,9 @@ use core::ops::Range; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { let client = get_client::(&device.clone().into()); + let tensor = B::int_empty(shape.clone(), device); - client.create_tensor_empty(shape.dims.into()) + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) } fn int_shape(tensor: &IntTensor) -> Shape { @@ -39,8 +40,10 @@ impl IntTensorOps for Fusion { device: &Device, ) -> IntTensor { let client = get_client::(&device.clone().into()); + let tensor = B::int_from_data(data, device); + let shape = B::int_shape(&tensor); - client.create_tensor_int(data.value, data.shape.dims.into()) + client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into()) } fn int_device(tensor: &IntTensor) -> Device { @@ -83,7 +86,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = shape.dims.into(); - let out = tensor.client.create_tensor_empty(shape.clone()); + let out = tensor.client.tensor_uninitialized(shape.clone()); tensor .client @@ -127,7 +130,7 @@ impl IntTensorOps for Fusion { shape.push(tensor.shape[i]); } - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -169,7 +172,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -211,7 +214,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -252,7 +255,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -292,7 +295,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = indices.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -335,7 +338,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -378,7 +381,7 @@ impl IntTensorOps for Fusion { let mut shape: Vec = tensor.shape.clone(); shape[dim] = indices.shape[0]; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -421,7 +424,7 @@ impl IntTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -471,7 +474,7 @@ impl IntTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat( CatOpsDescription { @@ -493,7 +496,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client .register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal( @@ -514,7 +517,7 @@ impl IntTensorOps for Fusion { ) -> BoolTensor { scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::EqualElem( @@ -538,7 +541,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Greater( @@ -560,7 +563,7 @@ impl IntTensorOps for Fusion { ) -> BoolTensor { scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::GreaterElem( @@ -584,7 +587,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::GreaterEqual( @@ -606,7 +609,7 @@ impl IntTensorOps for Fusion { ) -> BoolTensor { scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::GreaterEqualElem( @@ -630,7 +633,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Lower( @@ -652,7 +655,7 @@ impl IntTensorOps for Fusion { ) -> BoolTensor { scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::LowerElem( @@ -676,7 +679,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::LowerEqual( @@ -698,7 +701,7 @@ impl IntTensorOps for Fusion { ) -> BoolTensor { scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::LowerEqualElem( @@ -722,7 +725,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -745,7 +748,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(AddOps, B::int_add_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -770,7 +773,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -793,7 +796,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(SubOps, B::int_sub_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -818,7 +821,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -841,7 +844,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(MulOps, B::int_mul_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -866,7 +869,7 @@ impl IntTensorOps for Fusion { let out = lhs .client - .create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape)); + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape)); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -889,7 +892,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(DivOps, B::int_div_scalar); - let out = lhs.client.create_tensor_empty(lhs.shape.clone()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone()); out.client .register(graph::TensorOpsDescription::NumericOpsInt( @@ -921,7 +924,7 @@ impl IntTensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::)), @@ -945,7 +948,7 @@ impl IntTensorOps for Fusion { let shape: Vec = shape.dims.into(); let client = get_client::(&device.clone().into()); - let out = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::)), @@ -957,7 +960,7 @@ impl IntTensorOps for Fusion { fn int_sum(tensor: IntTensor) -> IntTensor { unary_int_ops!(SumOps, B::int_sum); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Sum( @@ -977,7 +980,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::SumDim( @@ -996,7 +999,7 @@ impl IntTensorOps for Fusion { fn int_mean(tensor: IntTensor) -> IntTensor { unary_int_ops!(MeanOps, B::int_mean); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Mean( @@ -1016,7 +1019,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::MeanDim( @@ -1037,7 +1040,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::ArgMax( @@ -1058,7 +1061,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::ArgMin( @@ -1080,7 +1083,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(ClampMinOps, B::int_clamp_min); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::ClampMin( @@ -1102,7 +1105,7 @@ impl IntTensorOps for Fusion { ) -> IntTensor { scalar_int_ops!(ClampMaxOps, B::int_clamp_max); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::ClampMax( @@ -1136,7 +1139,7 @@ impl IntTensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Clamp( @@ -1156,7 +1159,7 @@ impl IntTensorOps for Fusion { fn int_abs(tensor: IntTensor) -> IntTensor { unary_int_ops!(AbsOps, B::int_abs); - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Abs( @@ -1184,7 +1187,7 @@ impl IntTensorOps for Fusion { } } - let out = tensor.client.create_tensor_empty(tensor.shape.clone()); + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); out.client.register(TensorOpsDescription::IntOps( graph::IntOpsDescription::IntoFloat( @@ -1220,7 +1223,7 @@ impl IntTensorOps for Fusion { shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); tensor .client @@ -1243,7 +1246,7 @@ impl IntTensorOps for Fusion { fn int_max(tensor: IntTensor) -> IntTensor { unary_int_ops!(MaxOps, B::int_max); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Max( @@ -1263,7 +1266,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::MaxDim( @@ -1300,8 +1303,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.create_tensor_empty(shape.clone()); - let out_indices = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::MaxDimWithIndices( @@ -1321,7 +1324,7 @@ impl IntTensorOps for Fusion { fn int_min(tensor: IntTensor) -> IntTensor { unary_int_ops!(MinOps, B::int_min); - let out = tensor.client.create_tensor_empty(vec![1]); + let out = tensor.client.tensor_uninitialized(vec![1]); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::Min( @@ -1341,7 +1344,7 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor.client.create_tensor_empty(shape); + let out = tensor.client.tensor_uninitialized(shape); out.client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::MinDim( @@ -1378,8 +1381,8 @@ impl IntTensorOps for Fusion { let mut shape = tensor.shape.clone(); shape[dim] = 1; let client = tensor.client.clone(); - let out = client.create_tensor_empty(shape.clone()); - let out_indices = client.create_tensor_empty(shape); + let out = client.tensor_uninitialized(shape.clone()); + let out_indices = client.tensor_uninitialized(shape); client.register(TensorOpsDescription::NumericOpsInt( NumericOpsDescription::MinDimWithIndices( diff --git a/burn-fusion/src/ops/module.rs b/burn-fusion/src/ops/module.rs index dd52be1e1..2eef4be4b 100644 --- a/burn-fusion/src/ops/module.rs +++ b/burn-fusion/src/ops/module.rs @@ -56,7 +56,7 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], weight.shape[0], size]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::Conv1d( @@ -115,7 +115,7 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::Conv2d( @@ -168,7 +168,7 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::ConvTranspose1d( @@ -229,7 +229,7 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::ConvTranspose2d( @@ -275,7 +275,7 @@ impl ModuleOps> for Fusion { let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AvgPool1d( @@ -326,7 +326,7 @@ impl ModuleOps> for Fusion { calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AvgPool2d( @@ -374,7 +374,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AvgPool1dBackward( @@ -423,7 +423,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AvgPool2dBackward( @@ -472,7 +472,7 @@ impl ModuleOps> for Fusion { let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool1d( @@ -533,7 +533,7 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool2d( @@ -581,8 +581,8 @@ impl ModuleOps> for Fusion { let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); let shape = vec![x.shape[0], x.shape[1], size]; - let out = x.client.create_tensor_empty(shape.clone()); - let out_indices = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool1dWithIndices( @@ -645,8 +645,8 @@ impl ModuleOps> for Fusion { ); let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; - let out = x.client.create_tensor_empty(shape.clone()); - let out_indices = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape.clone()); + let out_indices = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool2dWithIndices( @@ -698,7 +698,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward( @@ -751,7 +751,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward( @@ -787,7 +787,7 @@ impl ModuleOps> for Fusion { } let shape = vec![x.shape[0], x.shape[1], output_size]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d( @@ -821,7 +821,7 @@ impl ModuleOps> for Fusion { } let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; - let out = x.client.create_tensor_empty(shape); + let out = x.client.tensor_uninitialized(shape); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d( @@ -855,7 +855,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward( @@ -889,7 +889,7 @@ impl ModuleOps> for Fusion { } } - let out = x.client.create_tensor_empty(x.shape.clone()); + let out = x.client.tensor_uninitialized(x.shape.clone()); x.client.clone().register(TensorOpsDescription::ModuleOps( crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward( diff --git a/burn-fusion/src/server.rs b/burn-fusion/src/server.rs index 328ef8aa6..9b52d3829 100644 --- a/burn-fusion/src/server.rs +++ b/burn-fusion/src/server.rs @@ -23,7 +23,7 @@ where G: GraphExecution, { pub fn new(device: B::FusionDevice) -> Self { - let optimizations = B::operations() + let optimizations = B::operations(&device.clone().into()) .into_iter() .map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default()))) .collect(); @@ -53,7 +53,11 @@ where ); } - pub fn sync(&mut self) { + pub fn drain_graph(&mut self) { + if self.graph.is_empty() { + return; + } + self.execution.maybe_execute( &mut self.graph, &mut self.handles, @@ -63,19 +67,7 @@ where } pub fn create_empty_handle(&mut self) -> Arc { - self.handles.create_tensor_empty() - } - - pub fn create_float_handle(&mut self, values: Vec>) -> Arc { - self.handles.create_tensor_float(values) - } - - pub fn create_int_handle(&mut self, values: Vec>) -> Arc { - self.handles.create_tensor_int(values) - } - - pub fn create_bool_handle(&mut self, values: Vec) -> Arc { - self.handles.create_tensor_bool(values) + self.handles.create_tensor_uninit() } pub fn read_float( @@ -84,7 +76,7 @@ where ) -> burn_tensor::Reader, D>> { // Make sure all registered operations are executed. // The underlying backend can still be async. - self.sync(); + self.drain_graph(); let tensor = self.handles.get_float_tensor(&tensor); B::into_data(tensor) @@ -96,7 +88,7 @@ where ) -> burn_tensor::Reader, D>> { // Make sure all registered operations are executed. // The underlying backend can still be async. - self.sync(); + self.drain_graph(); let tensor = self.handles.get_int_tensor(&tensor); B::int_into_data(tensor) @@ -108,7 +100,7 @@ where ) -> burn_tensor::Reader> { // Make sure all registered operations are executed. // The underlying backend can still be async. - self.sync(); + self.drain_graph(); let tensor = self.handles.get_bool_tensor(&tensor); B::bool_into_data(tensor) diff --git a/burn-fusion/src/tensor.rs b/burn-fusion/src/tensor.rs index a7c2005e4..70ffcf393 100644 --- a/burn-fusion/src/tensor.rs +++ b/burn-fusion/src/tensor.rs @@ -127,7 +127,7 @@ pub struct TensorId { } /// The status of the current tensor. -#[derive(Clone, Debug)] +#[derive(Hash, Clone, Debug, PartialEq, Eq)] pub enum TensorStatus { /// The tensor can be read, but not written. ReadOnly, @@ -147,7 +147,7 @@ pub enum TensorStatus { /// 2. Status::ReadOnly /// 3. Status::ReadOnly /// 4. Status::ReadWrite -#[derive(Debug)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct TensorDescription { /// The [tensor id](TensorId). pub id: TensorId, @@ -158,7 +158,8 @@ pub struct TensorDescription { } impl TensorId { - pub(crate) fn new(value: u64) -> Self { + /// Create a new tensor id. + pub fn new(value: u64) -> Self { Self { value } } } diff --git a/burn-tensor/src/tensor/ops/modules/base.rs b/burn-tensor/src/tensor/ops/modules/base.rs index ce1d3a310..4290fbca2 100644 --- a/burn-tensor/src/tensor/ops/modules/base.rs +++ b/burn-tensor/src/tensor/ops/modules/base.rs @@ -66,7 +66,7 @@ pub struct Conv1dBackward { } /// Convolution options. -#[derive(new, Debug, Clone)] +#[derive(new, Debug, Clone, Hash)] pub struct ConvOptions { /// Stride. pub stride: [usize; N], @@ -82,7 +82,7 @@ pub struct ConvOptions { } /// Transposed convolution options. -#[derive(new, Debug, Clone)] +#[derive(new, Debug, Clone, Hash)] pub struct ConvTransposeOptions { /// Stride. pub stride: [usize; N], diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml index 4917e2423..7e794099d 100644 --- a/burn-wgpu/Cargo.toml +++ b/burn-wgpu/Cargo.toml @@ -53,11 +53,16 @@ burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" } burn-fusion = { path = "../burn-fusion", version = "0.11.0" } serial_test = "2.0.0" +pretty_assertions = {workspace = true} [[bench]] name = "matmul" harness = false +[[bench]] +name = "fused_elemwise" +harness = false + [[bench]] name = "reduction" harness = false diff --git a/burn-wgpu/benches/fused_elemwise.rs b/burn-wgpu/benches/fused_elemwise.rs new file mode 100644 index 000000000..ff8443bd2 --- /dev/null +++ b/burn-wgpu/benches/fused_elemwise.rs @@ -0,0 +1,74 @@ +use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_fusion::Fusion; +use burn_tensor::backend::Backend; +use burn_tensor::{Distribution, Shape, Tensor}; +use burn_wgpu::Wgpu; +use burn_wgpu::WgpuDevice; +use derive_new::new; +use std::marker::PhantomData; + +#[derive(new)] +struct ElemWiseBenchmark { + shape: Shape<3>, + device: B::Device, + repeat: usize, + _b: PhantomData, +} + +impl Benchmark for ElemWiseBenchmark { + type Args = (Tensor, Tensor); + + fn name(&self) -> String { + format!( + "Backend {} Shape {:?} Repeat {}", + B::name(), + self.shape.dims, + self.repeat + ) + } + + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, (lhs, rhs): Self::Args) { + for _ in 0..self.repeat { + let tmp_0 = lhs.clone() + rhs.clone(); + let tmp_1 = rhs.clone() * tmp_0.clone(); + let tmp_2 = rhs.clone().exp(); + let tmp_3 = tmp_0 * tmp_1; + let _tmp_4 = tmp_2 / tmp_3; + } + } + + fn prepare(&self) -> Self::Args { + B::seed(10); + let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + + (lhs, rhs) + } + + fn sync(&self) { + B::sync(&self.device) + } +} + +#[allow(dead_code)] +/// Runs the benchmarks for wgpu matmul implementations +pub fn bench(device: &WgpuDevice) { + run_benchmark(ElemWiseBenchmark::::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); + run_benchmark(ElemWiseBenchmark::>::new( + Shape::new([256, 256, 1024]), + device.clone(), + 10, + )); +} + +fn main() { + bench(&WgpuDevice::BestAvailable) +} diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs index 7c78bf439..1789f2f00 100644 --- a/burn-wgpu/src/compute/server.rs +++ b/burn-wgpu/src/compute/server.rs @@ -154,7 +154,9 @@ where return pipeline.clone(); } - let pipeline = self.compile_source(&kernel.source().complete()); + let source = kernel.source().complete(); + log::trace!("Compiling kernel {kernel_id}:\n {source}"); + let pipeline = self.compile_source(&source); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); pipeline diff --git a/burn-wgpu/src/fusion.rs b/burn-wgpu/src/fusion/base.rs similarity index 83% rename from burn-wgpu/src/fusion.rs rename to burn-wgpu/src/fusion/base.rs index d22d4e577..b1e1f864c 100644 --- a/burn-wgpu/src/fusion.rs +++ b/burn-wgpu/src/fusion/base.rs @@ -1,6 +1,7 @@ use crate::{ compute::{WgpuComputeClient, WgpuHandle}, element::WgpuElement, + fusion::FloatElementWiseFusionOps, tensor::WgpuTensor, FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice, }; @@ -32,8 +33,8 @@ where type Handle = WgpuFusionHandle; type FusionClient = MutexFusionClient; - fn operations() -> Vec>> { - Vec::new() + fn operations(device: &WgpuDevice) -> Vec>> { + vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))] } fn float_tensor( @@ -70,7 +71,27 @@ where } } -#[derive(Debug, Clone)] +pub fn strides_dyn_rank(shape: &[usize]) -> Vec { + let mut strides = vec![0; shape.len()]; + + let mut current = 1; + shape.iter().enumerate().rev().for_each(|(index, val)| { + strides[index] = current; + current *= val; + }); + + strides +} + +pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { + let mut num_elems = 1; + for i in shape.iter() { + num_elems *= i; + } + num_elems +} + +#[derive(new, Debug, Clone)] /// Handle to be used when fusing operations. pub struct WgpuFusionHandle { /// Compute client for wgpu. diff --git a/burn-wgpu/src/fusion/codegen/body.rs b/burn-wgpu/src/fusion/codegen/body.rs new file mode 100644 index 000000000..cab35bf75 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/body.rs @@ -0,0 +1,27 @@ +use super::Operator; +use std::fmt::Display; + +/// A body is composed of a list of [operators](Operator). +/// +/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size +/// X and Y, but with Z=1. +#[derive(Hash, new)] +pub struct Body { + operators: Vec, +} + +impl Display for Body { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str( + "let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", + )?; + f.write_str("let rank: u32 = info[0];\n\n")?; + + for ops in self.operators.iter() { + f.write_fmt(format_args!("{ops}"))?; + f.write_str("\n")?; + } + + Ok(()) + } +} diff --git a/burn-wgpu/src/fusion/codegen/function.rs b/burn-wgpu/src/fusion/codegen/function.rs new file mode 100644 index 000000000..fceae4e39 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/function.rs @@ -0,0 +1,71 @@ +use super::Elem; +use std::fmt::Display; + +/// Not all functions are native to WGSL, so this struct allows to support more functions. +#[derive(Hash, PartialEq, Eq, Clone)] +pub enum Function { + Powf(Elem), + Erf(Elem), +} + +impl Display for Function { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Function::Powf(elem) => format_powf(f, elem), + Function::Erf(elem) => format_erf(f, elem), + } + } +} + +fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { + f.write_fmt(format_args!( + " +fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ + let modulo = rhs % 2.0; + + if (modulo == 0.0) {{ + // Even number + return pow(abs(lhs), rhs); + }} else if (modulo == 1.0 && lhs < 0.0) {{ + // Odd number + return -1.0 * pow(-1.0 * lhs, rhs); + }} else {{ + // Float number + return pow(lhs, rhs); + }} +}} +" + )) +} + +fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { + f.write_fmt(format_args!( + " +/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations +/// +/// > (maximum error: 1.5×10−7) +/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). +fn erf_positive(x: {elem}) -> {elem} {{ + let p = 0.3275911; + let a1 = 0.254829592; + let a2 = -0.284496736; + let a3 = 1.421413741; + let a4 = -1.453152027; + let a5 = 1.061405429; + + let t = 1.0 / (1.0 + p * abs(x)); + let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1; + + return 1.0 - (tmp * t * exp(-x * x)); +}} + +fn erf(x: {elem}) -> {elem} {{ + if (x < 0.0) {{ + return -1.0 * erf_positive(-1.0 * x); + }} + + return erf_positive(x); +}} +" + )) +} diff --git a/burn-wgpu/src/fusion/codegen/mod.rs b/burn-wgpu/src/fusion/codegen/mod.rs new file mode 100644 index 000000000..b9b568837 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/mod.rs @@ -0,0 +1,11 @@ +mod body; +mod function; +mod operator; +mod shader; +mod variable; + +pub use body::*; +pub use function::*; +pub use operator::*; +pub use shader::*; +pub use variable::*; diff --git a/burn-wgpu/src/fusion/codegen/operator.rs b/burn-wgpu/src/fusion/codegen/operator.rs new file mode 100644 index 000000000..ab0b193a8 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/operator.rs @@ -0,0 +1,146 @@ +use super::Variable; +use std::fmt::Display; + +/// All operators that can be fused in a WGSL compute shader. +#[derive(Debug, Hash, Clone)] +pub enum Operator { + Add { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Sub { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Mul { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Div { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Abs { + input: Variable, + out: Variable, + }, + Exp { + input: Variable, + out: Variable, + }, + Log { + input: Variable, + out: Variable, + }, + Log1p { + input: Variable, + out: Variable, + }, + Cos { + input: Variable, + out: Variable, + }, + Sin { + input: Variable, + out: Variable, + }, + Tanh { + input: Variable, + out: Variable, + }, + Powf { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + Erf { + input: Variable, + out: Variable, + }, + AssignGlobal { + input: Variable, + out: Variable, + }, + ReadGlobal { + variable: Variable, + position: usize, + position_out: usize, + }, +} + +impl Display for Operator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Operator::Add { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} + {rhs};")) + } + Operator::Sub { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} - {rhs};")) + } + Operator::Mul { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} * {rhs};")) + } + Operator::Div { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = {lhs} / {rhs};")) + } + Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), + Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), + Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), + Operator::Powf { lhs, rhs, out } => { + f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) + } + Operator::Log1p { input, out } => { + f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) + } + Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), + Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), + Operator::Tanh { input, out } => { + f.write_fmt(format_args!("let {out} = tanh({input});")) + } + Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), + Operator::AssignGlobal { input, out } => { + f.write_fmt(format_args!("{out}_global[id] = {input};")) + } + Operator::ReadGlobal { + variable, + position, + position_out, + } => { + let (global, local) = match variable { + Variable::Input(number) => { + (format!("input_{number}_global"), format!("input_{number}")) + } + Variable::Local(_) => panic!("can't read globala local variable."), + Variable::Output(number) => ( + format!("output_{number}_global"), + format!("output_{number}"), + ), + Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), + }; + + f.write_fmt(format_args!( + " +var index_{local}: u32 = 0u; + +for (var i: u32 = 1u; i <= rank; i++) {{ + let position = {position}u * (2u * rank); + let position_out = {position_out}u * (2u * rank); + + let stride = info[position + i]; + let stride_out = info[position_out + i]; + let shape = info[position + rank + i]; + + index_{local} += id / stride_out % shape * stride; +}} + +let {local} = {global}[index_{local}]; +" + )) + } + } + } +} diff --git a/burn-wgpu/src/fusion/codegen/shader.rs b/burn-wgpu/src/fusion/codegen/shader.rs new file mode 100644 index 000000000..8ce399978 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/shader.rs @@ -0,0 +1,201 @@ +use super::{Body, Function}; +use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}; +use std::{ + collections::hash_map::DefaultHasher, + fmt::Display, + hash::{Hash, Hasher}, +}; + +#[derive(Hash, PartialEq, Eq)] +pub enum Location { + Storage, + #[allow(dead_code)] + Workgroup, +} + +#[derive(Hash, PartialEq, Eq)] +pub enum Visibility { + Read, + ReadWrite, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum Elem { + F32, + #[allow(dead_code)] + I32, + U32, +} + +#[derive(Hash, PartialEq, Eq)] +pub struct Binding { + pub location: Location, + pub visibility: Visibility, + pub elem: Elem, + pub size: Option, +} + +#[derive(Hash, PartialEq, Eq)] +pub struct WorkgroupSize { + pub x: usize, + pub y: usize, + pub z: usize, +} + +impl Default for WorkgroupSize { + fn default() -> Self { + Self { + x: WORKGROUP_DEFAULT, + y: WORKGROUP_DEFAULT, + z: 1, + } + } +} + +#[derive(Hash)] +pub struct ComputeShader { + pub inputs: Vec, + pub outputs: Vec, + pub named: Vec<(String, Binding)>, + pub workgroup_size: WorkgroupSize, + pub global_invocation_id: bool, + pub num_workgroups: bool, + pub body: Body, + pub functions: Vec, +} + +impl DynamicKernelSource for ComputeShader { + fn source(&self) -> SourceTemplate { + SourceTemplate::new(self.to_string()) + } + + fn id(&self) -> String { + let mut s = DefaultHasher::new(); + self.hash(&mut s); + + s.finish().to_string() + } +} + +impl Display for ComputeShader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Self::format_bindings(f, "input", &self.inputs, 0)?; + Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; + + for (i, (name, binding)) in self.named.iter().enumerate() { + Self::format_binding( + f, + name.as_str(), + binding, + self.inputs.len() + self.outputs.len() + i, + )?; + } + + f.write_fmt(format_args!( + "const WORKGROUP_SIZE_X = {}u; +const WORKGROUP_SIZE_Y = {}u; +const WORKGROUP_SIZE_Z = {}u;\n", + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; + + f.write_fmt(format_args!( + " +@compute +@workgroup_size({}, {}, {}) +fn main( +", + self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z + ))?; + + if self.global_invocation_id { + f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; + } + + if self.num_workgroups { + f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; + } + + f.write_fmt(format_args!( + ") {{ + {} +}}", + self.body + ))?; + + for function in self.functions.iter() { + f.write_fmt(format_args!("{function}\n\n"))?; + } + + Ok(()) + } +} + +impl ComputeShader { + fn format_bindings( + f: &mut core::fmt::Formatter<'_>, + prefix: &str, + bindings: &[Binding], + num_entry: usize, + ) -> core::fmt::Result { + for (i, binding) in bindings.iter().enumerate() { + Self::format_binding( + f, + format!("{prefix}_{i}_global").as_str(), + binding, + num_entry + i, + )?; + } + + Ok(()) + } + + fn format_binding( + f: &mut core::fmt::Formatter<'_>, + name: &str, + binding: &Binding, + num_entry: usize, + ) -> core::fmt::Result { + let ty = match binding.size { + Some(size) => format!("array<{}, {}>", binding.elem, size), + None => format!("array<{}>", binding.elem), + }; + + f.write_fmt(format_args!( + "@group(0) +@binding({}) +var<{}, {}> {}: {}; +\n", + num_entry, binding.location, binding.visibility, name, ty + ))?; + + Ok(()) + } +} + +impl Display for Location { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Location::Storage => f.write_str("storage"), + Location::Workgroup => f.write_str("workgroup"), + } + } +} + +impl Display for Elem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Elem::F32 => f.write_str("f32"), + Elem::I32 => f.write_str("i32"), + Elem::U32 => f.write_str("u32"), + } + } +} + +impl Display for Visibility { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Visibility::Read => f.write_str("read"), + Visibility::ReadWrite => f.write_str("read_write"), + } + } +} diff --git a/burn-wgpu/src/fusion/codegen/variable.rs b/burn-wgpu/src/fusion/codegen/variable.rs new file mode 100644 index 000000000..b74c4dbb8 --- /dev/null +++ b/burn-wgpu/src/fusion/codegen/variable.rs @@ -0,0 +1,21 @@ +use super::Elem; +use std::fmt::Display; + +#[derive(Debug, Hash, Clone)] +pub enum Variable { + Input(u16), + Scalar(u16, Elem), + Local(u16), + Output(u16), +} + +impl Display for Variable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), + Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), + Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), + Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), + } + } +} diff --git a/burn-wgpu/src/fusion/elemwise/mod.rs b/burn-wgpu/src/fusion/elemwise/mod.rs new file mode 100644 index 000000000..971046d52 --- /dev/null +++ b/burn-wgpu/src/fusion/elemwise/mod.rs @@ -0,0 +1,3 @@ +mod ops; + +pub use ops::*; diff --git a/burn-wgpu/src/fusion/elemwise/ops.rs b/burn-wgpu/src/fusion/elemwise/ops.rs new file mode 100644 index 000000000..4a4eb45a9 --- /dev/null +++ b/burn-wgpu/src/fusion/elemwise/ops.rs @@ -0,0 +1,463 @@ +use crate::{ + fusion::codegen::{Elem, Operator, Variable}, + fusion::kernel::FusionKernel, + FloatElement, GraphicsApi, IntElement, Wgpu, +}; +use burn_fusion::{ + graph::{ + BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, + TensorOpsDescription, UnaryOpsDescription, + }, + FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, + TensorId, +}; +use burn_tensor::{Device, Element}; +use hashbrown::HashMap; +use std::sync::Arc; + +/// Fused element wise operations that are normally memory bound. +pub struct FloatElementWiseFusionOps +where + G: GraphicsApi, + F: FloatElement, + I: IntElement, +{ + pub(crate) inputs: Vec, + pub(crate) locals: HashMap, + pub(crate) tensors: HashMap, + pub(crate) scalars_f32: Vec, + pub(crate) operators: Vec, + pub(crate) properties: FusionProperties, + pub(crate) current_output_shape: Vec, + device: Device>, +} + +impl FusionOps> + for FloatElementWiseFusionOps +{ + fn register(&mut self, ops: Arc>>) -> FusionStatus { + match ops.as_ref() { + TensorOpsDescription::FloatOps(ops) => { + if !self.register_float(ops) { + return FusionStatus::Closed(self.properties); + } + } + TensorOpsDescription::NumericOpsFloat(ops) => { + if !self.register_numeric(ops) { + return FusionStatus::Closed(self.properties); + } + } + _ => { + return FusionStatus::Closed(self.properties); + } + }; + + self.properties.score += 1; + self.properties.ready = self.operators.len() > 1; + + FusionStatus::Open(self.properties) + } + + fn execute(&mut self, handles: &mut HandleContainer>) { + let inputs = self.input_descriptions(); + let outputs = self.output_descriptions(); + let locals = outputs + .iter() + .map(|out| *self.locals.get(&out.id).unwrap()) + .collect::>(); + + FusionKernel::new(&self.device) + .inputs(&inputs, &self.scalars_f32) + .body(&self.operators) + .outputs(&outputs, &locals) + .execute(handles); + } + + fn reset(&mut self) { + self.inputs.clear(); + self.locals.drain(); + self.tensors.clear(); + self.scalars_f32.clear(); + self.operators.clear(); + self.properties = FusionProperties::default(); + self.current_output_shape.clear(); + } + + fn len(&self) -> usize { + self.operators.len() + } +} + +impl FloatElementWiseFusionOps +where + G: GraphicsApi, + F: FloatElement, + I: IntElement, +{ + pub fn new(device: Device>) -> Self { + Self { + inputs: Vec::new(), + locals: HashMap::new(), + tensors: HashMap::new(), + scalars_f32: Vec::new(), + operators: Vec::new(), + current_output_shape: Vec::new(), + properties: FusionProperties::default(), + device, + } + } + + fn input_descriptions(&self) -> Vec<&TensorDescription> { + self.inputs + .iter() + .map(|input| { + let updated_tensor = self.tensors.get(&input.id).unwrap(); + updated_tensor + }) + .collect::>() + } + + fn output_descriptions(&self) -> Vec<&TensorDescription> { + let mut outputs = Vec::new(); + let mut local_tensor_ids_input = Vec::new(); + let mut local_tensor_ids_output = Vec::new(); + + // Mark a variable to the provided list of tensor ids using the variable list. + // + // Only local variables can become outputs. + let mark = |var: &Variable, list: &mut Vec| { + if let Variable::Local(index) = var { + if let Some((id, _)) = self + .locals + .iter() + .find(|(_id, position)| *position == index) + { + if !list.contains(id) { + list.push(id.clone()); + } + } + } + }; + + // For all operators, mark their local tensor id in the proper set. + for ops in self.operators.iter() { + match ops { + Operator::AssignGlobal { input: _, out: _ } => { + // Nothing to do here. + } + Operator::ReadGlobal { + variable: _, + position: _, + position_out: _, + } => { + // Nothing to do here. + } + Operator::Add { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sub { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Mul { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Div { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Exp { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Abs { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Erf { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Log1p { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Cos { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Sin { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Tanh { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + Operator::Powf { lhs, rhs, out } => { + mark(lhs, &mut local_tensor_ids_input); + mark(rhs, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } + } + } + + // All output tensors that are never read by a following operation should be written to + // since they are essentially the "logical" output of the shader. + for out in local_tensor_ids_output { + let is_read = local_tensor_ids_input.contains(&out); + + if !is_read { + outputs.push(self.tensors.get(&out).unwrap()); + } + } + + // All tensors where their latest description is read only should be written to since they + // are going to be used after the fused kernel by other operations. + for tensor in self.tensors.values() { + if let burn_fusion::TensorStatus::ReadOnly = tensor.status { + if self.locals.contains_key(&tensor.id) { + outputs.push(tensor); + } + } + } + + outputs + } + + fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { + let already_exists = self.tensors.contains_key(&tensor.id); + + let variable = match already_exists { + false => { + // New input + let var = Variable::Input(self.inputs.len() as u16); + self.inputs.push(tensor.clone()); + var + } + true => match self.locals.get(&tensor.id) { + // Is a local variable. + Some(local_index) => Variable::Local(*local_index), + // Isn't a local variable, so must be an existing input. + None => { + let input = self + .inputs + .iter() + .enumerate() + .find(|(_, input)| input.id == tensor.id) + .unwrap(); + let input_index = input.0; + Variable::Input(input_index as u16) + } + }, + }; + + // Update the tensor description with the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); + + variable + } + + fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { + // Update the tensor description to the new version. + self.tensors.insert(tensor.id.clone(), tensor.clone()); + + // Output already registered as a local variable. + if let Some(index) = self.locals.get(&tensor.id) { + return Variable::Local(*index); + } + + // New local variable. + let local_index = self.locals.len() as u16; + self.locals.insert(tensor.id.clone(), local_index); + Variable::Local(local_index) + } + + fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { + match ops { + FloatOpsDescription::Exp(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) + } + FloatOpsDescription::Log(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) + } + FloatOpsDescription::Log1p(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) + } + FloatOpsDescription::Cos(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) + } + FloatOpsDescription::Sin(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) + } + FloatOpsDescription::Powf(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out }) + } + FloatOpsDescription::Tanh(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) + } + FloatOpsDescription::Erf(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) + } + _ => false, + } + } + + fn register_numeric( + &mut self, + ops: &NumericOpsDescription, + ) -> bool { + match ops { + NumericOpsDescription::Add(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::AddScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) + } + NumericOpsDescription::Sub(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::SubScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) + } + NumericOpsDescription::Mul(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::MulScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) + } + NumericOpsDescription::Div(desc, _) => { + self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::DivScalar(desc, _) => { + self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) + } + NumericOpsDescription::Abs(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) + } + _ => false, + } + } + + fn register_binary_ops(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } + + let lhs = self.input_to_var(&desc.lhs); + let rhs = self.input_to_var(&desc.rhs); + let out = self.output_to_var(&desc.out); + + self.operators.push(func(lhs, rhs, out)); + + true + } + + fn register_unary_ops(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool + where + Func: Fn(Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } + + let input = self.input_to_var(&desc.input); + let out = self.output_to_var(&desc.out); + + self.operators.push(func(input, out)); + + true + } + + fn register_scalar_ops( + &mut self, + desc: &ScalarOpsDescription, + func: Func, + ) -> bool + where + Func: Fn(Variable, Variable, Variable) -> Operator, + { + if !self.output_is_compatible(&desc.out) { + return false; + } + + let lhs = self.input_to_var(&desc.lhs); + let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); + self.scalars_f32.push(desc.rhs.elem()); + let out = self.output_to_var(&desc.out); + + self.operators.push(func(lhs, rhs, out)); + + true + } + + fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { + if self.current_output_shape.is_empty() { + self.current_output_shape = out.shape.clone(); + } else if self.current_output_shape != out.shape { + return false; + } + + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use burn_fusion::graph::{BinaryOpsDescription, Ops}; + use burn_fusion::Fusion; + use burn_tensor::Tensor; + + struct FakeAddOps; + + impl Ops for FakeAddOps { + type Args = BinaryOpsDescription; + + fn execute(&self, _: &Self::Args, _: &mut HandleContainer) { + todo!() + } + } + + #[test] + fn test_fusion_same_behavior() { + type Backend = Wgpu; + type FusedBackend = Fusion; + + let data_1 = + Tensor::::random([1, 32], burn_tensor::Distribution::Default).into_data(); + let data_2 = + Tensor::::random([32, 32], burn_tensor::Distribution::Default).into_data(); + + let tensor_1 = Tensor::::from_data(data_1.clone()); + let tensor_2 = Tensor::::from_data(data_2.clone()); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_ref = tensor_6.into_data(); + + let tensor_1 = Tensor::::from_data(data_1); + let tensor_2 = Tensor::::from_data(data_2); + let tensor_3 = tensor_1.clone() + tensor_2; + let tensor_4 = tensor_3.clone() - tensor_1; + let tensor_5 = tensor_4 + 5.0; + let tensor_6 = tensor_5 + tensor_3; + let result_fused = tensor_6.into_data(); + + result_fused.assert_approx_eq(&result_ref, 3); + } +} diff --git a/burn-wgpu/src/fusion/kernel.rs b/burn-wgpu/src/fusion/kernel.rs new file mode 100644 index 000000000..31a0353cb --- /dev/null +++ b/burn-wgpu/src/fusion/kernel.rs @@ -0,0 +1,320 @@ +use super::codegen::Body; +use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient}; +use crate::fusion::codegen::Function; +use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank}; +use crate::fusion::{ + codegen::{ + Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, + }, + WgpuFusionHandle, +}; +use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}; +use crate::{FloatElement, GraphicsApi, IntElement, Wgpu}; +use burn_fusion::{HandleContainer, TensorDescription}; +use burn_tensor::Device; +use std::marker::PhantomData; + +/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details. +pub struct InputPhase; +/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details. +pub struct BodyPhase; +/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details. +pub struct OutputPhase; +/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details. +pub struct ExecutionPhase; + +/// Allows to create custom wgsl kernels based on configured inputs, body and outputs. +/// +/// This type has 4 phases that must be executed in order, but no worry the type system won't allow +/// you to make mistakes. +/// +/// 1. [Input Phase](InputPhase) +/// This phase focuses on registering the input tensor descriptions that are going to be used by +/// the fused kernel. +/// 2. [Body Phase](BodyPhase) +/// After the input phase is done, all the operations that happen in the body must be +/// registered. +/// 3. [Output Phase](OutputPhase) +/// This step focuses on registering all tensor descriptions that the kernel needs to write to. +/// 4. [Execution Phase](ExecutionPhase) +/// Now that all other phases are completed, we can actually run the kernel on the given +/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the +/// handles provided. +pub struct FusionKernel +where + G: GraphicsApi, + F: FloatElement, + I: IntElement, +{ + operations: Vec, + input_bindings: Vec<(Binding, TensorDescription)>, + output_bindings: Vec<(Binding, TensorDescription)>, + named_bindings: Vec<(String, Binding, DataBuffer)>, + functions: Vec, + num_elems_output: usize, + device: Device>, + client: WgpuComputeClient, + _phase: PhantomData, +} + +enum DataBuffer { + F32(Vec), + U32(Vec), +} + +impl FusionKernel { + /// Create a new fusion kernel on the given device. + pub fn new(device: &Device>) -> Self { + let client = compute_client::(device); + + Self { + operations: Vec::new(), + input_bindings: Vec::new(), + output_bindings: Vec::new(), + named_bindings: Vec::new(), + functions: Vec::new(), + num_elems_output: 0, + device: device.clone(), + client, + _phase: PhantomData, + } + } + + /// Register the inputs used by the kernel. + pub fn inputs( + mut self, + inputs_tensor: &[&TensorDescription], + inputs_scalar_f32: &[f32], + ) -> FusionKernel { + for (i, input) in inputs_tensor.iter().enumerate() { + self.input_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, + }, + (*input).clone(), + )); + + self.operations.push(Operator::ReadGlobal { + variable: Variable::Input(i as u16), + position: i, + position_out: inputs_tensor.len(), // First output + }); + } + + if !inputs_scalar_f32.is_empty() { + self.named_bindings.push(( + "scalars_f32".to_string(), + Binding { + elem: Elem::F32, + visibility: Visibility::Read, + location: Location::Storage, + size: Some(inputs_scalar_f32.len()), + }, + DataBuffer::F32(inputs_scalar_f32.to_vec()), + )); + } + + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } + } +} + +impl FusionKernel { + /// Register the [operators](Operator) that the kernel must execute in the order provided. + pub fn body(mut self, operators: &[Operator]) -> FusionKernel { + let mut register_function = |function: Function| { + if !self.functions.contains(&function) { + self.functions.push(function); + } + }; + + // Since not all operators are native to WGSL, we need to add the custom ones. + for ops in operators.iter() { + match ops { + Operator::Powf { + lhs: _, + rhs: _, + out: _, + } => { + register_function(Function::Powf(Elem::F32)); + } + Operator::Erf { input: _, out: _ } => { + register_function(Function::Erf(Elem::F32)); + } + _ => {} + } + self.operations.push(ops.clone()); + } + + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } + } +} + +impl FusionKernel { + /// Register the outputs with their local variable index. + /// + /// Note that the index corresponds to the registered [operator](Operator) number at the + /// [body phase](BodyPhase). + /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). + pub fn outputs( + mut self, + outputs: &[&TensorDescription], + locals: &[u16], + ) -> FusionKernel { + let mut num_elems_launch_option = 0; + + for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { + let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); + if num_elems_output > num_elems_launch_option { + num_elems_launch_option = num_elems_output; + } + + self.output_bindings.push(( + Binding { + elem: Elem::F32, + visibility: Visibility::ReadWrite, + location: Location::Storage, + size: None, + }, + (*output).clone(), + )); + + self.operations.push(Operator::AssignGlobal { + input: Variable::Local(*local), + out: Variable::Output(i as u16), + }); + } + + self.num_elems_output = num_elems_launch_option; + + FusionKernel { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + num_elems_output: self.num_elems_output, + device: self.device, + client: self.client, + _phase: PhantomData, + } + } +} + +impl FusionKernel { + /// Execute the kernel on the provided [handles](HandleContainer). + pub fn execute(mut self, handle_container: &mut HandleContainer>) { + let mut inputs = Vec::with_capacity(self.input_bindings.len()); + let mut outputs = Vec::with_capacity(self.output_bindings.len()); + let mut named = Vec::with_capacity(2); + let mut info = Vec::new(); + let mut handles = + Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); + + // Inner function to fill the info buffer. + let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { + if info.is_empty() { + info.push(handle.strides.len() as u32); + } + + for s in handle.strides.iter() { + info.push(*s as u32); + } + for s in tensor.shape.iter() { + info.push(*s as u32); + } + }; + + // We start by registering the inputs. + for (binding, tensor) in self.input_bindings.into_iter() { + let handle = handle_container.get_handle(&tensor); + register_info_tensor(&tensor, &handle); + + inputs.push(binding); + handles.push(handle.handle); + } + + // Then we follow with the outputs. + for (binding, tensor) in self.output_bindings { + let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); + let handle_fusion = WgpuFusionHandle { + client: self.client.clone(), + device: self.device.clone(), + strides: strides_dyn_rank(&tensor.shape), + handle: self.client.empty(core::mem::size_of::() * num_elems), + }; + register_info_tensor(&tensor, &handle_fusion); + + handles.push(handle_fusion.handle.clone()); + handle_container.register_handle(tensor.id, handle_fusion); + outputs.push(binding); + } + + // Now we can create the info handle. + Self::build_info_handle(&mut self.named_bindings, info); + + // Finally we finish with the named bindings. + for (name, binding, data) in self.named_bindings { + let handle = self.client.create(match &data { + DataBuffer::F32(values) => bytemuck::cast_slice(values), + DataBuffer::U32(values) => bytemuck::cast_slice(values), + }); + named.push((name, binding)); + handles.push(handle); + } + + // We create the shader codegen type and launch the kernel. + let kernel = ComputeShader { + inputs, + outputs, + named, + workgroup_size: WorkgroupSize::default(), + body: Body::new(self.operations), + num_workgroups: true, + global_invocation_id: true, + functions: self.functions, + }; + + let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); + let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); + + self.client + .execute(kernel, &handles.iter().collect::>()); + } + + fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { + named_bindings.push(( + "info".to_string(), + Binding { + elem: Elem::U32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, // We avoid putting the length here since it will force a new kernel + // for each tensor rank. + }, + DataBuffer::U32(info), + )); + } +} diff --git a/burn-wgpu/src/fusion/mod.rs b/burn-wgpu/src/fusion/mod.rs new file mode 100644 index 000000000..75193b747 --- /dev/null +++ b/burn-wgpu/src/fusion/mod.rs @@ -0,0 +1,8 @@ +mod base; +mod elemwise; + +pub(crate) mod codegen; +pub(crate) mod kernel; + +pub use base::*; +pub use elemwise::*; diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index c464ffe6b..133015f49 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -1,7 +1,8 @@ use super::SourceTemplate; use crate::{ - compute::{StaticKernel, WorkGroup}, + compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup}, element::WgpuElement, + kernel, tensor::WgpuTensor, }; use std::marker::PhantomData; @@ -76,6 +77,33 @@ pub fn into_contiguous( output } +/// Similar to [into contiguous](into_contiguous) but with dynamic rank. +pub fn into_contiguous_dyn( + client: WgpuComputeClient, + input: WgpuHandle, + input_shape: &[usize], + input_strides: &[usize], + output_shape: &[usize], + output_strides: &[usize], + num_elems: usize, +) -> WgpuHandle { + let handle = client.empty(num_elems * core::mem::size_of::()); + let info = kernel::build_info_dyn::( + &[input_shape, output_shape], + &[input_strides, output_strides], + ); + + let info_handle = client.create(bytemuck::cast_slice(&info)); + + let kernel = Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT))); + + client.execute(kernel, &[&input, &handle, &info_handle]); + + handle +} + /// Generates kernel source code by replacing some information using templating. pub struct KernelSettings< K: StaticKernelSource, @@ -184,6 +212,28 @@ pub fn build_info(tensors: &[&WgpuTensor]) info } +/// Similar to [build info](build_info) but with dynamic rank. +pub fn build_info_dyn(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec { + let rank = shapes.get(0).unwrap().len(); + let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; + info[0] = rank as u32; + + let mut current = 1; + for stride in strides.iter() { + for d in 0..rank { + info[current] = stride[d] as u32; + current += 1; + } + } + for shape in shapes.iter() { + for d in 0..rank { + info[current] = shape[d] as u32; + current += 1; + } + } + info +} + pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup { let num_elem_per_invocation = workgroup_size * workgroup_size; let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32);