WGPU: Support elemwise operation fusion (#948)

This commit is contained in:
Nathaniel Simard 2023-11-15 15:13:37 -05:00 committed by GitHub
parent 4fc0c27e31
commit 24014aca33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1718 additions and 367 deletions

View File

@ -21,6 +21,7 @@ ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-fusion = ["burn/wgpu", "burn/fusion"]
[dependencies]
burn = { path = "../burn" }

View File

@ -1,6 +1,14 @@
#[macro_export]
macro_rules! bench_on_backend {
() => {
#[cfg(feature = "wgpu-fusion")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Fusion;
bench::<Fusion<Wgpu<AutoGraphicsApi, f32, i32>>>(&WgpuDevice::default());
}
#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

View File

@ -2,7 +2,7 @@ use crate::{
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
HandleContainer,
};
use burn_tensor::{backend::Backend, Shape};
use burn_tensor::{backend::Backend, Device, Shape};
use core::marker::PhantomData;
use std::sync::Arc;
@ -36,12 +36,18 @@ impl<B: FusionBackend> Backend for Fusion<B> {
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
fn name() -> String {
format!("Fusion<{}>", B::name())
format!("fusion<{}>", B::name())
}
fn seed(seed: u64) {
B::seed(seed);
}
fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
client.drain_graph();
B::sync(device)
}
}
/// The status of a [fusion ops](FusionOps).
@ -116,14 +122,14 @@ pub trait FusionBackend: Backend {
/// The device type that can return an ID.
///
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device>;
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;
/// The list of operations that will be used to optimize the computational graph.
fn operations() -> Vec<Box<dyn FusionOps<Self>>>;
fn operations(device: &Device<Self>) -> Vec<Box<dyn FusionOps<Self>>>;
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(

View File

@ -1,6 +1,6 @@
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionTensor, TensorDescription, TensorId,
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
};
use burn_tensor::{
ops::{FloatElem, IntElem},
@ -18,26 +18,18 @@ pub trait FusionClient: Send + Sync + Clone {
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Register a new [tensor operation description](TensorOpsDescription).
fn register(&self, ops: TensorOpsDescription<Self::FusionBackend>);
/// Sync the computation.
fn sync(&self);
/// Register all lazy computation.
fn drain_graph(&self);
/// Get the current device used by all operations handled by this client.
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
/// Create an empty tensor.
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a float tensor with the given values.
fn create_tensor_float(
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a tensor with the given handle and shape.
fn register_tensor(
&self,
values: Vec<FloatElem<Self::FusionBackend>>,
handle: Handle<Self::FusionBackend>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create an integer tensor with the given values.
fn create_tensor_int(
&self,
values: Vec<IntElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self>;
/// Create a bool tensor with the given values.
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self>;
/// Read the values contained by a float tensor.
fn read_tensor_float<const D: usize>(
&self,

View File

@ -1,7 +1,7 @@
use super::FusionClient;
use crate::{
graph::{GraphExecution, TensorOpsDescription},
FusionBackend, FusionServer, FusionTensor,
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use spin::Mutex;
@ -49,10 +49,11 @@ where
self.server.lock().register(ops);
}
fn sync(&self) {
self.server.lock().sync();
fn drain_graph(&self) {
self.server.lock().drain_graph();
}
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self> {
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self> {
let id = self.server.lock().create_empty_handle();
FusionTensor::new(id, shape, self.clone())
@ -61,6 +62,18 @@ where
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
&self.device
}
fn register_tensor(
&self,
handle: Handle<Self::FusionBackend>,
shape: Vec<usize>,
) -> FusionTensor<Self> {
let mut server = self.server.lock();
let id = server.create_empty_handle();
server.handles.register_handle(id.as_ref().clone(), handle);
core::mem::drop(server);
FusionTensor::new(id, shape, self.clone())
}
fn read_tensor_float<const D: usize>(
&self,
@ -69,32 +82,6 @@ where
self.server.lock().read_float(tensor)
}
fn create_tensor_float(
&self,
values: Vec<FloatElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self> {
let id = self.server.lock().create_float_handle(values);
FusionTensor::new(id, shape, self.clone())
}
fn create_tensor_int(
&self,
values: Vec<burn_tensor::ops::IntElem<Self::FusionBackend>>,
shape: Vec<usize>,
) -> FusionTensor<Self> {
let id = self.server.lock().create_int_handle(values);
FusionTensor::new(id, shape, self.clone())
}
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self> {
let id = self.server.lock().create_bool_handle(values);
FusionTensor::new(id, shape, self.clone())
}
fn read_tensor_int<const D: usize>(
&self,
tensor: crate::TensorDescription,

View File

@ -32,8 +32,13 @@ impl<B: FusionBackend> GraphExecution<B> for GreedyGraphExecution {
}
match find_best_optimization_index(optimizations) {
Some(index) => graph.execute_optimization(handles, index, optimizations),
None => graph.execute(handles),
Some(index) => {
graph.execute_optimization(handles, index, optimizations);
}
None => {
graph.execute(handles);
optimizations.iter_mut().for_each(|ops| ops.reset());
}
}
if graph.is_empty() {

View File

@ -5,6 +5,7 @@ use burn_tensor::{
ops::{ConvOptions, ConvTransposeOptions},
Distribution, Element,
};
use core::hash::Hash;
use std::ops::Range;
/// General trait to abstract how a single operation is executed.
@ -652,6 +653,7 @@ pub enum BoolOpsDescription<B: FusionBackend> {
),
}
#[derive(Hash)]
/// Swap dim operation description.
pub struct SwapDimsDescription {
/// Input tensor description.
@ -664,6 +666,7 @@ pub struct SwapDimsDescription {
pub dim2: usize,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct ReshapeDescription {
pub input: TensorDescription,
@ -671,6 +674,7 @@ pub struct ReshapeDescription {
pub shape: Vec<usize>,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct BinaryOpsDescription {
pub lhs: TensorDescription,
@ -678,6 +682,7 @@ pub struct BinaryOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct UnaryOpsDescription {
pub input: TensorDescription,
@ -691,6 +696,7 @@ pub struct ScalarOpsDescription<E> {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct GatherOpsDescription {
pub tensor: TensorDescription,
@ -699,6 +705,7 @@ pub struct GatherOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct ScatterOpsDescription {
pub tensor: TensorDescription,
@ -708,6 +715,7 @@ pub struct ScatterOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct SelectOpsDescription {
pub tensor: TensorDescription,
@ -716,6 +724,7 @@ pub struct SelectOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct SelectAssignOpsDescription {
pub tensor: TensorDescription,
@ -725,6 +734,7 @@ pub struct SelectAssignOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct SliceOpsDescription {
pub tensor: TensorDescription,
@ -732,6 +742,7 @@ pub struct SliceOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct SliceAssignOpsDescription {
pub tensor: TensorDescription,
@ -740,6 +751,7 @@ pub struct SliceAssignOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaskWhereOpsDescription {
pub tensor: TensorDescription,
@ -773,6 +785,7 @@ pub struct RepeatOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct CatOpsDescription {
pub tensors: Vec<TensorDescription>,
@ -780,6 +793,7 @@ pub struct CatOpsDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct ReduceDimWithIndicesDescription {
pub tensor: TensorDescription,
@ -788,6 +802,7 @@ pub struct ReduceDimWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct EmbeddingDescription {
pub weights: TensorDescription,
@ -795,6 +810,7 @@ pub struct EmbeddingDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct EmbeddingBackwardDescription {
pub weights: TensorDescription,
@ -803,6 +819,7 @@ pub struct EmbeddingBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct Conv1dDescription {
pub x: TensorDescription,
@ -812,6 +829,7 @@ pub struct Conv1dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct Conv2dDescription {
pub x: TensorDescription,
@ -821,6 +839,7 @@ pub struct Conv2dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct ConvTranspose1dDescription {
pub x: TensorDescription,
@ -830,6 +849,7 @@ pub struct ConvTranspose1dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct ConvTranspose2dDescription {
pub x: TensorDescription,
@ -839,6 +859,7 @@ pub struct ConvTranspose2dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AvgPool1dDescription {
pub x: TensorDescription,
@ -849,6 +870,7 @@ pub struct AvgPool1dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AvgPool2dDescription {
pub x: TensorDescription,
@ -859,6 +881,7 @@ pub struct AvgPool2dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AvgPool1dBackwardDescription {
pub x: TensorDescription,
@ -870,6 +893,7 @@ pub struct AvgPool1dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AvgPool2dBackwardDescription {
pub x: TensorDescription,
@ -881,6 +905,7 @@ pub struct AvgPool2dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dDescription {
pub x: TensorDescription,
@ -888,6 +913,7 @@ pub struct AdaptiveAvgPool1dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dDescription {
pub x: TensorDescription,
@ -895,6 +921,7 @@ pub struct AdaptiveAvgPool2dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dBackwardDescription {
pub x: TensorDescription,
@ -902,6 +929,7 @@ pub struct AdaptiveAvgPool1dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dBackwardDescription {
pub x: TensorDescription,
@ -909,6 +937,7 @@ pub struct AdaptiveAvgPool2dBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool1dDescription {
pub x: TensorDescription,
@ -919,6 +948,7 @@ pub struct MaxPool1dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesDescription {
pub x: TensorDescription,
@ -930,6 +960,7 @@ pub struct MaxPool1dWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesBackwardDescription {
pub x: TensorDescription,
@ -942,6 +973,7 @@ pub struct MaxPool1dWithIndicesBackwardDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool2dDescription {
pub x: TensorDescription,
@ -952,6 +984,7 @@ pub struct MaxPool2dDescription {
pub out: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool2dWithIndicesDescription {
pub x: TensorDescription,
@ -963,6 +996,7 @@ pub struct MaxPool2dWithIndicesDescription {
pub out_indices: TensorDescription,
}
#[derive(Hash)]
#[allow(missing_docs)]
pub struct MaxPool2dWithIndicesBackwardDescription {
pub x: TensorDescription,

View File

@ -1,8 +1,5 @@
use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus};
use burn_tensor::{
ops::{FloatElem, IntElem},
Data, ElementConversion, Shape,
};
use burn_tensor::Shape;
use std::{collections::HashMap, sync::Arc};
/// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources
@ -17,10 +14,7 @@ pub struct HandleContainer<B: FusionBackend> {
}
enum Handle<B: FusionBackend> {
Empty,
DataFloat(Vec<FloatElem<B>>),
DataInt(Vec<IntElem<B>>),
DataBool(Vec<bool>),
NotInit,
Existing(B::Handle),
}
@ -34,47 +28,38 @@ impl<B: FusionBackend> HandleContainer<B> {
}
}
/// Register a handle for the given [tensor id](TensorId).
pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) {
self.handles.insert(id, Handle::Existing(handle));
}
/// Get the handle for the given [tensor id](TensorId).
pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle {
let (id, handle) = self
.handles
.remove_entry(&tensor.id)
.unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id));
match handle {
Handle::Existing(handle) => match tensor.status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
handle
}
TensorStatus::ReadWrite => handle,
TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."),
},
Handle::NotInit => panic!("Cannot get uninitialized handle."),
}
}
/// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the
/// given [tensor description](TensorDescription).
pub fn get_float_tensor<const D: usize>(
&mut self,
tensor: &TensorDescription,
) -> B::TensorPrimitive<D> {
let (id, handle) = self
.handles
.remove_entry(&tensor.id)
.unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id));
if let Handle::Existing(handle) = handle {
match tensor.status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
return B::float_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::ReadWrite => {
return B::float_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::NotInit => panic!("Can't get uninitialized tensor."),
}
}
let output = match handle {
Handle::Empty => B::empty(Shape::from(tensor.shape.clone()), &self.device),
Handle::DataFloat(values) => B::from_data(
Data::new(values, Shape::from(tensor.shape.clone())),
&self.device,
),
Handle::Existing(_) => unreachable!(),
Handle::DataInt(_) => panic!("From int unsupported when getting float tensor."),
Handle::DataBool(_) => panic!("From bool unsupported when getting float tensor."),
};
if let TensorStatus::ReadOnly = tensor.status {
self.handles
.insert(id, Handle::Existing(B::float_tensor_handle(output.clone())));
}
output
B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
}
/// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the
@ -83,38 +68,7 @@ impl<B: FusionBackend> HandleContainer<B> {
&mut self,
tensor: &TensorDescription,
) -> B::IntTensorPrimitive<D> {
let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap();
if let Handle::Existing(handle) = handle {
match tensor.status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
return B::int_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::ReadWrite => {
return B::int_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::NotInit => panic!("Can get uninitialized tensor."),
}
}
let output = match handle {
Handle::Empty => B::int_empty(Shape::from(tensor.shape.clone()), &self.device),
Handle::DataInt(values) => B::int_from_data(
Data::new(values, Shape::from(tensor.shape.clone())),
&self.device,
),
Handle::Existing(_) => unreachable!(),
Handle::DataFloat(_) => panic!("From float unsupported when getting int tensor."),
Handle::DataBool(_) => panic!("From bool unsupported when getting int tensor."),
};
if let TensorStatus::ReadOnly = tensor.status {
self.handles
.insert(id, Handle::Existing(B::int_tensor_handle(output.clone())));
}
output
B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
}
/// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the
@ -123,41 +77,7 @@ impl<B: FusionBackend> HandleContainer<B> {
&mut self,
tensor: &TensorDescription,
) -> B::BoolTensorPrimitive<D> {
let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap();
if let Handle::Existing(handle) = handle {
match tensor.status {
TensorStatus::ReadOnly => {
self.handles.insert(id, Handle::Existing(handle.clone()));
return B::bool_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::ReadWrite => {
return B::bool_tensor(handle, Shape::from(tensor.shape.clone()));
}
TensorStatus::NotInit => panic!("Can get uninitialized tensor."),
}
}
let output = match handle {
Handle::Empty => B::int_equal_elem(
B::int_empty(Shape::from(tensor.shape.clone()), &self.device),
0.elem(),
),
Handle::DataBool(data) => B::bool_from_data(
Data::new(data, Shape::from(tensor.shape.clone())),
&self.device,
),
Handle::Existing(_) => unreachable!(),
Handle::DataFloat(_) => panic!("From float unsupported when getting bool tensor."),
Handle::DataInt(_) => panic!("From int unsupported when getting bool tensor."),
};
if let TensorStatus::ReadOnly = tensor.status {
self.handles
.insert(id, Handle::Existing(B::bool_tensor_handle(output.clone())));
}
output
B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
}
/// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId).
@ -191,37 +111,10 @@ impl<B: FusionBackend> HandleContainer<B> {
}
/// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId).
pub fn create_tensor_empty(&mut self) -> Arc<TensorId> {
pub fn create_tensor_uninit(&mut self) -> Arc<TensorId> {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id.clone(), Handle::Empty);
Arc::new(id)
}
/// Lazily create a new float tensor and return its corresponding [tensor id](TensorId).
pub(crate) fn create_tensor_float(&mut self, values: Vec<FloatElem<B>>) -> Arc<TensorId> {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id.clone(), Handle::DataFloat(values));
Arc::new(id)
}
/// Lazily create a new int tensor and return its corresponding [tensor id](TensorId).
pub(crate) fn create_tensor_int(&mut self, values: Vec<IntElem<B>>) -> Arc<TensorId> {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id.clone(), Handle::DataInt(values));
Arc::new(id)
}
/// Lazily create a new bool tensor and return its corresponding [tensor id](TensorId).
pub(crate) fn create_tensor_bool(&mut self, values: Vec<bool>) -> Arc<TensorId> {
let id = TensorId::new(self.counter);
self.counter += 1;
self.handles.insert(id.clone(), Handle::DataBool(values));
self.handles.insert(id.clone(), Handle::NotInit);
Arc::new(id)
}

View File

@ -17,8 +17,9 @@ use burn_tensor::{
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::bool_empty(shape.clone(), device);
client.create_tensor_empty(shape.dims.into())
client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into())
}
fn bool_shape<const D: usize>(tensor: &BoolTensor<Self, D>) -> Shape<D> {
@ -36,8 +37,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
device: &Device<Self>,
) -> BoolTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::bool_from_data(data, device);
let shape = B::bool_shape(&tensor);
client.create_tensor_bool(data.value, data.shape.dims.into())
client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into())
}
fn bool_into_int<const D: usize>(
@ -55,7 +58,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt(
@ -84,7 +87,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::BoolOps(
BoolOpsDescription::IntoFloat(
@ -139,7 +142,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.create_tensor_empty(shape.clone());
let out = tensor.client.tensor_uninitialized(shape.clone());
tensor
.client
@ -183,7 +186,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape.push(tensor.shape[i]);
}
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -227,7 +230,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -279,7 +282,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape[dim] += tensor.shape[dim];
}
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat(
CatOpsDescription {
@ -312,7 +315,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::BaseOpsBool(
BaseOpsDescription::Equal(
@ -341,7 +344,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::BoolOps(
crate::graph::BoolOpsDescription::Not(
@ -377,7 +380,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape[dim1] = tensor.shape[dim2];
shape[dim2] = tensor.shape[dim1];
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client

View File

@ -26,8 +26,10 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
device: &Device<Self>,
) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::from_data(data, device);
let shape = B::shape(&tensor);
client.create_tensor_float(data.value, data.shape.dims.into())
client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into())
}
fn random<const D: usize>(
@ -54,7 +56,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random(
(out.to_description_out(), distribution),
@ -79,7 +81,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::<D>)),
@ -103,7 +105,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::<D>)),
@ -131,7 +133,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Full(
@ -188,7 +190,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::FloatOps(
FloatOpsDescription::IntoInt(
@ -205,8 +207,9 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::empty(shape.clone(), device);
client.create_tensor_empty(shape.dims.into())
client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into())
}
fn add<const D: usize>(
@ -217,7 +220,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Add(
@ -239,7 +242,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(AddOps, B::add_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::AddScalar(
@ -261,7 +264,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMinOps, B::clamp_min);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::ClampMin(
@ -283,7 +286,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(ClampMaxOps, B::clamp_max);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::ClampMax(
@ -317,7 +320,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Clamp(
@ -342,7 +345,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Sub(
@ -364,7 +367,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(SubOps, B::sub_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::SubScalar(
@ -388,7 +391,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Mul(
@ -410,7 +413,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(MulOps, B::mul_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MulScalar(
@ -434,7 +437,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Div(
@ -456,7 +459,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> FloatTensor<Self, D> {
scalar_float_ops!(DivOps, B::div_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::DivScalar(
@ -483,7 +486,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
shape[D - 2] = lhs.shape[D - 2];
shape[D - 1] = rhs.shape[D - 1];
let out = lhs.client.create_tensor_empty(shape);
let out = lhs.client.tensor_uninitialized(shape);
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul(
@ -519,7 +522,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
shape[dim1] = tensor.shape[dim2];
shape[dim2] = tensor.shape[dim1];
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -556,7 +559,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.create_tensor_empty(shape.clone());
let out = tensor.client.tensor_uninitialized(shape.clone());
tensor
.client
@ -595,7 +598,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = indices.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -638,7 +641,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -681,7 +684,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape: Vec<usize> = tensor.shape.clone();
shape[dim] = indices.shape[0];
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -724,7 +727,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -769,7 +772,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
shape.push(tensor.shape[i]);
}
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -813,7 +816,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -855,7 +858,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -896,7 +899,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -924,7 +927,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::BaseOpsFloat(
BaseOpsDescription::Equal(
@ -946,7 +949,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(EqualElemOps, B::equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::EqualElem(
@ -970,7 +973,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Greater(
@ -992,7 +995,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::GreaterElem(
@ -1016,7 +1019,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::GreaterEqual(
@ -1038,7 +1041,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::GreaterEqualElem(
@ -1062,7 +1065,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Lower(
@ -1084,7 +1087,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(LowerElemOps, B::lower_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::LowerElem(
@ -1108,7 +1111,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::LowerEqual(
@ -1130,7 +1133,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::LowerEqualElem(
@ -1149,7 +1152,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(SumOps, B::sum);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Sum(
@ -1169,7 +1172,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::SumDim(
@ -1188,7 +1191,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MeanOps, B::mean);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Mean(
@ -1208,7 +1211,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MeanDim(
@ -1239,7 +1242,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(ExpOps, B::exp);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(
@ -1256,7 +1259,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(LogOps, B::log);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
@ -1273,7 +1276,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Log1pOps, B::log1p);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p(
@ -1290,7 +1293,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
scalar_float_ops!(PowfOps, B::powf, f32);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf(
@ -1308,7 +1311,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SqrtOps, B::sqrt);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt(
@ -1325,7 +1328,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(AbsOps, B::abs);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Abs(
@ -1343,7 +1346,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(CosOps, B::cos);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos(
@ -1360,7 +1363,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(SinOps, B::sin);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin(
@ -1377,7 +1380,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::tanh);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh(
@ -1394,7 +1397,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::recip);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
UnaryOpsDescription {
@ -1409,7 +1412,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::erf);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf(
@ -1452,7 +1455,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
shape[dim] += tensor.shape[dim];
}
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat(
CatOpsDescription {
@ -1471,7 +1474,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::ArgMax(
@ -1492,7 +1495,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::ArgMin(
@ -1511,7 +1514,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn max<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MaxOps, B::max);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Max(
@ -1531,7 +1534,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MaxDim(
@ -1568,8 +1571,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let client = tensor.client.clone();
let out = client.create_tensor_empty(shape.clone());
let out_indices = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape.clone());
let out_indices = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MaxDimWithIndices(
@ -1589,7 +1592,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
fn min<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
unary_float_ops!(MinOps, B::min);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::Min(
@ -1609,7 +1612,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MinDim(
@ -1646,8 +1649,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let client = tensor.client.clone();
let out = client.create_tensor_empty(shape.clone());
let out_indices = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape.clone());
let out_indices = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsFloat(
NumericOpsDescription::MinDimWithIndices(

View File

@ -22,8 +22,9 @@ use core::ops::Range;
impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::int_empty(shape.clone(), device);
client.create_tensor_empty(shape.dims.into())
client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into())
}
fn int_shape<const D: usize>(tensor: &IntTensor<Self, D>) -> Shape<D> {
@ -39,8 +40,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
device: &Device<Self>,
) -> IntTensor<Self, D> {
let client = get_client::<B>(&device.clone().into());
let tensor = B::int_from_data(data, device);
let shape = B::int_shape(&tensor);
client.create_tensor_int(data.value, data.shape.dims.into())
client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into())
}
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
@ -83,7 +86,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = shape.dims.into();
let out = tensor.client.create_tensor_empty(shape.clone());
let out = tensor.client.tensor_uninitialized(shape.clone());
tensor
.client
@ -127,7 +130,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape.push(tensor.shape[i]);
}
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -169,7 +172,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -211,7 +214,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -252,7 +255,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -292,7 +295,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = indices.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -335,7 +338,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -378,7 +381,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape: Vec<usize> = tensor.shape.clone();
shape[dim] = indices.shape[0];
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -421,7 +424,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -471,7 +474,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] += tensor.shape[dim];
}
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat(
CatOpsDescription {
@ -493,7 +496,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client
.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal(
@ -514,7 +517,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::EqualElem(
@ -538,7 +541,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Greater(
@ -560,7 +563,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::GreaterElem(
@ -584,7 +587,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::GreaterEqual(
@ -606,7 +609,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::GreaterEqualElem(
@ -630,7 +633,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Lower(
@ -652,7 +655,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::LowerElem(
@ -676,7 +679,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::LowerEqual(
@ -698,7 +701,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::LowerEqualElem(
@ -722,7 +725,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -745,7 +748,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(AddOps, B::int_add_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -770,7 +773,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -793,7 +796,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(SubOps, B::int_sub_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -818,7 +821,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -841,7 +844,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(MulOps, B::int_mul_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -866,7 +869,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs
.client
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -889,7 +892,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(DivOps, B::int_div_scalar);
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
out.client
.register(graph::TensorOpsDescription::NumericOpsInt(
@ -921,7 +924,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::<D>)),
@ -945,7 +948,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = shape.dims.into();
let client = get_client::<B>(&device.clone().into());
let out = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::<D>)),
@ -957,7 +960,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_sum<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(SumOps, B::int_sum);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Sum(
@ -977,7 +980,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::SumDim(
@ -996,7 +999,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MeanOps, B::int_mean);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Mean(
@ -1016,7 +1019,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::MeanDim(
@ -1037,7 +1040,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::ArgMax(
@ -1058,7 +1061,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::ArgMin(
@ -1080,7 +1083,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMinOps, B::int_clamp_min);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::ClampMin(
@ -1102,7 +1105,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::ClampMax(
@ -1136,7 +1139,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Clamp(
@ -1156,7 +1159,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
unary_int_ops!(AbsOps, B::int_abs);
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Abs(
@ -1184,7 +1187,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
}
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
out.client.register(TensorOpsDescription::IntOps(
graph::IntOpsDescription::IntoFloat(
@ -1220,7 +1223,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim1] = tensor.shape[dim2];
shape[dim2] = tensor.shape[dim1];
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
tensor
.client
@ -1243,7 +1246,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_max<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MaxOps, B::int_max);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Max(
@ -1263,7 +1266,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::MaxDim(
@ -1300,8 +1303,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let client = tensor.client.clone();
let out = client.create_tensor_empty(shape.clone());
let out_indices = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape.clone());
let out_indices = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::MaxDimWithIndices(
@ -1321,7 +1324,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_min<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
unary_int_ops!(MinOps, B::int_min);
let out = tensor.client.create_tensor_empty(vec![1]);
let out = tensor.client.tensor_uninitialized(vec![1]);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::Min(
@ -1341,7 +1344,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor.client.create_tensor_empty(shape);
let out = tensor.client.tensor_uninitialized(shape);
out.client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::MinDim(
@ -1378,8 +1381,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let client = tensor.client.clone();
let out = client.create_tensor_empty(shape.clone());
let out_indices = client.create_tensor_empty(shape);
let out = client.tensor_uninitialized(shape.clone());
let out_indices = client.tensor_uninitialized(shape);
client.register(TensorOpsDescription::NumericOpsInt(
NumericOpsDescription::MinDimWithIndices(

View File

@ -56,7 +56,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], weight.shape[0], size];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::Conv1d(
@ -115,7 +115,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::Conv2d(
@ -168,7 +168,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::ConvTranspose1d(
@ -229,7 +229,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::ConvTranspose2d(
@ -275,7 +275,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AvgPool1d(
@ -326,7 +326,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AvgPool2d(
@ -374,7 +374,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AvgPool1dBackward(
@ -423,7 +423,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AvgPool2dBackward(
@ -472,7 +472,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool1d(
@ -533,7 +533,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool2d(
@ -581,8 +581,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let out = x.client.create_tensor_empty(shape.clone());
let out_indices = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape.clone());
let out_indices = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool1dWithIndices(
@ -645,8 +645,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let out = x.client.create_tensor_empty(shape.clone());
let out_indices = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape.clone());
let out_indices = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool2dWithIndices(
@ -698,7 +698,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward(
@ -751,7 +751,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward(
@ -787,7 +787,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
let shape = vec![x.shape[0], x.shape[1], output_size];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d(
@ -821,7 +821,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let out = x.client.create_tensor_empty(shape);
let out = x.client.tensor_uninitialized(shape);
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d(
@ -855,7 +855,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward(
@ -889,7 +889,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
}
}
let out = x.client.create_tensor_empty(x.shape.clone());
let out = x.client.tensor_uninitialized(x.shape.clone());
x.client.clone().register(TensorOpsDescription::ModuleOps(
crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward(

View File

@ -23,7 +23,7 @@ where
G: GraphExecution<B>,
{
pub fn new(device: B::FusionDevice) -> Self {
let optimizations = B::operations()
let optimizations = B::operations(&device.clone().into())
.into_iter()
.map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default())))
.collect();
@ -53,7 +53,11 @@ where
);
}
pub fn sync(&mut self) {
pub fn drain_graph(&mut self) {
if self.graph.is_empty() {
return;
}
self.execution.maybe_execute(
&mut self.graph,
&mut self.handles,
@ -63,19 +67,7 @@ where
}
pub fn create_empty_handle(&mut self) -> Arc<TensorId> {
self.handles.create_tensor_empty()
}
pub fn create_float_handle(&mut self, values: Vec<FloatElem<B>>) -> Arc<TensorId> {
self.handles.create_tensor_float(values)
}
pub fn create_int_handle(&mut self, values: Vec<IntElem<B>>) -> Arc<TensorId> {
self.handles.create_tensor_int(values)
}
pub fn create_bool_handle(&mut self, values: Vec<bool>) -> Arc<TensorId> {
self.handles.create_tensor_bool(values)
self.handles.create_tensor_uninit()
}
pub fn read_float<const D: usize>(
@ -84,7 +76,7 @@ where
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<B>, D>> {
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.sync();
self.drain_graph();
let tensor = self.handles.get_float_tensor(&tensor);
B::into_data(tensor)
@ -96,7 +88,7 @@ where
) -> burn_tensor::Reader<burn_tensor::Data<IntElem<B>, D>> {
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.sync();
self.drain_graph();
let tensor = self.handles.get_int_tensor(&tensor);
B::int_into_data(tensor)
@ -108,7 +100,7 @@ where
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
// Make sure all registered operations are executed.
// The underlying backend can still be async.
self.sync();
self.drain_graph();
let tensor = self.handles.get_bool_tensor(&tensor);
B::bool_into_data(tensor)

View File

@ -127,7 +127,7 @@ pub struct TensorId {
}
/// The status of the current tensor.
#[derive(Clone, Debug)]
#[derive(Hash, Clone, Debug, PartialEq, Eq)]
pub enum TensorStatus {
/// The tensor can be read, but not written.
ReadOnly,
@ -147,7 +147,7 @@ pub enum TensorStatus {
/// 2. Status::ReadOnly
/// 3. Status::ReadOnly
/// 4. Status::ReadWrite
#[derive(Debug)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct TensorDescription {
/// The [tensor id](TensorId).
pub id: TensorId,
@ -158,7 +158,8 @@ pub struct TensorDescription {
}
impl TensorId {
pub(crate) fn new(value: u64) -> Self {
/// Create a new tensor id.
pub fn new(value: u64) -> Self {
Self { value }
}
}

View File

@ -66,7 +66,7 @@ pub struct Conv1dBackward<B: Backend> {
}
/// Convolution options.
#[derive(new, Debug, Clone)]
#[derive(new, Debug, Clone, Hash)]
pub struct ConvOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],
@ -82,7 +82,7 @@ pub struct ConvOptions<const N: usize> {
}
/// Transposed convolution options.
#[derive(new, Debug, Clone)]
#[derive(new, Debug, Clone, Hash)]
pub struct ConvTransposeOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],

View File

@ -53,11 +53,16 @@ burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features =
burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" }
burn-fusion = { path = "../burn-fusion", version = "0.11.0" }
serial_test = "2.0.0"
pretty_assertions = {workspace = true}
[[bench]]
name = "matmul"
harness = false
[[bench]]
name = "fused_elemwise"
harness = false
[[bench]]
name = "reduction"
harness = false

View File

@ -0,0 +1,74 @@
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_fusion::Fusion;
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::Wgpu;
use burn_wgpu::WgpuDevice;
use derive_new::new;
use std::marker::PhantomData;
#[derive(new)]
struct ElemWiseBenchmark<B: Backend> {
shape: Shape<3>,
device: B::Device,
repeat: usize,
_b: PhantomData<B>,
}
impl<B: Backend> Benchmark for ElemWiseBenchmark<B> {
type Args = (Tensor<B, 3>, Tensor<B, 3>);
fn name(&self) -> String {
format!(
"Backend {} Shape {:?} Repeat {}",
B::name(),
self.shape.dims,
self.repeat
)
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.repeat {
let tmp_0 = lhs.clone() + rhs.clone();
let tmp_1 = rhs.clone() * tmp_0.clone();
let tmp_2 = rhs.clone().exp();
let tmp_3 = tmp_0 * tmp_1;
let _tmp_4 = tmp_2 / tmp_3;
}
}
fn prepare(&self) -> Self::Args {
B::seed(10);
let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
(lhs, rhs)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[allow(dead_code)]
/// Runs the benchmarks for wgpu matmul implementations
pub fn bench(device: &WgpuDevice) {
run_benchmark(ElemWiseBenchmark::<Wgpu>::new(
Shape::new([256, 256, 1024]),
device.clone(),
10,
));
run_benchmark(ElemWiseBenchmark::<Fusion<Wgpu>>::new(
Shape::new([256, 256, 1024]),
device.clone(),
10,
));
}
fn main() {
bench(&WgpuDevice::BestAvailable)
}

View File

@ -154,7 +154,9 @@ where
return pipeline.clone();
}
let pipeline = self.compile_source(&kernel.source().complete());
let source = kernel.source().complete();
log::trace!("Compiling kernel {kernel_id}:\n {source}");
let pipeline = self.compile_source(&source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
pipeline

View File

@ -1,6 +1,7 @@
use crate::{
compute::{WgpuComputeClient, WgpuHandle},
element::WgpuElement,
fusion::FloatElementWiseFusionOps,
tensor::WgpuTensor,
FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice,
};
@ -32,8 +33,8 @@ where
type Handle = WgpuFusionHandle;
type FusionClient = MutexFusionClient<Self, GreedyGraphExecution>;
fn operations() -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
Vec::new()
fn operations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))]
}
fn float_tensor<const D: usize>(
@ -70,7 +71,27 @@ where
}
}
#[derive(Debug, Clone)]
pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![0; shape.len()];
let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
}
num_elems
}
#[derive(new, Debug, Clone)]
/// Handle to be used when fusing operations.
pub struct WgpuFusionHandle {
/// Compute client for wgpu.

View File

@ -0,0 +1,27 @@
use super::Operator;
use std::fmt::Display;
/// A body is composed of a list of [operators](Operator).
///
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
/// X and Y, but with Z=1.
#[derive(Hash, new)]
pub struct Body {
operators: Vec<Operator>,
}
impl Display for Body {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(
"let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n",
)?;
f.write_str("let rank: u32 = info[0];\n\n")?;
for ops in self.operators.iter() {
f.write_fmt(format_args!("{ops}"))?;
f.write_str("\n")?;
}
Ok(())
}
}

View File

@ -0,0 +1,71 @@
use super::Elem;
use std::fmt::Display;
/// Not all functions are native to WGSL, so this struct allows to support more functions.
#[derive(Hash, PartialEq, Eq, Clone)]
pub enum Function {
Powf(Elem),
Erf(Elem),
}
impl Display for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Function::Powf(elem) => format_powf(f, elem),
Function::Erf(elem) => format_erf(f, elem),
}
}
}
fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
f.write_fmt(format_args!(
"
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
let modulo = rhs % 2.0;
if (modulo == 0.0) {{
// Even number
return pow(abs(lhs), rhs);
}} else if (modulo == 1.0 && lhs < 0.0) {{
// Odd number
return -1.0 * pow(-1.0 * lhs, rhs);
}} else {{
// Float number
return pow(lhs, rhs);
}}
}}
"
))
}
fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
f.write_fmt(format_args!(
"
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5×107)
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = erf(x).
fn erf_positive(x: {elem}) -> {elem} {{
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let t = 1.0 / (1.0 + p * abs(x));
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
return 1.0 - (tmp * t * exp(-x * x));
}}
fn erf(x: {elem}) -> {elem} {{
if (x < 0.0) {{
return -1.0 * erf_positive(-1.0 * x);
}}
return erf_positive(x);
}}
"
))
}

View File

@ -0,0 +1,11 @@
mod body;
mod function;
mod operator;
mod shader;
mod variable;
pub use body::*;
pub use function::*;
pub use operator::*;
pub use shader::*;
pub use variable::*;

View File

@ -0,0 +1,146 @@
use super::Variable;
use std::fmt::Display;
/// All operators that can be fused in a WGSL compute shader.
#[derive(Debug, Hash, Clone)]
pub enum Operator {
Add {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Sub {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Mul {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Div {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Abs {
input: Variable,
out: Variable,
},
Exp {
input: Variable,
out: Variable,
},
Log {
input: Variable,
out: Variable,
},
Log1p {
input: Variable,
out: Variable,
},
Cos {
input: Variable,
out: Variable,
},
Sin {
input: Variable,
out: Variable,
},
Tanh {
input: Variable,
out: Variable,
},
Powf {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Erf {
input: Variable,
out: Variable,
},
AssignGlobal {
input: Variable,
out: Variable,
},
ReadGlobal {
variable: Variable,
position: usize,
position_out: usize,
},
}
impl Display for Operator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Operator::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} + {rhs};"))
}
Operator::Sub { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} - {rhs};"))
}
Operator::Mul { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} * {rhs};"))
}
Operator::Div { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} / {rhs};"))
}
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
Operator::Powf { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
}
Operator::Log1p { input, out } => {
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
}
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
Operator::Tanh { input, out } => {
f.write_fmt(format_args!("let {out} = tanh({input});"))
}
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operator::AssignGlobal { input, out } => {
f.write_fmt(format_args!("{out}_global[id] = {input};"))
}
Operator::ReadGlobal {
variable,
position,
position_out,
} => {
let (global, local) = match variable {
Variable::Input(number) => {
(format!("input_{number}_global"), format!("input_{number}"))
}
Variable::Local(_) => panic!("can't read globala local variable."),
Variable::Output(number) => (
format!("output_{number}_global"),
format!("output_{number}"),
),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
};
f.write_fmt(format_args!(
"
var index_{local}: u32 = 0u;
for (var i: u32 = 1u; i <= rank; i++) {{
let position = {position}u * (2u * rank);
let position_out = {position_out}u * (2u * rank);
let stride = info[position + i];
let stride_out = info[position_out + i];
let shape = info[position + rank + i];
index_{local} += id / stride_out % shape * stride;
}}
let {local} = {global}[index_{local}];
"
))
}
}
}
}

View File

@ -0,0 +1,201 @@
use super::{Body, Function};
use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT};
use std::{
collections::hash_map::DefaultHasher,
fmt::Display,
hash::{Hash, Hasher},
};
#[derive(Hash, PartialEq, Eq)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}
#[derive(Hash, PartialEq, Eq)]
pub enum Visibility {
Read,
ReadWrite,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum Elem {
F32,
#[allow(dead_code)]
I32,
U32,
}
#[derive(Hash, PartialEq, Eq)]
pub struct Binding {
pub location: Location,
pub visibility: Visibility,
pub elem: Elem,
pub size: Option<usize>,
}
#[derive(Hash, PartialEq, Eq)]
pub struct WorkgroupSize {
pub x: usize,
pub y: usize,
pub z: usize,
}
impl Default for WorkgroupSize {
fn default() -> Self {
Self {
x: WORKGROUP_DEFAULT,
y: WORKGROUP_DEFAULT,
z: 1,
}
}
}
#[derive(Hash)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub named: Vec<(String, Binding)>,
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub num_workgroups: bool,
pub body: Body,
pub functions: Vec<Function>,
}
impl DynamicKernelSource for ComputeShader {
fn source(&self) -> SourceTemplate {
SourceTemplate::new(self.to_string())
}
fn id(&self) -> String {
let mut s = DefaultHasher::new();
self.hash(&mut s);
s.finish().to_string()
}
}
impl Display for ComputeShader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::format_bindings(f, "input", &self.inputs, 0)?;
Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?;
for (i, (name, binding)) in self.named.iter().enumerate() {
Self::format_binding(
f,
name.as_str(),
binding,
self.inputs.len() + self.outputs.len() + i,
)?;
}
f.write_fmt(format_args!(
"const WORKGROUP_SIZE_X = {}u;
const WORKGROUP_SIZE_Y = {}u;
const WORKGROUP_SIZE_Z = {}u;\n",
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z
))?;
f.write_fmt(format_args!(
"
@compute
@workgroup_size({}, {}, {})
fn main(
",
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z
))?;
if self.global_invocation_id {
f.write_str(" @builtin(global_invocation_id) global_id: vec3<u32>,\n")?;
}
if self.num_workgroups {
f.write_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>,\n")?;
}
f.write_fmt(format_args!(
") {{
{}
}}",
self.body
))?;
for function in self.functions.iter() {
f.write_fmt(format_args!("{function}\n\n"))?;
}
Ok(())
}
}
impl ComputeShader {
fn format_bindings(
f: &mut core::fmt::Formatter<'_>,
prefix: &str,
bindings: &[Binding],
num_entry: usize,
) -> core::fmt::Result {
for (i, binding) in bindings.iter().enumerate() {
Self::format_binding(
f,
format!("{prefix}_{i}_global").as_str(),
binding,
num_entry + i,
)?;
}
Ok(())
}
fn format_binding(
f: &mut core::fmt::Formatter<'_>,
name: &str,
binding: &Binding,
num_entry: usize,
) -> core::fmt::Result {
let ty = match binding.size {
Some(size) => format!("array<{}, {}>", binding.elem, size),
None => format!("array<{}>", binding.elem),
};
f.write_fmt(format_args!(
"@group(0)
@binding({})
var<{}, {}> {}: {};
\n",
num_entry, binding.location, binding.visibility, name, ty
))?;
Ok(())
}
}
impl Display for Location {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Location::Storage => f.write_str("storage"),
Location::Workgroup => f.write_str("workgroup"),
}
}
}
impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Elem::F32 => f.write_str("f32"),
Elem::I32 => f.write_str("i32"),
Elem::U32 => f.write_str("u32"),
}
}
}
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Visibility::Read => f.write_str("read"),
Visibility::ReadWrite => f.write_str("read_write"),
}
}
}

View File

@ -0,0 +1,21 @@
use super::Elem;
use std::fmt::Display;
#[derive(Debug, Hash, Clone)]
pub enum Variable {
Input(u16),
Scalar(u16, Elem),
Local(u16),
Output(u16),
}
impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
}
}
}

View File

@ -0,0 +1,3 @@
mod ops;
pub use ops::*;

View File

@ -0,0 +1,463 @@
use crate::{
fusion::codegen::{Elem, Operator, Variable},
fusion::kernel::FusionKernel,
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_fusion::{
graph::{
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription,
TensorOpsDescription, UnaryOpsDescription,
},
FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription,
TensorId,
};
use burn_tensor::{Device, Element};
use hashbrown::HashMap;
use std::sync::Arc;
/// Fused element wise operations that are normally memory bound.
pub struct FloatElementWiseFusionOps<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
pub(crate) inputs: Vec<TensorDescription>,
pub(crate) locals: HashMap<TensorId, u16>,
pub(crate) tensors: HashMap<TensorId, TensorDescription>,
pub(crate) scalars_f32: Vec<f32>,
pub(crate) operators: Vec<Operator>,
pub(crate) properties: FusionProperties,
pub(crate) current_output_shape: Vec<usize>,
device: Device<Wgpu<G, F, I>>,
}
impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G, F, I>>
for FloatElementWiseFusionOps<G, F, I>
{
fn register(&mut self, ops: Arc<TensorOpsDescription<Wgpu<G, F, I>>>) -> FusionStatus {
match ops.as_ref() {
TensorOpsDescription::FloatOps(ops) => {
if !self.register_float(ops) {
return FusionStatus::Closed(self.properties);
}
}
TensorOpsDescription::NumericOpsFloat(ops) => {
if !self.register_numeric(ops) {
return FusionStatus::Closed(self.properties);
}
}
_ => {
return FusionStatus::Closed(self.properties);
}
};
self.properties.score += 1;
self.properties.ready = self.operators.len() > 1;
FusionStatus::Open(self.properties)
}
fn execute(&mut self, handles: &mut HandleContainer<Wgpu<G, F, I>>) {
let inputs = self.input_descriptions();
let outputs = self.output_descriptions();
let locals = outputs
.iter()
.map(|out| *self.locals.get(&out.id).unwrap())
.collect::<Vec<_>>();
FusionKernel::new(&self.device)
.inputs(&inputs, &self.scalars_f32)
.body(&self.operators)
.outputs(&outputs, &locals)
.execute(handles);
}
fn reset(&mut self) {
self.inputs.clear();
self.locals.drain();
self.tensors.clear();
self.scalars_f32.clear();
self.operators.clear();
self.properties = FusionProperties::default();
self.current_output_shape.clear();
}
fn len(&self) -> usize {
self.operators.len()
}
}
impl<G, F, I> FloatElementWiseFusionOps<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
pub fn new(device: Device<Wgpu<G, F, I>>) -> Self {
Self {
inputs: Vec::new(),
locals: HashMap::new(),
tensors: HashMap::new(),
scalars_f32: Vec::new(),
operators: Vec::new(),
current_output_shape: Vec::new(),
properties: FusionProperties::default(),
device,
}
}
fn input_descriptions(&self) -> Vec<&TensorDescription> {
self.inputs
.iter()
.map(|input| {
let updated_tensor = self.tensors.get(&input.id).unwrap();
updated_tensor
})
.collect::<Vec<_>>()
}
fn output_descriptions(&self) -> Vec<&TensorDescription> {
let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new();
// Mark a variable to the provided list of tensor ids using the variable list.
//
// Only local variables can become outputs.
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
if let Variable::Local(index) = var {
if let Some((id, _)) = self
.locals
.iter()
.find(|(_id, position)| *position == index)
{
if !list.contains(id) {
list.push(id.clone());
}
}
}
};
// For all operators, mark their local tensor id in the proper set.
for ops in self.operators.iter() {
match ops {
Operator::AssignGlobal { input: _, out: _ } => {
// Nothing to do here.
}
Operator::ReadGlobal {
variable: _,
position: _,
position_out: _,
} => {
// Nothing to do here.
}
Operator::Add { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Sub { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Mul { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Div { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Exp { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Abs { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Erf { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Log { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Log1p { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Cos { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Sin { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Tanh { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Powf { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
}
}
// All output tensors that are never read by a following operation should be written to
// since they are essentially the "logical" output of the shader.
for out in local_tensor_ids_output {
let is_read = local_tensor_ids_input.contains(&out);
if !is_read {
outputs.push(self.tensors.get(&out).unwrap());
}
}
// All tensors where their latest description is read only should be written to since they
// are going to be used after the fused kernel by other operations.
for tensor in self.tensors.values() {
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.locals.contains_key(&tensor.id) {
outputs.push(tensor);
}
}
}
outputs
}
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable {
let already_exists = self.tensors.contains_key(&tensor.id);
let variable = match already_exists {
false => {
// New input
let var = Variable::Input(self.inputs.len() as u16);
self.inputs.push(tensor.clone());
var
}
true => match self.locals.get(&tensor.id) {
// Is a local variable.
Some(local_index) => Variable::Local(*local_index),
// Isn't a local variable, so must be an existing input.
None => {
let input = self
.inputs
.iter()
.enumerate()
.find(|(_, input)| input.id == tensor.id)
.unwrap();
let input_index = input.0;
Variable::Input(input_index as u16)
}
},
};
// Update the tensor description with the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone());
variable
}
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable {
// Update the tensor description to the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone());
// Output already registered as a local variable.
if let Some(index) = self.locals.get(&tensor.id) {
return Variable::Local(*index);
}
// New local variable.
let local_index = self.locals.len() as u16;
self.locals.insert(tensor.id.clone(), local_index);
Variable::Local(local_index)
}
fn register_float<B: FusionBackend>(&mut self, ops: &FloatOpsDescription<B>) -> bool {
match ops {
FloatOpsDescription::Exp(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out })
}
FloatOpsDescription::Log(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Log { input, out })
}
FloatOpsDescription::Log1p(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out })
}
FloatOpsDescription::Cos(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out })
}
FloatOpsDescription::Sin(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out })
}
FloatOpsDescription::Powf(desc, _) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
}
FloatOpsDescription::Tanh(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out })
}
FloatOpsDescription::Erf(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
}
_ => false,
}
}
fn register_numeric<B: FusionBackend, E: Element>(
&mut self,
ops: &NumericOpsDescription<B, E>,
) -> bool {
match ops {
NumericOpsDescription::Add(desc, _) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
}
NumericOpsDescription::AddScalar(desc, _) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
}
NumericOpsDescription::Sub(desc, _) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
}
NumericOpsDescription::SubScalar(desc, _) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
}
NumericOpsDescription::Mul(desc, _) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
}
NumericOpsDescription::MulScalar(desc, _) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
}
NumericOpsDescription::Div(desc, _) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
}
NumericOpsDescription::DivScalar(desc, _) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
}
NumericOpsDescription::Abs(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out })
}
_ => false,
}
}
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs);
let rhs = self.input_to_var(&desc.rhs);
let out = self.output_to_var(&desc.out);
self.operators.push(func(lhs, rhs, out));
true
}
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
where
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input);
let out = self.output_to_var(&desc.out);
self.operators.push(func(input, out));
true
}
fn register_scalar_ops<Func, E: Element>(
&mut self,
desc: &ScalarOpsDescription<E>,
func: Func,
) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs);
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32);
self.scalars_f32.push(desc.rhs.elem());
let out = self.output_to_var(&desc.out);
self.operators.push(func(lhs, rhs, out));
true
}
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
if self.current_output_shape.is_empty() {
self.current_output_shape = out.shape.clone();
} else if self.current_output_shape != out.shape {
return false;
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_fusion::graph::{BinaryOpsDescription, Ops};
use burn_fusion::Fusion;
use burn_tensor::Tensor;
struct FakeAddOps;
impl<B: FusionBackend> Ops<B> for FakeAddOps {
type Args = BinaryOpsDescription;
fn execute(&self, _: &Self::Args, _: &mut HandleContainer<B>) {
todo!()
}
}
#[test]
fn test_fusion_same_behavior() {
type Backend = Wgpu;
type FusedBackend = Fusion<Wgpu>;
let data_1 =
Tensor::<Backend, 2>::random([1, 32], burn_tensor::Distribution::Default).into_data();
let data_2 =
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
let tensor_1 = Tensor::<Backend, 2>::from_data(data_1.clone());
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_ref = tensor_6.into_data();
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_fused = tensor_6.into_data();
result_fused.assert_approx_eq(&result_ref, 3);
}
}

View File

@ -0,0 +1,320 @@
use super::codegen::Body;
use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient};
use crate::fusion::codegen::Function;
use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank};
use crate::fusion::{
codegen::{
Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize,
},
WgpuFusionHandle,
};
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
use burn_fusion::{HandleContainer, TensorDescription};
use burn_tensor::Device;
use std::marker::PhantomData;
/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details.
pub struct InputPhase;
/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details.
pub struct BodyPhase;
/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details.
pub struct OutputPhase;
/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details.
pub struct ExecutionPhase;
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
///
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
/// you to make mistakes.
///
/// 1. [Input Phase](InputPhase)
/// This phase focuses on registering the input tensor descriptions that are going to be used by
/// the fused kernel.
/// 2. [Body Phase](BodyPhase)
/// After the input phase is done, all the operations that happen in the body must be
/// registered.
/// 3. [Output Phase](OutputPhase)
/// This step focuses on registering all tensor descriptions that the kernel needs to write to.
/// 4. [Execution Phase](ExecutionPhase)
/// Now that all other phases are completed, we can actually run the kernel on the given
/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the
/// handles provided.
pub struct FusionKernel<G, F, I, Phase = InputPhase>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
operations: Vec<Operator>,
input_bindings: Vec<(Binding, TensorDescription)>,
output_bindings: Vec<(Binding, TensorDescription)>,
named_bindings: Vec<(String, Binding, DataBuffer)>,
functions: Vec<Function>,
num_elems_output: usize,
device: Device<Wgpu<G, F, I>>,
client: WgpuComputeClient,
_phase: PhantomData<Phase>,
}
enum DataBuffer {
F32(Vec<f32>),
U32(Vec<u32>),
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, InputPhase> {
/// Create a new fusion kernel on the given device.
pub fn new(device: &Device<Wgpu<G, F, I>>) -> Self {
let client = compute_client::<G>(device);
Self {
operations: Vec::new(),
input_bindings: Vec::new(),
output_bindings: Vec::new(),
named_bindings: Vec::new(),
functions: Vec::new(),
num_elems_output: 0,
device: device.clone(),
client,
_phase: PhantomData,
}
}
/// Register the inputs used by the kernel.
pub fn inputs(
mut self,
inputs_tensor: &[&TensorDescription],
inputs_scalar_f32: &[f32],
) -> FusionKernel<G, F, I, BodyPhase> {
for (i, input) in inputs_tensor.iter().enumerate() {
self.input_bindings.push((
Binding {
elem: Elem::F32,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16),
position: i,
position_out: inputs_tensor.len(), // First output
});
}
if !inputs_scalar_f32.is_empty() {
self.named_bindings.push((
"scalars_f32".to_string(),
Binding {
elem: Elem::F32,
visibility: Visibility::Read,
location: Location::Storage,
size: Some(inputs_scalar_f32.len()),
},
DataBuffer::F32(inputs_scalar_f32.to_vec()),
));
}
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, BodyPhase> {
/// Register the [operators](Operator) that the kernel must execute in the order provided.
pub fn body(mut self, operators: &[Operator]) -> FusionKernel<G, F, I, OutputPhase> {
let mut register_function = |function: Function| {
if !self.functions.contains(&function) {
self.functions.push(function);
}
};
// Since not all operators are native to WGSL, we need to add the custom ones.
for ops in operators.iter() {
match ops {
Operator::Powf {
lhs: _,
rhs: _,
out: _,
} => {
register_function(Function::Powf(Elem::F32));
}
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
_ => {}
}
self.operations.push(ops.clone());
}
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, OutputPhase> {
/// Register the outputs with their local variable index.
///
/// Note that the index corresponds to the registered [operator](Operator) number at the
/// [body phase](BodyPhase).
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(
mut self,
outputs: &[&TensorDescription],
locals: &[u16],
) -> FusionKernel<G, F, I, ExecutionPhase> {
let mut num_elems_launch_option = 0;
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() {
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
if num_elems_output > num_elems_launch_option {
num_elems_launch_option = num_elems_output;
}
self.output_bindings.push((
Binding {
elem: Elem::F32,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local),
out: Variable::Output(i as u16),
});
}
self.num_elems_output = num_elems_launch_option;
FusionKernel {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
functions: self.functions,
num_elems_output: self.num_elems_output,
device: self.device,
client: self.client,
_phase: PhantomData,
}
}
}
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, ExecutionPhase> {
/// Execute the kernel on the provided [handles](HandleContainer).
pub fn execute(mut self, handle_container: &mut HandleContainer<Wgpu<G, F, I>>) {
let mut inputs = Vec::with_capacity(self.input_bindings.len());
let mut outputs = Vec::with_capacity(self.output_bindings.len());
let mut named = Vec::with_capacity(2);
let mut info = Vec::new();
let mut handles =
Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity());
// Inner function to fill the info buffer.
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
if info.is_empty() {
info.push(handle.strides.len() as u32);
}
for s in handle.strides.iter() {
info.push(*s as u32);
}
for s in tensor.shape.iter() {
info.push(*s as u32);
}
};
// We start by registering the inputs.
for (binding, tensor) in self.input_bindings.into_iter() {
let handle = handle_container.get_handle(&tensor);
register_info_tensor(&tensor, &handle);
inputs.push(binding);
handles.push(handle.handle);
}
// Then we follow with the outputs.
for (binding, tensor) in self.output_bindings {
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
let handle_fusion = WgpuFusionHandle {
client: self.client.clone(),
device: self.device.clone(),
strides: strides_dyn_rank(&tensor.shape),
handle: self.client.empty(core::mem::size_of::<F>() * num_elems),
};
register_info_tensor(&tensor, &handle_fusion);
handles.push(handle_fusion.handle.clone());
handle_container.register_handle(tensor.id, handle_fusion);
outputs.push(binding);
}
// Now we can create the info handle.
Self::build_info_handle(&mut self.named_bindings, info);
// Finally we finish with the named bindings.
for (name, binding, data) in self.named_bindings {
let handle = self.client.create(match &data {
DataBuffer::F32(values) => bytemuck::cast_slice(values),
DataBuffer::U32(values) => bytemuck::cast_slice(values),
});
named.push((name, binding));
handles.push(handle);
}
// We create the shader codegen type and launch the kernel.
let kernel = ComputeShader {
inputs,
outputs,
named,
workgroup_size: WorkgroupSize::default(),
body: Body::new(self.operations),
num_workgroups: true,
global_invocation_id: true,
functions: self.functions,
};
let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT);
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
self.client
.execute(kernel, &handles.iter().collect::<Vec<_>>());
}
fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec<u32>) {
named_bindings.push((
"info".to_string(),
Binding {
elem: Elem::U32,
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
DataBuffer::U32(info),
));
}
}

View File

@ -0,0 +1,8 @@
mod base;
mod elemwise;
pub(crate) mod codegen;
pub(crate) mod kernel;
pub use base::*;
pub use elemwise::*;

View File

@ -1,7 +1,8 @@
use super::SourceTemplate;
use crate::{
compute::{StaticKernel, WorkGroup},
compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup},
element::WgpuElement,
kernel,
tensor::WgpuTensor,
};
use std::marker::PhantomData;
@ -76,6 +77,33 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
output
}
/// Similar to [into contiguous](into_contiguous) but with dynamic rank.
pub fn into_contiguous_dyn<E: WgpuElement>(
client: WgpuComputeClient,
input: WgpuHandle,
input_shape: &[usize],
input_strides: &[usize],
output_shape: &[usize],
output_strides: &[usize],
num_elems: usize,
) -> WgpuHandle {
let handle = client.empty(num_elems * core::mem::size_of::<E>());
let info = kernel::build_info_dyn::<E>(
&[input_shape, output_shape],
&[input_strides, output_strides],
);
let info_handle = client.create(bytemuck::cast_slice(&info));
let kernel = Box::new(StaticKernel::<
KernelSettings<ContiguousRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)));
client.execute(kernel, &[&input, &handle, &info_handle]);
handle
}
/// Generates kernel source code by replacing some information using templating.
pub struct KernelSettings<
K: StaticKernelSource,
@ -184,6 +212,28 @@ pub fn build_info<E: WgpuElement, const D: usize>(tensors: &[&WgpuTensor<E, D>])
info
}
/// Similar to [build info](build_info) but with dynamic rank.
pub fn build_info_dyn<E: WgpuElement>(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec<u32> {
let rank = shapes.get(0).unwrap().len();
let mut info: Vec<u32> = vec![0; shapes.len() * 2 * rank + 1];
info[0] = rank as u32;
let mut current = 1;
for stride in strides.iter() {
for d in 0..rank {
info[current] = stride[d] as u32;
current += 1;
}
}
for shape in shapes.iter() {
for d in 0..rank {
info[current] = shape[d] as u32;
current += 1;
}
}
info
}
pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup {
let num_elem_per_invocation = workgroup_size * workgroup_size;
let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32);