mirror of https://github.com/tracel-ai/burn.git
WGPU: Support elemwise operation fusion (#948)
This commit is contained in:
parent
4fc0c27e31
commit
24014aca33
|
@ -21,6 +21,7 @@ ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
|
|||
tch-cpu = ["burn/tch"]
|
||||
tch-gpu = ["burn/tch"]
|
||||
wgpu = ["burn/wgpu"]
|
||||
wgpu-fusion = ["burn/wgpu", "burn/fusion"]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../burn" }
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
#[macro_export]
|
||||
macro_rules! bench_on_backend {
|
||||
() => {
|
||||
#[cfg(feature = "wgpu-fusion")]
|
||||
{
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::Fusion;
|
||||
|
||||
bench::<Fusion<Wgpu<AutoGraphicsApi, f32, i32>>>(&WgpuDevice::default());
|
||||
}
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
{
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
|
||||
HandleContainer,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Shape};
|
||||
use burn_tensor::{backend::Backend, Device, Shape};
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -36,12 +36,18 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
type BoolTensorPrimitive<const D: usize> = FusionTensor<B::FusionClient>;
|
||||
|
||||
fn name() -> String {
|
||||
format!("Fusion<{}>", B::name())
|
||||
format!("fusion<{}>", B::name())
|
||||
}
|
||||
|
||||
fn seed(seed: u64) {
|
||||
B::seed(seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
|
||||
client.drain_graph();
|
||||
B::sync(device)
|
||||
}
|
||||
}
|
||||
|
||||
/// The status of a [fusion ops](FusionOps).
|
||||
|
@ -116,14 +122,14 @@ pub trait FusionBackend: Backend {
|
|||
/// The device type that can return an ID.
|
||||
///
|
||||
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
|
||||
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device>;
|
||||
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
type Handle: Sync + Send + Clone;
|
||||
/// What kind of client should be used.
|
||||
type FusionClient: FusionClient<FusionBackend = Self>;
|
||||
|
||||
/// The list of operations that will be used to optimize the computational graph.
|
||||
fn operations() -> Vec<Box<dyn FusionOps<Self>>>;
|
||||
fn operations(device: &Device<Self>) -> Vec<Box<dyn FusionOps<Self>>>;
|
||||
|
||||
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
|
||||
fn float_tensor<const D: usize>(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
graph::{GraphExecution, TensorOpsDescription},
|
||||
FusionBackend, FusionTensor, TensorDescription, TensorId,
|
||||
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{FloatElem, IntElem},
|
||||
|
@ -18,26 +18,18 @@ pub trait FusionClient: Send + Sync + Clone {
|
|||
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
|
||||
/// Register a new [tensor operation description](TensorOpsDescription).
|
||||
fn register(&self, ops: TensorOpsDescription<Self::FusionBackend>);
|
||||
/// Sync the computation.
|
||||
fn sync(&self);
|
||||
/// Register all lazy computation.
|
||||
fn drain_graph(&self);
|
||||
/// Get the current device used by all operations handled by this client.
|
||||
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
|
||||
/// Create an empty tensor.
|
||||
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self>;
|
||||
/// Create a float tensor with the given values.
|
||||
fn create_tensor_float(
|
||||
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
|
||||
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
|
||||
/// Create a tensor with the given handle and shape.
|
||||
fn register_tensor(
|
||||
&self,
|
||||
values: Vec<FloatElem<Self::FusionBackend>>,
|
||||
handle: Handle<Self::FusionBackend>,
|
||||
shape: Vec<usize>,
|
||||
) -> FusionTensor<Self>;
|
||||
/// Create an integer tensor with the given values.
|
||||
fn create_tensor_int(
|
||||
&self,
|
||||
values: Vec<IntElem<Self::FusionBackend>>,
|
||||
shape: Vec<usize>,
|
||||
) -> FusionTensor<Self>;
|
||||
/// Create a bool tensor with the given values.
|
||||
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self>;
|
||||
/// Read the values contained by a float tensor.
|
||||
fn read_tensor_float<const D: usize>(
|
||||
&self,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::FusionClient;
|
||||
use crate::{
|
||||
graph::{GraphExecution, TensorOpsDescription},
|
||||
FusionBackend, FusionServer, FusionTensor,
|
||||
FusionBackend, FusionServer, FusionTensor, Handle,
|
||||
};
|
||||
use burn_tensor::ops::FloatElem;
|
||||
use spin::Mutex;
|
||||
|
@ -49,10 +49,11 @@ where
|
|||
self.server.lock().register(ops);
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.server.lock().sync();
|
||||
fn drain_graph(&self) {
|
||||
self.server.lock().drain_graph();
|
||||
}
|
||||
fn create_tensor_empty(&self, shape: Vec<usize>) -> FusionTensor<Self> {
|
||||
|
||||
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self> {
|
||||
let id = self.server.lock().create_empty_handle();
|
||||
|
||||
FusionTensor::new(id, shape, self.clone())
|
||||
|
@ -61,6 +62,18 @@ where
|
|||
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
|
||||
&self.device
|
||||
}
|
||||
fn register_tensor(
|
||||
&self,
|
||||
handle: Handle<Self::FusionBackend>,
|
||||
shape: Vec<usize>,
|
||||
) -> FusionTensor<Self> {
|
||||
let mut server = self.server.lock();
|
||||
let id = server.create_empty_handle();
|
||||
server.handles.register_handle(id.as_ref().clone(), handle);
|
||||
core::mem::drop(server);
|
||||
|
||||
FusionTensor::new(id, shape, self.clone())
|
||||
}
|
||||
|
||||
fn read_tensor_float<const D: usize>(
|
||||
&self,
|
||||
|
@ -69,32 +82,6 @@ where
|
|||
self.server.lock().read_float(tensor)
|
||||
}
|
||||
|
||||
fn create_tensor_float(
|
||||
&self,
|
||||
values: Vec<FloatElem<Self::FusionBackend>>,
|
||||
shape: Vec<usize>,
|
||||
) -> FusionTensor<Self> {
|
||||
let id = self.server.lock().create_float_handle(values);
|
||||
|
||||
FusionTensor::new(id, shape, self.clone())
|
||||
}
|
||||
|
||||
fn create_tensor_int(
|
||||
&self,
|
||||
values: Vec<burn_tensor::ops::IntElem<Self::FusionBackend>>,
|
||||
shape: Vec<usize>,
|
||||
) -> FusionTensor<Self> {
|
||||
let id = self.server.lock().create_int_handle(values);
|
||||
|
||||
FusionTensor::new(id, shape, self.clone())
|
||||
}
|
||||
|
||||
fn create_tensor_bool(&self, values: Vec<bool>, shape: Vec<usize>) -> FusionTensor<Self> {
|
||||
let id = self.server.lock().create_bool_handle(values);
|
||||
|
||||
FusionTensor::new(id, shape, self.clone())
|
||||
}
|
||||
|
||||
fn read_tensor_int<const D: usize>(
|
||||
&self,
|
||||
tensor: crate::TensorDescription,
|
||||
|
|
|
@ -32,8 +32,13 @@ impl<B: FusionBackend> GraphExecution<B> for GreedyGraphExecution {
|
|||
}
|
||||
|
||||
match find_best_optimization_index(optimizations) {
|
||||
Some(index) => graph.execute_optimization(handles, index, optimizations),
|
||||
None => graph.execute(handles),
|
||||
Some(index) => {
|
||||
graph.execute_optimization(handles, index, optimizations);
|
||||
}
|
||||
None => {
|
||||
graph.execute(handles);
|
||||
optimizations.iter_mut().for_each(|ops| ops.reset());
|
||||
}
|
||||
}
|
||||
|
||||
if graph.is_empty() {
|
||||
|
|
|
@ -5,6 +5,7 @@ use burn_tensor::{
|
|||
ops::{ConvOptions, ConvTransposeOptions},
|
||||
Distribution, Element,
|
||||
};
|
||||
use core::hash::Hash;
|
||||
use std::ops::Range;
|
||||
|
||||
/// General trait to abstract how a single operation is executed.
|
||||
|
@ -652,6 +653,7 @@ pub enum BoolOpsDescription<B: FusionBackend> {
|
|||
),
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
/// Swap dim operation description.
|
||||
pub struct SwapDimsDescription {
|
||||
/// Input tensor description.
|
||||
|
@ -664,6 +666,7 @@ pub struct SwapDimsDescription {
|
|||
pub dim2: usize,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ReshapeDescription {
|
||||
pub input: TensorDescription,
|
||||
|
@ -671,6 +674,7 @@ pub struct ReshapeDescription {
|
|||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct BinaryOpsDescription {
|
||||
pub lhs: TensorDescription,
|
||||
|
@ -678,6 +682,7 @@ pub struct BinaryOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct UnaryOpsDescription {
|
||||
pub input: TensorDescription,
|
||||
|
@ -691,6 +696,7 @@ pub struct ScalarOpsDescription<E> {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct GatherOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -699,6 +705,7 @@ pub struct GatherOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ScatterOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -708,6 +715,7 @@ pub struct ScatterOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SelectOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -716,6 +724,7 @@ pub struct SelectOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SelectAssignOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -725,6 +734,7 @@ pub struct SelectAssignOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SliceOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -732,6 +742,7 @@ pub struct SliceOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SliceAssignOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -740,6 +751,7 @@ pub struct SliceAssignOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaskWhereOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -773,6 +785,7 @@ pub struct RepeatOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct CatOpsDescription {
|
||||
pub tensors: Vec<TensorDescription>,
|
||||
|
@ -780,6 +793,7 @@ pub struct CatOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ReduceDimWithIndicesDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -788,6 +802,7 @@ pub struct ReduceDimWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct EmbeddingDescription {
|
||||
pub weights: TensorDescription,
|
||||
|
@ -795,6 +810,7 @@ pub struct EmbeddingDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct EmbeddingBackwardDescription {
|
||||
pub weights: TensorDescription,
|
||||
|
@ -803,6 +819,7 @@ pub struct EmbeddingBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -812,6 +829,7 @@ pub struct Conv1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -821,6 +839,7 @@ pub struct Conv2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ConvTranspose1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -830,6 +849,7 @@ pub struct ConvTranspose1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ConvTranspose2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -839,6 +859,7 @@ pub struct ConvTranspose2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -849,6 +870,7 @@ pub struct AvgPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -859,6 +881,7 @@ pub struct AvgPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool1dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -870,6 +893,7 @@ pub struct AvgPool1dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool2dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -881,6 +905,7 @@ pub struct AvgPool2dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -888,6 +913,7 @@ pub struct AdaptiveAvgPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -895,6 +921,7 @@ pub struct AdaptiveAvgPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool1dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -902,6 +929,7 @@ pub struct AdaptiveAvgPool1dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool2dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -909,6 +937,7 @@ pub struct AdaptiveAvgPool2dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -919,6 +948,7 @@ pub struct MaxPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dWithIndicesDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -930,6 +960,7 @@ pub struct MaxPool1dWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dWithIndicesBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -942,6 +973,7 @@ pub struct MaxPool1dWithIndicesBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -952,6 +984,7 @@ pub struct MaxPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool2dWithIndicesDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -963,6 +996,7 @@ pub struct MaxPool2dWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool2dWithIndicesBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus};
|
||||
use burn_tensor::{
|
||||
ops::{FloatElem, IntElem},
|
||||
Data, ElementConversion, Shape,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
/// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources
|
||||
|
@ -17,10 +14,7 @@ pub struct HandleContainer<B: FusionBackend> {
|
|||
}
|
||||
|
||||
enum Handle<B: FusionBackend> {
|
||||
Empty,
|
||||
DataFloat(Vec<FloatElem<B>>),
|
||||
DataInt(Vec<IntElem<B>>),
|
||||
DataBool(Vec<bool>),
|
||||
NotInit,
|
||||
Existing(B::Handle),
|
||||
}
|
||||
|
||||
|
@ -34,47 +28,38 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Register a handle for the given [tensor id](TensorId).
|
||||
pub fn register_handle(&mut self, id: TensorId, handle: B::Handle) {
|
||||
self.handles.insert(id, Handle::Existing(handle));
|
||||
}
|
||||
|
||||
/// Get the handle for the given [tensor id](TensorId).
|
||||
pub fn get_handle(&mut self, tensor: &TensorDescription) -> B::Handle {
|
||||
let (id, handle) = self
|
||||
.handles
|
||||
.remove_entry(&tensor.id)
|
||||
.unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id));
|
||||
|
||||
match handle {
|
||||
Handle::Existing(handle) => match tensor.status {
|
||||
TensorStatus::ReadOnly => {
|
||||
self.handles.insert(id, Handle::Existing(handle.clone()));
|
||||
handle
|
||||
}
|
||||
TensorStatus::ReadWrite => handle,
|
||||
TensorStatus::NotInit => panic!("Cannot get uninitialized tensor."),
|
||||
},
|
||||
Handle::NotInit => panic!("Cannot get uninitialized handle."),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [float tensor](burn_tensor::backend::Backend::TensorPrimitive) corresponding to the
|
||||
/// given [tensor description](TensorDescription).
|
||||
pub fn get_float_tensor<const D: usize>(
|
||||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let (id, handle) = self
|
||||
.handles
|
||||
.remove_entry(&tensor.id)
|
||||
.unwrap_or_else(|| panic!("Should have handle for tensor {:?}", tensor.id));
|
||||
|
||||
if let Handle::Existing(handle) = handle {
|
||||
match tensor.status {
|
||||
TensorStatus::ReadOnly => {
|
||||
self.handles.insert(id, Handle::Existing(handle.clone()));
|
||||
return B::float_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::ReadWrite => {
|
||||
return B::float_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::NotInit => panic!("Can't get uninitialized tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
let output = match handle {
|
||||
Handle::Empty => B::empty(Shape::from(tensor.shape.clone()), &self.device),
|
||||
Handle::DataFloat(values) => B::from_data(
|
||||
Data::new(values, Shape::from(tensor.shape.clone())),
|
||||
&self.device,
|
||||
),
|
||||
Handle::Existing(_) => unreachable!(),
|
||||
Handle::DataInt(_) => panic!("From int unsupported when getting float tensor."),
|
||||
Handle::DataBool(_) => panic!("From bool unsupported when getting float tensor."),
|
||||
};
|
||||
|
||||
if let TensorStatus::ReadOnly = tensor.status {
|
||||
self.handles
|
||||
.insert(id, Handle::Existing(B::float_tensor_handle(output.clone())));
|
||||
}
|
||||
|
||||
output
|
||||
B::float_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
|
||||
}
|
||||
|
||||
/// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the
|
||||
|
@ -83,38 +68,7 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::IntTensorPrimitive<D> {
|
||||
let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap();
|
||||
|
||||
if let Handle::Existing(handle) = handle {
|
||||
match tensor.status {
|
||||
TensorStatus::ReadOnly => {
|
||||
self.handles.insert(id, Handle::Existing(handle.clone()));
|
||||
return B::int_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::ReadWrite => {
|
||||
return B::int_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::NotInit => panic!("Can get uninitialized tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
let output = match handle {
|
||||
Handle::Empty => B::int_empty(Shape::from(tensor.shape.clone()), &self.device),
|
||||
Handle::DataInt(values) => B::int_from_data(
|
||||
Data::new(values, Shape::from(tensor.shape.clone())),
|
||||
&self.device,
|
||||
),
|
||||
Handle::Existing(_) => unreachable!(),
|
||||
Handle::DataFloat(_) => panic!("From float unsupported when getting int tensor."),
|
||||
Handle::DataBool(_) => panic!("From bool unsupported when getting int tensor."),
|
||||
};
|
||||
|
||||
if let TensorStatus::ReadOnly = tensor.status {
|
||||
self.handles
|
||||
.insert(id, Handle::Existing(B::int_tensor_handle(output.clone())));
|
||||
}
|
||||
|
||||
output
|
||||
B::int_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
|
||||
}
|
||||
|
||||
/// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the
|
||||
|
@ -123,41 +77,7 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
&mut self,
|
||||
tensor: &TensorDescription,
|
||||
) -> B::BoolTensorPrimitive<D> {
|
||||
let (id, handle) = self.handles.remove_entry(&tensor.id).unwrap();
|
||||
|
||||
if let Handle::Existing(handle) = handle {
|
||||
match tensor.status {
|
||||
TensorStatus::ReadOnly => {
|
||||
self.handles.insert(id, Handle::Existing(handle.clone()));
|
||||
return B::bool_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::ReadWrite => {
|
||||
return B::bool_tensor(handle, Shape::from(tensor.shape.clone()));
|
||||
}
|
||||
TensorStatus::NotInit => panic!("Can get uninitialized tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
let output = match handle {
|
||||
Handle::Empty => B::int_equal_elem(
|
||||
B::int_empty(Shape::from(tensor.shape.clone()), &self.device),
|
||||
0.elem(),
|
||||
),
|
||||
Handle::DataBool(data) => B::bool_from_data(
|
||||
Data::new(data, Shape::from(tensor.shape.clone())),
|
||||
&self.device,
|
||||
),
|
||||
Handle::Existing(_) => unreachable!(),
|
||||
Handle::DataFloat(_) => panic!("From float unsupported when getting bool tensor."),
|
||||
Handle::DataInt(_) => panic!("From int unsupported when getting bool tensor."),
|
||||
};
|
||||
|
||||
if let TensorStatus::ReadOnly = tensor.status {
|
||||
self.handles
|
||||
.insert(id, Handle::Existing(B::bool_tensor_handle(output.clone())));
|
||||
}
|
||||
|
||||
output
|
||||
B::bool_tensor(self.get_handle(tensor), Shape::from(&tensor.shape))
|
||||
}
|
||||
|
||||
/// Register a new [float tensor](burn_tensor::backend::Backend::TensorPrimitive) with the corresponding [tensor id](TensorId).
|
||||
|
@ -191,37 +111,10 @@ impl<B: FusionBackend> HandleContainer<B> {
|
|||
}
|
||||
|
||||
/// Lazily create a new empty tensor and return its corresponding [tensor id](TensorId).
|
||||
pub fn create_tensor_empty(&mut self) -> Arc<TensorId> {
|
||||
pub fn create_tensor_uninit(&mut self) -> Arc<TensorId> {
|
||||
let id = TensorId::new(self.counter);
|
||||
self.counter += 1;
|
||||
self.handles.insert(id.clone(), Handle::Empty);
|
||||
|
||||
Arc::new(id)
|
||||
}
|
||||
|
||||
/// Lazily create a new float tensor and return its corresponding [tensor id](TensorId).
|
||||
pub(crate) fn create_tensor_float(&mut self, values: Vec<FloatElem<B>>) -> Arc<TensorId> {
|
||||
let id = TensorId::new(self.counter);
|
||||
self.counter += 1;
|
||||
self.handles.insert(id.clone(), Handle::DataFloat(values));
|
||||
|
||||
Arc::new(id)
|
||||
}
|
||||
|
||||
/// Lazily create a new int tensor and return its corresponding [tensor id](TensorId).
|
||||
pub(crate) fn create_tensor_int(&mut self, values: Vec<IntElem<B>>) -> Arc<TensorId> {
|
||||
let id = TensorId::new(self.counter);
|
||||
self.counter += 1;
|
||||
self.handles.insert(id.clone(), Handle::DataInt(values));
|
||||
|
||||
Arc::new(id)
|
||||
}
|
||||
|
||||
/// Lazily create a new bool tensor and return its corresponding [tensor id](TensorId).
|
||||
pub(crate) fn create_tensor_bool(&mut self, values: Vec<bool>) -> Arc<TensorId> {
|
||||
let id = TensorId::new(self.counter);
|
||||
self.counter += 1;
|
||||
self.handles.insert(id.clone(), Handle::DataBool(values));
|
||||
self.handles.insert(id.clone(), Handle::NotInit);
|
||||
|
||||
Arc::new(id)
|
||||
}
|
||||
|
|
|
@ -17,8 +17,9 @@ use burn_tensor::{
|
|||
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::bool_empty(shape.clone(), device);
|
||||
|
||||
client.create_tensor_empty(shape.dims.into())
|
||||
client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(tensor: &BoolTensor<Self, D>) -> Shape<D> {
|
||||
|
@ -36,8 +37,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
device: &Device<Self>,
|
||||
) -> BoolTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::bool_from_data(data, device);
|
||||
let shape = B::bool_shape(&tensor);
|
||||
|
||||
client.create_tensor_bool(data.value, data.shape.dims.into())
|
||||
client.register_tensor(B::bool_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
|
@ -55,7 +58,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt(
|
||||
|
@ -84,7 +87,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::BoolOps(
|
||||
BoolOpsDescription::IntoFloat(
|
||||
|
@ -139,7 +142,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.create_tensor_empty(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -183,7 +186,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
shape.push(tensor.shape[i]);
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -227,7 +230,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -279,7 +282,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] += tensor.shape[dim];
|
||||
}
|
||||
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat(
|
||||
CatOpsDescription {
|
||||
|
@ -312,7 +315,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::BaseOpsBool(
|
||||
BaseOpsDescription::Equal(
|
||||
|
@ -341,7 +344,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::BoolOps(
|
||||
crate::graph::BoolOpsDescription::Not(
|
||||
|
@ -377,7 +380,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
shape[dim1] = tensor.shape[dim2];
|
||||
shape[dim2] = tensor.shape[dim1];
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
|
|
@ -26,8 +26,10 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::from_data(data, device);
|
||||
let shape = B::shape(&tensor);
|
||||
|
||||
client.create_tensor_float(data.value, data.shape.dims.into())
|
||||
client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
|
@ -54,7 +56,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Random(
|
||||
(out.to_description_out(), distribution),
|
||||
|
@ -79,7 +81,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::<D>)),
|
||||
|
@ -103,7 +105,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::<D>)),
|
||||
|
@ -131,7 +133,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Full(
|
||||
|
@ -188,7 +190,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::FloatOps(
|
||||
FloatOpsDescription::IntoInt(
|
||||
|
@ -205,8 +207,9 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::empty(shape.clone(), device);
|
||||
|
||||
client.create_tensor_empty(shape.dims.into())
|
||||
client.register_tensor(B::float_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn add<const D: usize>(
|
||||
|
@ -217,7 +220,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Add(
|
||||
|
@ -239,7 +242,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(AddOps, B::add_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::AddScalar(
|
||||
|
@ -261,7 +264,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(ClampMinOps, B::clamp_min);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::ClampMin(
|
||||
|
@ -283,7 +286,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(ClampMaxOps, B::clamp_max);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::ClampMax(
|
||||
|
@ -317,7 +320,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Clamp(
|
||||
|
@ -342,7 +345,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Sub(
|
||||
|
@ -364,7 +367,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(SubOps, B::sub_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::SubScalar(
|
||||
|
@ -388,7 +391,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Mul(
|
||||
|
@ -410,7 +413,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(MulOps, B::mul_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MulScalar(
|
||||
|
@ -434,7 +437,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Div(
|
||||
|
@ -456,7 +459,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(DivOps, B::div_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::DivScalar(
|
||||
|
@ -483,7 +486,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
shape[D - 2] = lhs.shape[D - 2];
|
||||
shape[D - 1] = rhs.shape[D - 1];
|
||||
|
||||
let out = lhs.client.create_tensor_empty(shape);
|
||||
let out = lhs.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Matmul(
|
||||
|
@ -519,7 +522,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
shape[dim1] = tensor.shape[dim2];
|
||||
shape[dim2] = tensor.shape[dim1];
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -556,7 +559,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.create_tensor_empty(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -595,7 +598,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = indices.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -638,7 +641,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -681,7 +684,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape: Vec<usize> = tensor.shape.clone();
|
||||
shape[dim] = indices.shape[0];
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -724,7 +727,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -769,7 +772,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
shape.push(tensor.shape[i]);
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -813,7 +816,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -855,7 +858,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -896,7 +899,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -924,7 +927,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::BaseOpsFloat(
|
||||
BaseOpsDescription::Equal(
|
||||
|
@ -946,7 +949,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_float_cmp_ops!(EqualElemOps, B::equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::EqualElem(
|
||||
|
@ -970,7 +973,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Greater(
|
||||
|
@ -992,7 +995,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_float_cmp_ops!(GreaterElemOps, B::greater_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::GreaterElem(
|
||||
|
@ -1016,7 +1019,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::GreaterEqual(
|
||||
|
@ -1038,7 +1041,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_float_cmp_ops!(GreaterEqualElemOps, B::greater_equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::GreaterEqualElem(
|
||||
|
@ -1062,7 +1065,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Lower(
|
||||
|
@ -1084,7 +1087,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_float_cmp_ops!(LowerElemOps, B::lower_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::LowerElem(
|
||||
|
@ -1108,7 +1111,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::LowerEqual(
|
||||
|
@ -1130,7 +1133,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_float_cmp_ops!(LowerEqualElemOps, B::lower_equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::LowerEqualElem(
|
||||
|
@ -1149,7 +1152,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
unary_float_ops!(SumOps, B::sum);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Sum(
|
||||
|
@ -1169,7 +1172,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::SumDim(
|
||||
|
@ -1188,7 +1191,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
unary_float_ops!(MeanOps, B::mean);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Mean(
|
||||
|
@ -1208,7 +1211,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MeanDim(
|
||||
|
@ -1239,7 +1242,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(ExpOps, B::exp);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(
|
||||
|
@ -1256,7 +1259,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(LogOps, B::log);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
|
||||
|
@ -1273,7 +1276,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(Log1pOps, B::log1p);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Log1p(
|
||||
|
@ -1290,7 +1293,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(PowfOps, B::powf, f32);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Powf(
|
||||
|
@ -1308,7 +1311,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(SqrtOps, B::sqrt);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sqrt(
|
||||
|
@ -1325,7 +1328,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(AbsOps, B::abs);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Abs(
|
||||
|
@ -1343,7 +1346,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(CosOps, B::cos);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Cos(
|
||||
|
@ -1360,7 +1363,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(SinOps, B::sin);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Sin(
|
||||
|
@ -1377,7 +1380,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(TanhOps, B::tanh);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Tanh(
|
||||
|
@ -1394,7 +1397,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(Recip, B::recip);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
|
||||
UnaryOpsDescription {
|
||||
|
@ -1409,7 +1412,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary_float_ops!(TanhOps, B::erf);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Erf(
|
||||
|
@ -1452,7 +1455,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
shape[dim] += tensor.shape[dim];
|
||||
}
|
||||
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::BaseOpsFloat(BaseOpsDescription::Cat(
|
||||
CatOpsDescription {
|
||||
|
@ -1471,7 +1474,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::ArgMax(
|
||||
|
@ -1492,7 +1495,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::ArgMin(
|
||||
|
@ -1511,7 +1514,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn max<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
unary_float_ops!(MaxOps, B::max);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Max(
|
||||
|
@ -1531,7 +1534,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MaxDim(
|
||||
|
@ -1568,8 +1571,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let client = tensor.client.clone();
|
||||
let out = client.create_tensor_empty(shape.clone());
|
||||
let out_indices = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MaxDimWithIndices(
|
||||
|
@ -1589,7 +1592,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
fn min<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
unary_float_ops!(MinOps, B::min);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::Min(
|
||||
|
@ -1609,7 +1612,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MinDim(
|
||||
|
@ -1646,8 +1649,8 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let client = tensor.client.clone();
|
||||
let out = client.create_tensor_empty(shape.clone());
|
||||
let out_indices = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsFloat(
|
||||
NumericOpsDescription::MinDimWithIndices(
|
||||
|
|
|
@ -22,8 +22,9 @@ use core::ops::Range;
|
|||
impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::int_empty(shape.clone(), device);
|
||||
|
||||
client.create_tensor_empty(shape.dims.into())
|
||||
client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn int_shape<const D: usize>(tensor: &IntTensor<Self, D>) -> Shape<D> {
|
||||
|
@ -39,8 +40,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let tensor = B::int_from_data(data, device);
|
||||
let shape = B::int_shape(&tensor);
|
||||
|
||||
client.create_tensor_int(data.value, data.shape.dims.into())
|
||||
client.register_tensor(B::int_tensor_handle(tensor), shape.dims.into())
|
||||
}
|
||||
|
||||
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
|
||||
|
@ -83,7 +86,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.create_tensor_empty(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -127,7 +130,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape.push(tensor.shape[i]);
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -169,7 +172,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -211,7 +214,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -252,7 +255,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -292,7 +295,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = indices.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -335,7 +338,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -378,7 +381,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape: Vec<usize> = tensor.shape.clone();
|
||||
shape[dim] = indices.shape[0];
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -421,7 +424,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -471,7 +474,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] += tensor.shape[dim];
|
||||
}
|
||||
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat(
|
||||
CatOpsDescription {
|
||||
|
@ -493,7 +496,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client
|
||||
.register(TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal(
|
||||
|
@ -514,7 +517,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::EqualElem(
|
||||
|
@ -538,7 +541,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Greater(
|
||||
|
@ -560,7 +563,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::GreaterElem(
|
||||
|
@ -584,7 +587,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::GreaterEqual(
|
||||
|
@ -606,7 +609,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::GreaterEqualElem(
|
||||
|
@ -630,7 +633,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Lower(
|
||||
|
@ -652,7 +655,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::LowerElem(
|
||||
|
@ -676,7 +679,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::LowerEqual(
|
||||
|
@ -698,7 +701,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::LowerEqualElem(
|
||||
|
@ -722,7 +725,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -745,7 +748,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(AddOps, B::int_add_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -770,7 +773,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -793,7 +796,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(SubOps, B::int_sub_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -818,7 +821,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -841,7 +844,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(MulOps, B::int_mul_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -866,7 +869,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs
|
||||
.client
|
||||
.create_tensor_empty(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -889,7 +892,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(DivOps, B::int_div_scalar);
|
||||
|
||||
let out = lhs.client.create_tensor_empty(lhs.shape.clone());
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
out.client
|
||||
.register(graph::TensorOpsDescription::NumericOpsInt(
|
||||
|
@ -921,7 +924,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Zeros(out.to_description_out(), Box::new(ZerosOps::<D>)),
|
||||
|
@ -945,7 +948,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Ones(out.to_description_out(), Box::new(OnesOps::<D>)),
|
||||
|
@ -957,7 +960,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_sum<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(SumOps, B::int_sum);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Sum(
|
||||
|
@ -977,7 +980,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::SumDim(
|
||||
|
@ -996,7 +999,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(MeanOps, B::int_mean);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Mean(
|
||||
|
@ -1016,7 +1019,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::MeanDim(
|
||||
|
@ -1037,7 +1040,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::ArgMax(
|
||||
|
@ -1058,7 +1061,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::ArgMin(
|
||||
|
@ -1080,7 +1083,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(ClampMinOps, B::int_clamp_min);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::ClampMin(
|
||||
|
@ -1102,7 +1105,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::ClampMax(
|
||||
|
@ -1136,7 +1139,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Clamp(
|
||||
|
@ -1156,7 +1159,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
unary_int_ops!(AbsOps, B::int_abs);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Abs(
|
||||
|
@ -1184,7 +1187,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = tensor.client.create_tensor_empty(tensor.shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
out.client.register(TensorOpsDescription::IntOps(
|
||||
graph::IntOpsDescription::IntoFloat(
|
||||
|
@ -1220,7 +1223,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim1] = tensor.shape[dim2];
|
||||
shape[dim2] = tensor.shape[dim1];
|
||||
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
tensor
|
||||
.client
|
||||
|
@ -1243,7 +1246,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_max<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(MaxOps, B::int_max);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Max(
|
||||
|
@ -1263,7 +1266,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::MaxDim(
|
||||
|
@ -1300,8 +1303,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let client = tensor.client.clone();
|
||||
let out = client.create_tensor_empty(shape.clone());
|
||||
let out_indices = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::MaxDimWithIndices(
|
||||
|
@ -1321,7 +1324,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_min<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
unary_int_ops!(MinOps, B::int_min);
|
||||
|
||||
let out = tensor.client.create_tensor_empty(vec![1]);
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::Min(
|
||||
|
@ -1341,7 +1344,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let out = tensor.client.create_tensor_empty(shape);
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
out.client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::MinDim(
|
||||
|
@ -1378,8 +1381,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = 1;
|
||||
let client = tensor.client.clone();
|
||||
let out = client.create_tensor_empty(shape.clone());
|
||||
let out_indices = client.create_tensor_empty(shape);
|
||||
let out = client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = client.tensor_uninitialized(shape);
|
||||
|
||||
client.register(TensorOpsDescription::NumericOpsInt(
|
||||
NumericOpsDescription::MinDimWithIndices(
|
||||
|
|
|
@ -56,7 +56,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::Conv1d(
|
||||
|
@ -115,7 +115,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::Conv2d(
|
||||
|
@ -168,7 +168,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::ConvTranspose1d(
|
||||
|
@ -229,7 +229,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::ConvTranspose2d(
|
||||
|
@ -275,7 +275,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
|
||||
let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AvgPool1d(
|
||||
|
@ -326,7 +326,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AvgPool2d(
|
||||
|
@ -374,7 +374,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AvgPool1dBackward(
|
||||
|
@ -423,7 +423,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AvgPool2dBackward(
|
||||
|
@ -472,7 +472,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool1d(
|
||||
|
@ -533,7 +533,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool2d(
|
||||
|
@ -581,8 +581,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
|
||||
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let out = x.client.create_tensor_empty(shape.clone());
|
||||
let out_indices = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool1dWithIndices(
|
||||
|
@ -645,8 +645,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let out = x.client.create_tensor_empty(shape.clone());
|
||||
let out_indices = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape.clone());
|
||||
let out_indices = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool2dWithIndices(
|
||||
|
@ -698,7 +698,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool1dWithIndicesBackward(
|
||||
|
@ -751,7 +751,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::MaxPool2dWithIndicesBackward(
|
||||
|
@ -787,7 +787,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], output_size];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AdaptiveAvgPool1d(
|
||||
|
@ -821,7 +821,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
|
||||
let out = x.client.create_tensor_empty(shape);
|
||||
let out = x.client.tensor_uninitialized(shape);
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AdaptiveAvgPool2d(
|
||||
|
@ -855,7 +855,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AdaptiveAvgPool1dBackward(
|
||||
|
@ -889,7 +889,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
let out = x.client.create_tensor_empty(x.shape.clone());
|
||||
let out = x.client.tensor_uninitialized(x.shape.clone());
|
||||
|
||||
x.client.clone().register(TensorOpsDescription::ModuleOps(
|
||||
crate::graph::ModuleOpsDescription::AdaptiveAvgPool2dBackward(
|
||||
|
|
|
@ -23,7 +23,7 @@ where
|
|||
G: GraphExecution<B>,
|
||||
{
|
||||
pub fn new(device: B::FusionDevice) -> Self {
|
||||
let optimizations = B::operations()
|
||||
let optimizations = B::operations(&device.clone().into())
|
||||
.into_iter()
|
||||
.map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default())))
|
||||
.collect();
|
||||
|
@ -53,7 +53,11 @@ where
|
|||
);
|
||||
}
|
||||
|
||||
pub fn sync(&mut self) {
|
||||
pub fn drain_graph(&mut self) {
|
||||
if self.graph.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.execution.maybe_execute(
|
||||
&mut self.graph,
|
||||
&mut self.handles,
|
||||
|
@ -63,19 +67,7 @@ where
|
|||
}
|
||||
|
||||
pub fn create_empty_handle(&mut self) -> Arc<TensorId> {
|
||||
self.handles.create_tensor_empty()
|
||||
}
|
||||
|
||||
pub fn create_float_handle(&mut self, values: Vec<FloatElem<B>>) -> Arc<TensorId> {
|
||||
self.handles.create_tensor_float(values)
|
||||
}
|
||||
|
||||
pub fn create_int_handle(&mut self, values: Vec<IntElem<B>>) -> Arc<TensorId> {
|
||||
self.handles.create_tensor_int(values)
|
||||
}
|
||||
|
||||
pub fn create_bool_handle(&mut self, values: Vec<bool>) -> Arc<TensorId> {
|
||||
self.handles.create_tensor_bool(values)
|
||||
self.handles.create_tensor_uninit()
|
||||
}
|
||||
|
||||
pub fn read_float<const D: usize>(
|
||||
|
@ -84,7 +76,7 @@ where
|
|||
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<B>, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.sync();
|
||||
self.drain_graph();
|
||||
|
||||
let tensor = self.handles.get_float_tensor(&tensor);
|
||||
B::into_data(tensor)
|
||||
|
@ -96,7 +88,7 @@ where
|
|||
) -> burn_tensor::Reader<burn_tensor::Data<IntElem<B>, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.sync();
|
||||
self.drain_graph();
|
||||
|
||||
let tensor = self.handles.get_int_tensor(&tensor);
|
||||
B::int_into_data(tensor)
|
||||
|
@ -108,7 +100,7 @@ where
|
|||
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
|
||||
// Make sure all registered operations are executed.
|
||||
// The underlying backend can still be async.
|
||||
self.sync();
|
||||
self.drain_graph();
|
||||
|
||||
let tensor = self.handles.get_bool_tensor(&tensor);
|
||||
B::bool_into_data(tensor)
|
||||
|
|
|
@ -127,7 +127,7 @@ pub struct TensorId {
|
|||
}
|
||||
|
||||
/// The status of the current tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Hash, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum TensorStatus {
|
||||
/// The tensor can be read, but not written.
|
||||
ReadOnly,
|
||||
|
@ -147,7 +147,7 @@ pub enum TensorStatus {
|
|||
/// 2. Status::ReadOnly
|
||||
/// 3. Status::ReadOnly
|
||||
/// 4. Status::ReadWrite
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct TensorDescription {
|
||||
/// The [tensor id](TensorId).
|
||||
pub id: TensorId,
|
||||
|
@ -158,7 +158,8 @@ pub struct TensorDescription {
|
|||
}
|
||||
|
||||
impl TensorId {
|
||||
pub(crate) fn new(value: u64) -> Self {
|
||||
/// Create a new tensor id.
|
||||
pub fn new(value: u64) -> Self {
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ pub struct Conv1dBackward<B: Backend> {
|
|||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
#[derive(new, Debug, Clone, Hash)]
|
||||
pub struct ConvOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
@ -82,7 +82,7 @@ pub struct ConvOptions<const N: usize> {
|
|||
}
|
||||
|
||||
/// Transposed convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
#[derive(new, Debug, Clone, Hash)]
|
||||
pub struct ConvTransposeOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
|
|
@ -53,11 +53,16 @@ burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features =
|
|||
burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.11.0" }
|
||||
serial_test = "2.0.0"
|
||||
pretty_assertions = {workspace = true}
|
||||
|
||||
[[bench]]
|
||||
name = "matmul"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "fused_elemwise"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "reduction"
|
||||
harness = false
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_fusion::Fusion;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Distribution, Shape, Tensor};
|
||||
use burn_wgpu::Wgpu;
|
||||
use burn_wgpu::WgpuDevice;
|
||||
use derive_new::new;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(new)]
|
||||
struct ElemWiseBenchmark<B: Backend> {
|
||||
shape: Shape<3>,
|
||||
device: B::Device,
|
||||
repeat: usize,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Benchmark for ElemWiseBenchmark<B> {
|
||||
type Args = (Tensor<B, 3>, Tensor<B, 3>);
|
||||
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"Backend {} Shape {:?} Repeat {}",
|
||||
B::name(),
|
||||
self.shape.dims,
|
||||
self.repeat
|
||||
)
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn execute(&self, (lhs, rhs): Self::Args) {
|
||||
for _ in 0..self.repeat {
|
||||
let tmp_0 = lhs.clone() + rhs.clone();
|
||||
let tmp_1 = rhs.clone() * tmp_0.clone();
|
||||
let tmp_2 = rhs.clone().exp();
|
||||
let tmp_3 = tmp_0 * tmp_1;
|
||||
let _tmp_4 = tmp_2 / tmp_3;
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
B::seed(10);
|
||||
let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
|
||||
let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
|
||||
|
||||
(lhs, rhs)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Runs the benchmarks for wgpu matmul implementations
|
||||
pub fn bench(device: &WgpuDevice) {
|
||||
run_benchmark(ElemWiseBenchmark::<Wgpu>::new(
|
||||
Shape::new([256, 256, 1024]),
|
||||
device.clone(),
|
||||
10,
|
||||
));
|
||||
run_benchmark(ElemWiseBenchmark::<Fusion<Wgpu>>::new(
|
||||
Shape::new([256, 256, 1024]),
|
||||
device.clone(),
|
||||
10,
|
||||
));
|
||||
}
|
||||
|
||||
fn main() {
|
||||
bench(&WgpuDevice::BestAvailable)
|
||||
}
|
|
@ -154,7 +154,9 @@ where
|
|||
return pipeline.clone();
|
||||
}
|
||||
|
||||
let pipeline = self.compile_source(&kernel.source().complete());
|
||||
let source = kernel.source().complete();
|
||||
log::trace!("Compiling kernel {kernel_id}:\n {source}");
|
||||
let pipeline = self.compile_source(&source);
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
pipeline
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{
|
||||
compute::{WgpuComputeClient, WgpuHandle},
|
||||
element::WgpuElement,
|
||||
fusion::FloatElementWiseFusionOps,
|
||||
tensor::WgpuTensor,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice,
|
||||
};
|
||||
|
@ -32,8 +33,8 @@ where
|
|||
type Handle = WgpuFusionHandle;
|
||||
type FusionClient = MutexFusionClient<Self, GreedyGraphExecution>;
|
||||
|
||||
fn operations() -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
|
||||
Vec::new()
|
||||
fn operations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
|
||||
vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))]
|
||||
}
|
||||
|
||||
fn float_tensor<const D: usize>(
|
||||
|
@ -70,7 +71,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
|
||||
let mut strides = vec![0; shape.len()];
|
||||
|
||||
let mut current = 1;
|
||||
shape.iter().enumerate().rev().for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
});
|
||||
|
||||
strides
|
||||
}
|
||||
|
||||
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||
let mut num_elems = 1;
|
||||
for i in shape.iter() {
|
||||
num_elems *= i;
|
||||
}
|
||||
num_elems
|
||||
}
|
||||
|
||||
#[derive(new, Debug, Clone)]
|
||||
/// Handle to be used when fusing operations.
|
||||
pub struct WgpuFusionHandle {
|
||||
/// Compute client for wgpu.
|
|
@ -0,0 +1,27 @@
|
|||
use super::Operator;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// A body is composed of a list of [operators](Operator).
|
||||
///
|
||||
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
|
||||
/// X and Y, but with Z=1.
|
||||
#[derive(Hash, new)]
|
||||
pub struct Body {
|
||||
operators: Vec<Operator>,
|
||||
}
|
||||
|
||||
impl Display for Body {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(
|
||||
"let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n",
|
||||
)?;
|
||||
f.write_str("let rank: u32 = info[0];\n\n")?;
|
||||
|
||||
for ops in self.operators.iter() {
|
||||
f.write_fmt(format_args!("{ops}"))?;
|
||||
f.write_str("\n")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
use super::Elem;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Not all functions are native to WGSL, so this struct allows to support more functions.
|
||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
||||
pub enum Function {
|
||||
Powf(Elem),
|
||||
Erf(Elem),
|
||||
}
|
||||
|
||||
impl Display for Function {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Function::Powf(elem) => format_powf(f, elem),
|
||||
Function::Erf(elem) => format_erf(f, elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
|
||||
let modulo = rhs % 2.0;
|
||||
|
||||
if (modulo == 0.0) {{
|
||||
// Even number
|
||||
return pow(abs(lhs), rhs);
|
||||
}} else if (modulo == 1.0 && lhs < 0.0) {{
|
||||
// Odd number
|
||||
return -1.0 * pow(-1.0 * lhs, rhs);
|
||||
}} else {{
|
||||
// Float number
|
||||
return pow(lhs, rhs);
|
||||
}}
|
||||
}}
|
||||
"
|
||||
))
|
||||
}
|
||||
|
||||
fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
|
||||
///
|
||||
/// > (maximum error: 1.5×10−7)
|
||||
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
|
||||
fn erf_positive(x: {elem}) -> {elem} {{
|
||||
let p = 0.3275911;
|
||||
let a1 = 0.254829592;
|
||||
let a2 = -0.284496736;
|
||||
let a3 = 1.421413741;
|
||||
let a4 = -1.453152027;
|
||||
let a5 = 1.061405429;
|
||||
|
||||
let t = 1.0 / (1.0 + p * abs(x));
|
||||
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
|
||||
|
||||
return 1.0 - (tmp * t * exp(-x * x));
|
||||
}}
|
||||
|
||||
fn erf(x: {elem}) -> {elem} {{
|
||||
if (x < 0.0) {{
|
||||
return -1.0 * erf_positive(-1.0 * x);
|
||||
}}
|
||||
|
||||
return erf_positive(x);
|
||||
}}
|
||||
"
|
||||
))
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
mod body;
|
||||
mod function;
|
||||
mod operator;
|
||||
mod shader;
|
||||
mod variable;
|
||||
|
||||
pub use body::*;
|
||||
pub use function::*;
|
||||
pub use operator::*;
|
||||
pub use shader::*;
|
||||
pub use variable::*;
|
|
@ -0,0 +1,146 @@
|
|||
use super::Variable;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// All operators that can be fused in a WGSL compute shader.
|
||||
#[derive(Debug, Hash, Clone)]
|
||||
pub enum Operator {
|
||||
Add {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Sub {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Mul {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Div {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Abs {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Exp {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Log {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Log1p {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Cos {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Sin {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Tanh {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Powf {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Erf {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
AssignGlobal {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ReadGlobal {
|
||||
variable: Variable,
|
||||
position: usize,
|
||||
position_out: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Operator {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Operator::Add { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} + {rhs};"))
|
||||
}
|
||||
Operator::Sub { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} - {rhs};"))
|
||||
}
|
||||
Operator::Mul { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} * {rhs};"))
|
||||
}
|
||||
Operator::Div { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} / {rhs};"))
|
||||
}
|
||||
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
|
||||
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
|
||||
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
|
||||
Operator::Powf { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
|
||||
}
|
||||
Operator::Log1p { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
|
||||
}
|
||||
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
|
||||
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
|
||||
Operator::Tanh { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = tanh({input});"))
|
||||
}
|
||||
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
|
||||
Operator::AssignGlobal { input, out } => {
|
||||
f.write_fmt(format_args!("{out}_global[id] = {input};"))
|
||||
}
|
||||
Operator::ReadGlobal {
|
||||
variable,
|
||||
position,
|
||||
position_out,
|
||||
} => {
|
||||
let (global, local) = match variable {
|
||||
Variable::Input(number) => {
|
||||
(format!("input_{number}_global"), format!("input_{number}"))
|
||||
}
|
||||
Variable::Local(_) => panic!("can't read globala local variable."),
|
||||
Variable::Output(number) => (
|
||||
format!("output_{number}_global"),
|
||||
format!("output_{number}"),
|
||||
),
|
||||
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||
};
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var index_{local}: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= rank; i++) {{
|
||||
let position = {position}u * (2u * rank);
|
||||
let position_out = {position_out}u * (2u * rank);
|
||||
|
||||
let stride = info[position + i];
|
||||
let stride_out = info[position_out + i];
|
||||
let shape = info[position + rank + i];
|
||||
|
||||
index_{local} += id / stride_out % shape * stride;
|
||||
}}
|
||||
|
||||
let {local} = {global}[index_{local}];
|
||||
"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
use super::{Body, Function};
|
||||
use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT};
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
fmt::Display,
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub enum Location {
|
||||
Storage,
|
||||
#[allow(dead_code)]
|
||||
Workgroup,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub enum Visibility {
|
||||
Read,
|
||||
ReadWrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub enum Elem {
|
||||
F32,
|
||||
#[allow(dead_code)]
|
||||
I32,
|
||||
U32,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub struct Binding {
|
||||
pub location: Location,
|
||||
pub visibility: Visibility,
|
||||
pub elem: Elem,
|
||||
pub size: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub struct WorkgroupSize {
|
||||
pub x: usize,
|
||||
pub y: usize,
|
||||
pub z: usize,
|
||||
}
|
||||
|
||||
impl Default for WorkgroupSize {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
x: WORKGROUP_DEFAULT,
|
||||
y: WORKGROUP_DEFAULT,
|
||||
z: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
pub struct ComputeShader {
|
||||
pub inputs: Vec<Binding>,
|
||||
pub outputs: Vec<Binding>,
|
||||
pub named: Vec<(String, Binding)>,
|
||||
pub workgroup_size: WorkgroupSize,
|
||||
pub global_invocation_id: bool,
|
||||
pub num_workgroups: bool,
|
||||
pub body: Body,
|
||||
pub functions: Vec<Function>,
|
||||
}
|
||||
|
||||
impl DynamicKernelSource for ComputeShader {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
SourceTemplate::new(self.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
let mut s = DefaultHasher::new();
|
||||
self.hash(&mut s);
|
||||
|
||||
s.finish().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ComputeShader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
Self::format_bindings(f, "input", &self.inputs, 0)?;
|
||||
Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?;
|
||||
|
||||
for (i, (name, binding)) in self.named.iter().enumerate() {
|
||||
Self::format_binding(
|
||||
f,
|
||||
name.as_str(),
|
||||
binding,
|
||||
self.inputs.len() + self.outputs.len() + i,
|
||||
)?;
|
||||
}
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"const WORKGROUP_SIZE_X = {}u;
|
||||
const WORKGROUP_SIZE_Y = {}u;
|
||||
const WORKGROUP_SIZE_Z = {}u;\n",
|
||||
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z
|
||||
))?;
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
@compute
|
||||
@workgroup_size({}, {}, {})
|
||||
fn main(
|
||||
",
|
||||
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z
|
||||
))?;
|
||||
|
||||
if self.global_invocation_id {
|
||||
f.write_str(" @builtin(global_invocation_id) global_id: vec3<u32>,\n")?;
|
||||
}
|
||||
|
||||
if self.num_workgroups {
|
||||
f.write_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>,\n")?;
|
||||
}
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
") {{
|
||||
{}
|
||||
}}",
|
||||
self.body
|
||||
))?;
|
||||
|
||||
for function in self.functions.iter() {
|
||||
f.write_fmt(format_args!("{function}\n\n"))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeShader {
|
||||
fn format_bindings(
|
||||
f: &mut core::fmt::Formatter<'_>,
|
||||
prefix: &str,
|
||||
bindings: &[Binding],
|
||||
num_entry: usize,
|
||||
) -> core::fmt::Result {
|
||||
for (i, binding) in bindings.iter().enumerate() {
|
||||
Self::format_binding(
|
||||
f,
|
||||
format!("{prefix}_{i}_global").as_str(),
|
||||
binding,
|
||||
num_entry + i,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_binding(
|
||||
f: &mut core::fmt::Formatter<'_>,
|
||||
name: &str,
|
||||
binding: &Binding,
|
||||
num_entry: usize,
|
||||
) -> core::fmt::Result {
|
||||
let ty = match binding.size {
|
||||
Some(size) => format!("array<{}, {}>", binding.elem, size),
|
||||
None => format!("array<{}>", binding.elem),
|
||||
};
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"@group(0)
|
||||
@binding({})
|
||||
var<{}, {}> {}: {};
|
||||
\n",
|
||||
num_entry, binding.location, binding.visibility, name, ty
|
||||
))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Location {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Location::Storage => f.write_str("storage"),
|
||||
Location::Workgroup => f.write_str("workgroup"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Elem {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Elem::F32 => f.write_str("f32"),
|
||||
Elem::I32 => f.write_str("i32"),
|
||||
Elem::U32 => f.write_str("u32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::Read => f.write_str("read"),
|
||||
Visibility::ReadWrite => f.write_str("read_write"),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
use super::Elem;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Hash, Clone)]
|
||||
pub enum Variable {
|
||||
Input(u16),
|
||||
Scalar(u16, Elem),
|
||||
Local(u16),
|
||||
Output(u16),
|
||||
}
|
||||
|
||||
impl Display for Variable {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
|
||||
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod ops;
|
||||
|
||||
pub use ops::*;
|
|
@ -0,0 +1,463 @@
|
|||
use crate::{
|
||||
fusion::codegen::{Elem, Operator, Variable},
|
||||
fusion::kernel::FusionKernel,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_fusion::{
|
||||
graph::{
|
||||
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription,
|
||||
TensorOpsDescription, UnaryOpsDescription,
|
||||
},
|
||||
FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription,
|
||||
TensorId,
|
||||
};
|
||||
use burn_tensor::{Device, Element};
|
||||
use hashbrown::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Fused element wise operations that are normally memory bound.
|
||||
pub struct FloatElementWiseFusionOps<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
pub(crate) inputs: Vec<TensorDescription>,
|
||||
pub(crate) locals: HashMap<TensorId, u16>,
|
||||
pub(crate) tensors: HashMap<TensorId, TensorDescription>,
|
||||
pub(crate) scalars_f32: Vec<f32>,
|
||||
pub(crate) operators: Vec<Operator>,
|
||||
pub(crate) properties: FusionProperties,
|
||||
pub(crate) current_output_shape: Vec<usize>,
|
||||
device: Device<Wgpu<G, F, I>>,
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G, F, I>>
|
||||
for FloatElementWiseFusionOps<G, F, I>
|
||||
{
|
||||
fn register(&mut self, ops: Arc<TensorOpsDescription<Wgpu<G, F, I>>>) -> FusionStatus {
|
||||
match ops.as_ref() {
|
||||
TensorOpsDescription::FloatOps(ops) => {
|
||||
if !self.register_float(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::NumericOpsFloat(ops) => {
|
||||
if !self.register_numeric(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
}
|
||||
};
|
||||
|
||||
self.properties.score += 1;
|
||||
self.properties.ready = self.operators.len() > 1;
|
||||
|
||||
FusionStatus::Open(self.properties)
|
||||
}
|
||||
|
||||
fn execute(&mut self, handles: &mut HandleContainer<Wgpu<G, F, I>>) {
|
||||
let inputs = self.input_descriptions();
|
||||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
.iter()
|
||||
.map(|out| *self.locals.get(&out.id).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
FusionKernel::new(&self.device)
|
||||
.inputs(&inputs, &self.scalars_f32)
|
||||
.body(&self.operators)
|
||||
.outputs(&outputs, &locals)
|
||||
.execute(handles);
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.inputs.clear();
|
||||
self.locals.drain();
|
||||
self.tensors.clear();
|
||||
self.scalars_f32.clear();
|
||||
self.operators.clear();
|
||||
self.properties = FusionProperties::default();
|
||||
self.current_output_shape.clear();
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<G, F, I> FloatElementWiseFusionOps<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
pub fn new(device: Device<Wgpu<G, F, I>>) -> Self {
|
||||
Self {
|
||||
inputs: Vec::new(),
|
||||
locals: HashMap::new(),
|
||||
tensors: HashMap::new(),
|
||||
scalars_f32: Vec::new(),
|
||||
operators: Vec::new(),
|
||||
current_output_shape: Vec::new(),
|
||||
properties: FusionProperties::default(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_descriptions(&self) -> Vec<&TensorDescription> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
let updated_tensor = self.tensors.get(&input.id).unwrap();
|
||||
updated_tensor
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn output_descriptions(&self) -> Vec<&TensorDescription> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut local_tensor_ids_input = Vec::new();
|
||||
let mut local_tensor_ids_output = Vec::new();
|
||||
|
||||
// Mark a variable to the provided list of tensor ids using the variable list.
|
||||
//
|
||||
// Only local variables can become outputs.
|
||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||
if let Variable::Local(index) = var {
|
||||
if let Some((id, _)) = self
|
||||
.locals
|
||||
.iter()
|
||||
.find(|(_id, position)| *position == index)
|
||||
{
|
||||
if !list.contains(id) {
|
||||
list.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// For all operators, mark their local tensor id in the proper set.
|
||||
for ops in self.operators.iter() {
|
||||
match ops {
|
||||
Operator::AssignGlobal { input: _, out: _ } => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::ReadGlobal {
|
||||
variable: _,
|
||||
position: _,
|
||||
position_out: _,
|
||||
} => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::Add { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Sub { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Mul { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Div { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Exp { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Abs { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Erf { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Log { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Log1p { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Cos { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Sin { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Tanh { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Powf { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All output tensors that are never read by a following operation should be written to
|
||||
// since they are essentially the "logical" output of the shader.
|
||||
for out in local_tensor_ids_output {
|
||||
let is_read = local_tensor_ids_input.contains(&out);
|
||||
|
||||
if !is_read {
|
||||
outputs.push(self.tensors.get(&out).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// All tensors where their latest description is read only should be written to since they
|
||||
// are going to be used after the fused kernel by other operations.
|
||||
for tensor in self.tensors.values() {
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if self.locals.contains_key(&tensor.id) {
|
||||
outputs.push(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outputs
|
||||
}
|
||||
|
||||
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||
|
||||
let variable = match already_exists {
|
||||
false => {
|
||||
// New input
|
||||
let var = Variable::Input(self.inputs.len() as u16);
|
||||
self.inputs.push(tensor.clone());
|
||||
var
|
||||
}
|
||||
true => match self.locals.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => Variable::Local(*local_index),
|
||||
// Isn't a local variable, so must be an existing input.
|
||||
None => {
|
||||
let input = self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, input)| input.id == tensor.id)
|
||||
.unwrap();
|
||||
let input_index = input.0;
|
||||
Variable::Input(input_index as u16)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Update the tensor description with the new version.
|
||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
||||
// Update the tensor description to the new version.
|
||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
||||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.locals.get(&tensor.id) {
|
||||
return Variable::Local(*index);
|
||||
}
|
||||
|
||||
// New local variable.
|
||||
let local_index = self.locals.len() as u16;
|
||||
self.locals.insert(tensor.id.clone(), local_index);
|
||||
Variable::Local(local_index)
|
||||
}
|
||||
|
||||
fn register_float<B: FusionBackend>(&mut self, ops: &FloatOpsDescription<B>) -> bool {
|
||||
match ops {
|
||||
FloatOpsDescription::Exp(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out })
|
||||
}
|
||||
FloatOpsDescription::Log(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Log { input, out })
|
||||
}
|
||||
FloatOpsDescription::Log1p(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out })
|
||||
}
|
||||
FloatOpsDescription::Cos(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out })
|
||||
}
|
||||
FloatOpsDescription::Sin(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out })
|
||||
}
|
||||
FloatOpsDescription::Powf(desc, _) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
|
||||
}
|
||||
FloatOpsDescription::Tanh(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out })
|
||||
}
|
||||
FloatOpsDescription::Erf(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_numeric<B: FusionBackend, E: Element>(
|
||||
&mut self,
|
||||
ops: &NumericOpsDescription<B, E>,
|
||||
) -> bool {
|
||||
match ops {
|
||||
NumericOpsDescription::Add(desc, _) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::AddScalar(desc, _) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Sub(desc, _) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::SubScalar(desc, _) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Mul(desc, _) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::MulScalar(desc, _) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Div(desc, _) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::DivScalar(desc, _) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Abs(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out })
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs);
|
||||
let rhs = self.input_to_var(&desc.rhs);
|
||||
let out = self.output_to_var(&desc.out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let input = self.input_to_var(&desc.input);
|
||||
let out = self.output_to_var(&desc.out);
|
||||
|
||||
self.operators.push(func(input, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_scalar_ops<Func, E: Element>(
|
||||
&mut self,
|
||||
desc: &ScalarOpsDescription<E>,
|
||||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs);
|
||||
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32);
|
||||
self.scalars_f32.push(desc.rhs.elem());
|
||||
let out = self.output_to_var(&desc.out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
|
||||
if self.current_output_shape.is_empty() {
|
||||
self.current_output_shape = out.shape.clone();
|
||||
} else if self.current_output_shape != out.shape {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_fusion::graph::{BinaryOpsDescription, Ops};
|
||||
use burn_fusion::Fusion;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
struct FakeAddOps;
|
||||
|
||||
impl<B: FusionBackend> Ops<B> for FakeAddOps {
|
||||
type Args = BinaryOpsDescription;
|
||||
|
||||
fn execute(&self, _: &Self::Args, _: &mut HandleContainer<B>) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_same_behavior() {
|
||||
type Backend = Wgpu;
|
||||
type FusedBackend = Fusion<Wgpu>;
|
||||
|
||||
let data_1 =
|
||||
Tensor::<Backend, 2>::random([1, 32], burn_tensor::Distribution::Default).into_data();
|
||||
let data_2 =
|
||||
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
|
||||
|
||||
let tensor_1 = Tensor::<Backend, 2>::from_data(data_1.clone());
|
||||
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_ref = tensor_6.into_data();
|
||||
|
||||
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_fused = tensor_6.into_data();
|
||||
|
||||
result_fused.assert_approx_eq(&result_ref, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,320 @@
|
|||
use super::codegen::Body;
|
||||
use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient};
|
||||
use crate::fusion::codegen::Function;
|
||||
use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank};
|
||||
use crate::fusion::{
|
||||
codegen::{
|
||||
Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize,
|
||||
},
|
||||
WgpuFusionHandle,
|
||||
};
|
||||
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
|
||||
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||
use burn_fusion::{HandleContainer, TensorDescription};
|
||||
use burn_tensor::Device;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct InputPhase;
|
||||
/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct BodyPhase;
|
||||
/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct OutputPhase;
|
||||
/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct ExecutionPhase;
|
||||
|
||||
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
||||
///
|
||||
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
||||
/// you to make mistakes.
|
||||
///
|
||||
/// 1. [Input Phase](InputPhase)
|
||||
/// This phase focuses on registering the input tensor descriptions that are going to be used by
|
||||
/// the fused kernel.
|
||||
/// 2. [Body Phase](BodyPhase)
|
||||
/// After the input phase is done, all the operations that happen in the body must be
|
||||
/// registered.
|
||||
/// 3. [Output Phase](OutputPhase)
|
||||
/// This step focuses on registering all tensor descriptions that the kernel needs to write to.
|
||||
/// 4. [Execution Phase](ExecutionPhase)
|
||||
/// Now that all other phases are completed, we can actually run the kernel on the given
|
||||
/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the
|
||||
/// handles provided.
|
||||
pub struct FusionKernel<G, F, I, Phase = InputPhase>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
operations: Vec<Operator>,
|
||||
input_bindings: Vec<(Binding, TensorDescription)>,
|
||||
output_bindings: Vec<(Binding, TensorDescription)>,
|
||||
named_bindings: Vec<(String, Binding, DataBuffer)>,
|
||||
functions: Vec<Function>,
|
||||
num_elems_output: usize,
|
||||
device: Device<Wgpu<G, F, I>>,
|
||||
client: WgpuComputeClient,
|
||||
_phase: PhantomData<Phase>,
|
||||
}
|
||||
|
||||
enum DataBuffer {
|
||||
F32(Vec<f32>),
|
||||
U32(Vec<u32>),
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, InputPhase> {
|
||||
/// Create a new fusion kernel on the given device.
|
||||
pub fn new(device: &Device<Wgpu<G, F, I>>) -> Self {
|
||||
let client = compute_client::<G>(device);
|
||||
|
||||
Self {
|
||||
operations: Vec::new(),
|
||||
input_bindings: Vec::new(),
|
||||
output_bindings: Vec::new(),
|
||||
named_bindings: Vec::new(),
|
||||
functions: Vec::new(),
|
||||
num_elems_output: 0,
|
||||
device: device.clone(),
|
||||
client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(
|
||||
mut self,
|
||||
inputs_tensor: &[&TensorDescription],
|
||||
inputs_scalar_f32: &[f32],
|
||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
||||
for (i, input) in inputs_tensor.iter().enumerate() {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
}
|
||||
|
||||
if !inputs_scalar_f32.is_empty() {
|
||||
self.named_bindings.push((
|
||||
"scalars_f32".to_string(),
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(inputs_scalar_f32.len()),
|
||||
},
|
||||
DataBuffer::F32(inputs_scalar_f32.to_vec()),
|
||||
));
|
||||
}
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, BodyPhase> {
|
||||
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
||||
pub fn body(mut self, operators: &[Operator]) -> FusionKernel<G, F, I, OutputPhase> {
|
||||
let mut register_function = |function: Function| {
|
||||
if !self.functions.contains(&function) {
|
||||
self.functions.push(function);
|
||||
}
|
||||
};
|
||||
|
||||
// Since not all operators are native to WGSL, we need to add the custom ones.
|
||||
for ops in operators.iter() {
|
||||
match ops {
|
||||
Operator::Powf {
|
||||
lhs: _,
|
||||
rhs: _,
|
||||
out: _,
|
||||
} => {
|
||||
register_function(Function::Powf(Elem::F32));
|
||||
}
|
||||
Operator::Erf { input: _, out: _ } => {
|
||||
register_function(Function::Erf(Elem::F32));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.operations.push(ops.clone());
|
||||
}
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, OutputPhase> {
|
||||
/// Register the outputs with their local variable index.
|
||||
///
|
||||
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
||||
/// [body phase](BodyPhase).
|
||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(
|
||||
mut self,
|
||||
outputs: &[&TensorDescription],
|
||||
locals: &[u16],
|
||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
let mut num_elems_launch_option = 0;
|
||||
|
||||
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() {
|
||||
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
||||
if num_elems_output > num_elems_launch_option {
|
||||
num_elems_launch_option = num_elems_output;
|
||||
}
|
||||
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local),
|
||||
out: Variable::Output(i as u16),
|
||||
});
|
||||
}
|
||||
|
||||
self.num_elems_output = num_elems_launch_option;
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
/// Execute the kernel on the provided [handles](HandleContainer).
|
||||
pub fn execute(mut self, handle_container: &mut HandleContainer<Wgpu<G, F, I>>) {
|
||||
let mut inputs = Vec::with_capacity(self.input_bindings.len());
|
||||
let mut outputs = Vec::with_capacity(self.output_bindings.len());
|
||||
let mut named = Vec::with_capacity(2);
|
||||
let mut info = Vec::new();
|
||||
let mut handles =
|
||||
Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity());
|
||||
|
||||
// Inner function to fill the info buffer.
|
||||
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
|
||||
if info.is_empty() {
|
||||
info.push(handle.strides.len() as u32);
|
||||
}
|
||||
|
||||
for s in handle.strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in tensor.shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
};
|
||||
|
||||
// We start by registering the inputs.
|
||||
for (binding, tensor) in self.input_bindings.into_iter() {
|
||||
let handle = handle_container.get_handle(&tensor);
|
||||
register_info_tensor(&tensor, &handle);
|
||||
|
||||
inputs.push(binding);
|
||||
handles.push(handle.handle);
|
||||
}
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for (binding, tensor) in self.output_bindings {
|
||||
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
|
||||
let handle_fusion = WgpuFusionHandle {
|
||||
client: self.client.clone(),
|
||||
device: self.device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
handle: self.client.empty(core::mem::size_of::<F>() * num_elems),
|
||||
};
|
||||
register_info_tensor(&tensor, &handle_fusion);
|
||||
|
||||
handles.push(handle_fusion.handle.clone());
|
||||
handle_container.register_handle(tensor.id, handle_fusion);
|
||||
outputs.push(binding);
|
||||
}
|
||||
|
||||
// Now we can create the info handle.
|
||||
Self::build_info_handle(&mut self.named_bindings, info);
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
for (name, binding, data) in self.named_bindings {
|
||||
let handle = self.client.create(match &data {
|
||||
DataBuffer::F32(values) => bytemuck::cast_slice(values),
|
||||
DataBuffer::U32(values) => bytemuck::cast_slice(values),
|
||||
});
|
||||
named.push((name, binding));
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// We create the shader codegen type and launch the kernel.
|
||||
let kernel = ComputeShader {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
workgroup_size: WorkgroupSize::default(),
|
||||
body: Body::new(self.operations),
|
||||
num_workgroups: true,
|
||||
global_invocation_id: true,
|
||||
functions: self.functions,
|
||||
};
|
||||
|
||||
let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT);
|
||||
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
|
||||
|
||||
self.client
|
||||
.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec<u32>) {
|
||||
named_bindings.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
elem: Elem::U32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
DataBuffer::U32(info),
|
||||
));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
mod base;
|
||||
mod elemwise;
|
||||
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod kernel;
|
||||
|
||||
pub use base::*;
|
||||
pub use elemwise::*;
|
|
@ -1,7 +1,8 @@
|
|||
use super::SourceTemplate;
|
||||
use crate::{
|
||||
compute::{StaticKernel, WorkGroup},
|
||||
compute::{StaticKernel, WgpuComputeClient, WgpuHandle, WorkGroup},
|
||||
element::WgpuElement,
|
||||
kernel,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
@ -76,6 +77,33 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
|
|||
output
|
||||
}
|
||||
|
||||
/// Similar to [into contiguous](into_contiguous) but with dynamic rank.
|
||||
pub fn into_contiguous_dyn<E: WgpuElement>(
|
||||
client: WgpuComputeClient,
|
||||
input: WgpuHandle,
|
||||
input_shape: &[usize],
|
||||
input_strides: &[usize],
|
||||
output_shape: &[usize],
|
||||
output_strides: &[usize],
|
||||
num_elems: usize,
|
||||
) -> WgpuHandle {
|
||||
let handle = client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let info = kernel::build_info_dyn::<E>(
|
||||
&[input_shape, output_shape],
|
||||
&[input_strides, output_strides],
|
||||
);
|
||||
|
||||
let info_handle = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = Box::new(StaticKernel::<
|
||||
KernelSettings<ContiguousRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)));
|
||||
|
||||
client.execute(kernel, &[&input, &handle, &info_handle]);
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
pub struct KernelSettings<
|
||||
K: StaticKernelSource,
|
||||
|
@ -184,6 +212,28 @@ pub fn build_info<E: WgpuElement, const D: usize>(tensors: &[&WgpuTensor<E, D>])
|
|||
info
|
||||
}
|
||||
|
||||
/// Similar to [build info](build_info) but with dynamic rank.
|
||||
pub fn build_info_dyn<E: WgpuElement>(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec<u32> {
|
||||
let rank = shapes.get(0).unwrap().len();
|
||||
let mut info: Vec<u32> = vec![0; shapes.len() * 2 * rank + 1];
|
||||
info[0] = rank as u32;
|
||||
|
||||
let mut current = 1;
|
||||
for stride in strides.iter() {
|
||||
for d in 0..rank {
|
||||
info[current] = stride[d] as u32;
|
||||
current += 1;
|
||||
}
|
||||
}
|
||||
for shape in shapes.iter() {
|
||||
for d in 0..rank {
|
||||
info[current] = shape[d] as u32;
|
||||
current += 1;
|
||||
}
|
||||
}
|
||||
info
|
||||
}
|
||||
|
||||
pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup {
|
||||
let num_elem_per_invocation = workgroup_size * workgroup_size;
|
||||
let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32);
|
||||
|
|
Loading…
Reference in New Issue