[burn-fusion] save all execution plans for any trigger (#1143)

This commit is contained in:
Nathaniel Simard 2024-01-16 14:02:42 -05:00 committed by GitHub
parent 6079f98950
commit b99726f804
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 2178 additions and 1504 deletions

View File

@ -51,6 +51,10 @@ pub fn save<B: Backend>(
.join("burn")
.join("backend-comparison");
for bench in benches.iter() {
println!("{bench}");
}
if !cache_dir.exists() {
fs::create_dir_all(&cache_dir)?;
}

View File

@ -1,6 +1,6 @@
use crate::{
client::FusionClient,
stream::{Context, TensorOpsDescription},
stream::{Context, OperationDescription},
FusionClientLocator, FusionTensor,
};
use burn_tensor::{backend::Backend, Device, Shape};
@ -70,7 +70,7 @@ pub struct OptimizationProperties {
}
/// The fusion operation abstraction allows implementations to fuse many
/// [tensor operations](TensorOpsDescription) into one, improving the performance of the backend.
/// [tensor operations](OperationDescription) into one, improving the performance of the backend.
///
///
/// # Notes
@ -79,19 +79,25 @@ pub struct OptimizationProperties {
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
/// operations should be fused, but that another way of executing them is more efficient.
///
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait OptimizationBuilder<B: FusionBackend>: Send {
/// Register a new [tensor operation](TensorOpsDescription).
fn register(&mut self, ops: &TensorOpsDescription);
pub trait OptimizationBuilder<O>: Send {
/// Register a new [tensor operation](OperationDescription).
fn register(&mut self, operation: &OperationDescription);
/// Finish the optimization and create a fusion operation.
fn build(&self) -> B::Optimization;
fn build(&self) -> O;
/// Reset the state.
fn reset(&mut self);
/// Return the builder [status](OptimizationStatus).
fn status(&self) -> OptimizationStatus;
/// Return the builder [properties](OptimizationProperties).
fn properties(&self) -> OptimizationProperties;
/// The number of operation fused.
fn len(&self) -> usize;
/// If no operations are fused.
fn is_empty(&self) -> bool {
self.len() == 0
}
}
/// The operation created from the [builder](OptimizationBuilder).
@ -143,7 +149,8 @@ pub trait FusionBackend: Backend {
type FusionClient: FusionClient<FusionBackend = Self>;
/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>) -> Vec<Box<dyn OptimizationBuilder<Self>>>;
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(

View File

@ -1,5 +1,5 @@
use crate::{
stream::{Ops, TensorOpsDescription},
stream::{Operation, OperationDescription},
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
};
use burn_tensor::{
@ -14,11 +14,11 @@ pub trait FusionClient: Send + Sync + Clone {
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Register a new [tensor operation description](TensorOpsDescription).
fn register<O: Ops<Self::FusionBackend> + 'static>(
/// Register a new [tensor operation description](OperationDescription).
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
description: TensorOpsDescription,
ops: O,
description: OperationDescription,
operation: O,
);
/// Register all lazy computation.
fn drain(&self);

View File

@ -1,5 +1,8 @@
use super::FusionClient;
use crate::{stream::TensorOpsDescription, FusionBackend, FusionServer, FusionTensor, Handle};
use crate::{
stream::{Operation, OperationDescription},
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use spin::Mutex;
use std::sync::Arc;
@ -38,12 +41,14 @@ where
}
}
fn register<O: crate::stream::Ops<Self::FusionBackend> + 'static>(
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
description: TensorOpsDescription,
ops: O,
description: OperationDescription,
operation: O,
) {
self.server.lock().register(description, Box::new(ops))
self.server
.lock()
.register(description, Box::new(operation))
}
fn drain(&self) {

View File

@ -7,10 +7,10 @@ macro_rules! binary_float_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
@ -31,10 +31,10 @@ macro_rules! binary_float_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
@ -55,10 +55,10 @@ macro_rules! binary_int_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);
@ -89,10 +89,10 @@ macro_rules! binary_int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);

View File

@ -3,9 +3,10 @@ use crate::{
get_client,
ops::binary::binary_ops_shape,
stream::{
BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops,
ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
TensorOpsDescription, UnaryOpsDescription,
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, Operation, OperationDescription, ReshapeDescription,
SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription,
UnaryOperationDescription,
},
Fusion, FusionBackend,
};
@ -48,10 +49,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> burn_tensor::ops::IntTensor<Self, D> {
#[derive(new)]
struct IntoIntOps<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for IntoIntOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_into_int(input);
@ -61,12 +62,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt(desc.clone())),
OperationDescription::Bool(BoolOperationDescription::IntoInt(desc.clone())),
IntoIntOps::<D>::new(desc),
);
@ -78,10 +79,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> burn_tensor::ops::FloatTensor<Self, D> {
#[derive(new)]
struct IntoFloatOps<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for IntoFloatOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_into_float(input);
@ -91,12 +92,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BoolOps(BoolOpsDescription::IntoFloat(desc.clone())),
OperationDescription::Bool(BoolOperationDescription::IntoFloat(desc.clone())),
IntoFloatOps::<D>::new(desc),
);
@ -135,7 +136,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
desc: ReshapeDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
@ -151,7 +152,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Reshape(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::Reshape(desc.clone())),
ReshapeDimsOps::<D1, D2>::new(desc),
);
@ -164,10 +165,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D1> {
#[derive(new)]
struct SliceOps<const D1: usize, const D2: usize> {
desc: SliceOpsDescription,
desc: SliceOperationDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
@ -186,13 +187,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(shape);
let desc = SliceOpsDescription {
let desc = SliceOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.into(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Slice(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::Slice(desc.clone())),
SliceOps::<D1, D2>::new(desc),
);
@ -206,10 +207,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D1> {
#[derive(new)]
struct SliceAssignOps<const D1: usize, const D2: usize> {
desc: SliceAssignOpsDescription,
desc: SliceAssignOperationDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceAssignOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
let value = handles.get_bool_tensor::<D1>(&self.desc.value);
@ -227,7 +228,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = SliceAssignOpsDescription {
let desc = SliceAssignOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.into(),
value: value.into_description(),
@ -235,7 +236,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
};
out.client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::SliceAssign(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::SliceAssign(desc.clone())),
SliceAssignOps::<D1, D2>::new(desc),
);
@ -248,10 +249,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
#[derive(new)]
struct CatOps<const D: usize> {
desc: CatOpsDescription,
desc: CatOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for CatOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensors = self
.desc
@ -278,13 +279,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = client.tensor_uninitialized(shape);
let desc = CatOpsDescription {
let desc = CatOperationDescription {
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
dim,
out: out.to_description_out(),
};
client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::Cat(desc.clone())),
CatOps::<D>::new(desc),
);
@ -297,10 +298,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
) -> BoolTensor<Self, D> {
#[derive(new)]
struct EqualOps<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for EqualOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for EqualOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let lhs = handles.get_bool_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_bool_tensor(&self.desc.rhs);
@ -313,13 +314,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Equal(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::Equal(desc.clone())),
EqualOps::<D>::new(desc),
);
@ -329,10 +330,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
#[derive(new)]
struct NotOps<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for NotOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for NotOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_not(input);
@ -342,13 +343,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BoolOps(crate::stream::BoolOpsDescription::Not(desc.clone())),
OperationDescription::Bool(crate::stream::BoolOperationDescription::Not(desc.clone())),
NotOps::<D>::new(desc),
);
@ -365,7 +366,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
desc: SwapDimsDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for SwapDimsOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
@ -386,7 +387,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::SwapDims(desc.clone())),
OperationDescription::BaseBool(BaseOperationDescription::SwapDims(desc.clone())),
SwapDimsOps::<D>::new(desc),
);

File diff suppressed because it is too large Load Diff

View File

@ -5,12 +5,13 @@ use crate::{
ops::binary::binary_ops_shape,
scalar_int_cmp_ops, scalar_int_ops,
stream::{
self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription,
GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription,
NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
TensorOpsDescription, UnaryOpsDescription,
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
OperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
},
unary_int_ops, Fusion, FusionBackend, TensorDescription,
};
@ -80,7 +81,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: ReshapeDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D1>(&self.desc.input);
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
@ -96,7 +97,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Reshape(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::Reshape(desc.clone())),
ReshapeDimsOps::<D1, D2>::new(desc),
);
@ -109,10 +110,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D1> {
#[derive(new)]
struct SliceOps<const D1: usize, const D2: usize> {
desc: SliceOpsDescription,
desc: SliceOperationDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
@ -131,13 +132,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(shape);
let desc = SliceOpsDescription {
let desc = SliceOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.into(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::Slice(desc.clone())),
SliceOps::<D1, D2>::new(desc),
);
@ -151,10 +152,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D1> {
#[derive(new)]
struct SliceAssignOps<const D1: usize, const D2: usize> {
desc: SliceAssignOpsDescription,
desc: SliceAssignOperationDescription,
}
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceAssignOps<D1, D2> {
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
let value = handles.get_int_tensor::<D1>(&self.desc.value);
@ -171,14 +172,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = SliceAssignOpsDescription {
let desc = SliceAssignOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.into(),
value: value.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::SliceAssign(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::SliceAssign(desc.clone())),
SliceAssignOps::<D1, D2>::new(desc),
);
@ -192,10 +193,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct MaskWhereOps<const D: usize> {
desc: MaskWhereOpsDescription,
desc: MaskWhereOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for MaskWhereOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let value = handles.get_int_tensor(&self.desc.value);
@ -210,14 +211,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = MaskWhereOpsDescription {
let desc = MaskWhereOperationDescription {
tensor: tensor.into_description(),
value: value.into_description(),
mask: mask.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaskWhere(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())),
MaskWhereOps::<D>::new(desc),
);
@ -231,10 +232,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct MaskFillOps<const D: usize> {
desc: MaskFillOpsDescription<i32>,
desc: MaskFillOperationDescription<i32>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for MaskFillOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let mask = handles.get_bool_tensor(&self.desc.mask);
@ -247,14 +248,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = MaskFillOpsDescription {
let desc = MaskFillOperationDescription {
tensor: tensor.into_description(),
value: value.elem(),
mask: mask.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaskFill(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())),
MaskFillOps::<D>::new(desc),
);
@ -268,10 +269,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct GatherOps<const D: usize> {
desc: GatherOpsDescription,
desc: GatherOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for GatherOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -283,14 +284,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = indices.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = GatherOpsDescription {
let desc = GatherOperationDescription {
tensor: tensor.into_description(),
dim,
indices: indices.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Gather(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())),
GatherOps::<D>::new(desc),
);
@ -305,10 +306,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct ScatterOps<const D: usize> {
desc: ScatterOpsDescription,
desc: ScatterOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for ScatterOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -322,7 +323,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScatterOpsDescription {
let desc = ScatterOperationDescription {
tensor: tensor.into_description(),
dim,
indices: indices.into_description(),
@ -330,7 +331,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Scatter(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())),
ScatterOps::<D>::new(desc),
);
@ -344,10 +345,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct SelectOps<const D: usize> {
desc: SelectOpsDescription,
desc: SelectOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for SelectOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -361,14 +362,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let mut shape: Vec<usize> = tensor.shape.clone();
shape[dim] = indices.shape[0];
let out = tensor.client.tensor_uninitialized(shape);
let desc = SelectOpsDescription {
let desc = SelectOperationDescription {
tensor: tensor.into_description(),
dim,
indices: indices.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Select(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())),
SelectOps::<D>::new(desc),
);
@ -383,10 +384,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct SelectAssignOps<const D: usize> {
desc: SelectAssignOpsDescription,
desc: SelectAssignOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for SelectAssignOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let indices = handles.get_int_tensor(&self.desc.indices);
@ -400,7 +401,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let shape: Vec<usize> = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape);
let desc = SelectAssignOpsDescription {
let desc = SelectAssignOperationDescription {
tensor: tensor.into_description(),
dim,
indices: indices.into_description(),
@ -408,7 +409,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SelectAssign(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::SelectAssign(
desc.clone(),
)),
SelectAssignOps::<D>::new(desc),
);
@ -418,10 +421,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
#[derive(new)]
struct CatOps<const D: usize> {
desc: CatOpsDescription,
desc: CatOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for CatOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensors = self
.desc
@ -448,13 +451,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = client.tensor_uninitialized(shape);
let desc = CatOpsDescription {
let desc = CatOperationDescription {
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
dim,
out: out.to_description_out(),
};
client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::Cat(desc.clone())),
CatOps::<D>::new(desc),
);
@ -471,13 +474,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::Equal(desc.clone())),
EqualOps::<D>::new(desc),
);
@ -492,13 +495,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::EqualElem(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())),
EqualElemOps::<D>::new(desc),
);
@ -515,13 +518,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Greater(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())),
GreaterOps::<D>::new(desc),
);
@ -536,13 +539,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterElem(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::GreaterElem(
desc.clone(),
)),
GreaterElemOps::<D>::new(desc),
);
@ -559,13 +564,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterEqual(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual(
desc.clone(),
)),
GreaterEqualOps::<D>::new(desc),
);
@ -580,13 +587,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterEqualElem(
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem(
desc.clone(),
)),
GreaterEqualElemOps::<D>::new(desc),
@ -605,13 +612,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Lower(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())),
LowerOps::<D>::new(desc),
);
@ -626,13 +633,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerElem(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())),
LowerElemOps::<D>::new(desc),
);
@ -649,13 +656,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerEqual(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())),
LowerEqualOps::<D>::new(desc),
);
@ -670,13 +677,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerEqualElem(
OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem(
desc.clone(),
)),
LowerEqualElemOps::<D>::new(desc),
@ -695,13 +702,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Add(desc.clone())),
stream::OperationDescription::NumericInt(NumericOperationDescription::Add(
desc.clone(),
)),
AddOps::<D>::new(desc),
);
@ -716,13 +725,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::AddScalar(
stream::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
desc.clone(),
)),
AddOps::<D>::new(desc),
@ -741,13 +750,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Sub(desc.clone())),
stream::OperationDescription::NumericInt(NumericOperationDescription::Sub(
desc.clone(),
)),
SubOps::<D>::new(desc),
);
@ -762,13 +773,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SubScalar(
stream::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
desc.clone(),
)),
SubOps::<D>::new(desc),
@ -787,13 +798,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Mul(desc.clone())),
stream::OperationDescription::NumericInt(NumericOperationDescription::Mul(
desc.clone(),
)),
MulOps::<D>::new(desc),
);
@ -808,13 +821,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MulScalar(
stream::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
desc.clone(),
)),
MulOps::<D>::new(desc),
@ -833,13 +846,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
let desc = BinaryOpsDescription {
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Div(desc.clone())),
stream::OperationDescription::NumericInt(NumericOperationDescription::Div(
desc.clone(),
)),
DivOps::<D>::new(desc),
);
@ -854,13 +869,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.elem(),
out: out.to_description_out(),
};
out.client.register(
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::DivScalar(
stream::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
desc.clone(),
)),
DivOps::<D>::new(desc),
@ -875,7 +890,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: TensorDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for ZerosOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.desc.shape.clone());
let output = B::int_zeros::<D>(shape, &handles.device);
@ -888,7 +903,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = client.tensor_uninitialized(shape);
let desc = out.to_description_out();
client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Zeros(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())),
ZerosOps::<D>::new(desc),
);
@ -901,7 +916,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: TensorDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for OnesOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let shape = Shape::from(self.desc.shape.clone());
let output = B::int_ones::<D>(shape, &handles.device);
@ -915,7 +930,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let desc = out.to_description_out();
client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Ones(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())),
OnesOps::<D>::new(desc),
);
@ -927,12 +942,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(vec![1]);
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Sum(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())),
SumOps::<D>::new(desc),
);
@ -946,13 +961,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SumDim(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())),
SumDimOps::<D>::new(desc),
);
@ -964,12 +979,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(vec![1]);
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Mean(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())),
MeanOps::<D>::new(desc),
);
@ -983,13 +998,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MeanDim(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())),
MeanDimOps::<D>::new(desc),
);
@ -1003,13 +1018,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ArgMax(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())),
ArgMaxOps::<D>::new(desc),
);
@ -1023,13 +1038,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ArgMin(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())),
ArgMinOps::<D>::new(desc),
);
@ -1043,10 +1058,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
) -> IntTensor<Self, D> {
#[derive(new)]
struct ClampOps<const D: usize> {
desc: ClampOpsDescription<i32>,
desc: ClampOperationDescription<i32>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for ClampOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.tensor);
let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
@ -1056,14 +1071,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = ClampOpsDescription {
let desc = ClampOperationDescription {
tensor: tensor.into_description(),
min: min.elem(),
max: max.elem(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Clamp(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())),
ClampOps::<D>::new(desc),
);
@ -1075,12 +1090,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Abs(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())),
AbsOps::<D>::new(desc),
);
@ -1090,10 +1105,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(new)]
struct IntoFloatOps<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for IntoFloatOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = B::int_into_float(input);
@ -1102,12 +1117,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::IntOps(stream::IntOpsDescription::IntoFloat(desc.clone())),
OperationDescription::Int(stream::IntOperationDescription::IntoFloat(desc.clone())),
IntoFloatOps::<D>::new(desc),
);
@ -1124,7 +1139,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: SwapDimsDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for SwapDimsOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
@ -1145,7 +1160,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::SwapDims(desc.clone())),
OperationDescription::BaseInt(BaseOperationDescription::SwapDims(desc.clone())),
SwapDimsOps::<D>::new(desc),
);
@ -1157,12 +1172,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(vec![1]);
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Max(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())),
MaxOps::<D>::new(desc),
);
@ -1176,13 +1191,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaxDim(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())),
MaxDimOps::<D>::new(desc),
);
@ -1198,7 +1213,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: ReduceDimWithIndicesDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for MaxDimWithIndicesOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
@ -1220,7 +1235,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out_indices: out_indices.to_description_out(),
};
client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaxDimWithIndices(
OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices(
desc.clone(),
)),
MaxDimWithIndicesOps::<D>::new(desc),
@ -1234,12 +1249,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let out = tensor.client.tensor_uninitialized(vec![1]);
let desc = UnaryOpsDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Min(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())),
MinOps::<D>::new(desc),
);
@ -1253,13 +1268,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
shape[dim] = 1;
let out = tensor.client.tensor_uninitialized(shape);
let desc = ScalarOpsDescription {
let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MinDim(desc.clone())),
OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())),
MinDimOps::<D>::new(desc),
);
@ -1275,7 +1290,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
desc: ReduceDimWithIndicesDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for MinDimWithIndicesOps<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
@ -1297,7 +1312,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out_indices: out_indices.to_description_out(),
};
client.register(
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MinDimWithIndices(
OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices(
desc.clone(),
)),
MinDimWithIndicesOps::<D>::new(desc),

View File

@ -7,8 +7,8 @@ use crate::{
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Ops,
TensorOpsDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Operation,
OperationDescription,
},
Fusion, FusionBackend, HandleContainer,
};
@ -28,7 +28,7 @@ macro_rules! make_ops {
desc: $desc,
}
impl<B: FusionBackend> Ops<B> for $name {
impl<B: FusionBackend> Operation<B> for $name {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
#[allow(clippy::redundant_closure_call)]
$fn(self.desc, handles)
@ -78,7 +78,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.clone().register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::Conv1d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv1d(
description.clone(),
)),
Conv1dOps::new(description),
@ -136,7 +136,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::Conv2d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv2d(
desc.clone(),
)),
Conv2dOps::new(desc),
@ -188,9 +188,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::ConvTranspose1d(
desc.clone(),
)),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::ConvTranspose1d(desc.clone()),
),
ConvTranspose1dOps::new(desc),
);
@ -248,9 +248,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::ConvTranspose2d(
desc.clone(),
)),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::ConvTranspose2d(desc.clone()),
),
ConvTranspose2dOps::new(desc),
);
@ -294,7 +294,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::AvgPool1d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool1d(
desc.clone(),
)),
AvgPool1dOps::new(desc),
@ -344,7 +344,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::AvgPool2d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool2d(
desc.clone(),
)),
AvgPool2dOps::new(desc),
@ -392,8 +392,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AvgPool1dBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AvgPool1dBackward(desc.clone()),
),
AvgPool1dBackwardOps::new(desc),
);
@ -440,8 +440,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AvgPool2dBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AvgPool2dBackward(desc.clone()),
),
AvgPool2dBackwardOps::new(desc),
);
@ -487,7 +487,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::MaxPool1d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool1d(
desc.clone(),
)),
MaxPool1dOps::new(desc),
@ -547,7 +547,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::MaxPool2d(
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool2d(
desc.clone(),
)),
MaxPool2dOps::new(desc),
@ -596,8 +596,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out_indices: out_indices.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::MaxPool1dWithIndices(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool1dWithIndices(desc.clone()),
),
MaxPool1dWithIndicesOps::new(desc),
);
@ -659,8 +659,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out_indices: out_indices.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::MaxPool2dWithIndices(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool2dWithIndices(desc.clone()),
),
MaxPool2dWithIndicesOps::new(desc),
);
@ -711,8 +711,10 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool1dWithIndicesBackward(
desc.clone(),
),
),
MaxPool1dWithIndicesBackwardOps::new(desc),
);
@ -763,8 +765,10 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::MaxPool2dWithIndicesBackward(
desc.clone(),
),
),
MaxPool2dWithIndicesBackwardOps::new(desc),
);
@ -793,8 +797,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AdaptiveAvgPool1d(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1d(desc.clone()),
),
AdaptiveAvgPool1dOps::new(desc),
);
@ -826,8 +830,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AdaptiveAvgPool2d(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2d(desc.clone()),
),
AdaptiveAvgPool2dOps::new(desc),
);
@ -859,8 +863,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc.clone()),
),
AdaptiveAvgPool1dBackwardOps::new(desc),
);
@ -892,8 +896,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out: out.to_description_out(),
};
out.client.register(
TensorOpsDescription::ModuleOps(
crate::stream::ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc.clone()),
OperationDescription::Module(
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc.clone()),
),
AdaptiveAvgPool2dBackwardOps::new(desc),
);

View File

@ -14,10 +14,10 @@ macro_rules! scalar_float_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<$elem>,
desc: ScalarOperationDescription<$elem>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -34,10 +34,10 @@ macro_rules! scalar_float_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<$elem>,
desc: ScalarOperationDescription<$elem>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);
@ -58,10 +58,10 @@ macro_rules! scalar_float2int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<$elem>,
desc: ScalarOperationDescription<$elem>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs.clone());
@ -81,10 +81,10 @@ macro_rules! unary_float_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = $ops(input);
@ -104,10 +104,10 @@ macro_rules! unary_int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: UnaryOpsDescription,
desc: UnaryOperationDescription,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let output = $ops(input);
@ -127,10 +127,10 @@ macro_rules! scalar_float_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<f32>,
desc: ScalarOperationDescription<f32>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -150,10 +150,10 @@ macro_rules! scalar_int_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<i32>,
desc: ScalarOperationDescription<i32>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -180,10 +180,10 @@ macro_rules! scalar_int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<$elem>,
desc: ScalarOperationDescription<$elem>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
@ -200,10 +200,10 @@ macro_rules! scalar_int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: ScalarOpsDescription<$elem>,
desc: ScalarOperationDescription<$elem>,
}
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let output = $ops(lhs, self.desc.rhs);

View File

@ -1,5 +1,5 @@
use crate::{
stream::{MultiStream, Ops, TensorOpsDescription},
stream::{MultiStream, Operation, OperationDescription},
FusionBackend, HandleContainer, TensorId,
};
use burn_tensor::ops::{FloatElem, IntElem};
@ -26,8 +26,8 @@ where
}
}
pub fn register(&mut self, ops_desc: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
self.streams.register(ops_desc, ops, &mut self.handles)
pub fn register(&mut self, desc: OperationDescription, operation: Box<dyn Operation<B>>) {
self.streams.register(desc, operation, &mut self.handles)
}
pub fn drain_streams(&mut self) {

View File

@ -1,51 +1,52 @@
use super::Ops;
use super::RelativeStreamConverter;
use super::TensorOpsDescription;
use super::Operation;
use super::OperationConverter;
use super::OperationDescription;
use crate::FusionBackend;
/// A growing list of [tensor operation descriptions](TensorOpsDescription).
pub struct Stream<B: FusionBackend> {
pub(crate) global: Vec<TensorOpsDescription>,
pub(crate) relative: Vec<TensorOpsDescription>,
pub(crate) converter: RelativeStreamConverter,
pub(crate) ops: Vec<Box<dyn Ops<B>>>,
/// A growing list of [tensor operation descriptions](OperationDescription).
pub struct OperationQueue<B: FusionBackend> {
pub(crate) global: Vec<OperationDescription>,
pub(crate) relative: Vec<OperationDescription>,
pub(crate) converter: OperationConverter,
pub(crate) operations: Vec<Box<dyn Operation<B>>>,
}
impl<B: FusionBackend> Stream<B> {
pub(crate) fn new() -> Self {
impl<B: FusionBackend> Default for OperationQueue<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: FusionBackend> OperationQueue<B> {
/// Create a new empty queue.
pub fn new() -> Self {
Self {
global: Vec::new(),
relative: Vec::new(),
converter: RelativeStreamConverter::default(),
ops: Vec::new(),
converter: OperationConverter::default(),
operations: Vec::new(),
}
}
pub(crate) fn split_relative_stream(
&self,
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
let len = self.relative.len();
if len < 1 {
return (&self.relative, None);
}
(&self.relative[0..len - 1], self.relative.last())
}
pub(crate) fn add(&mut self, global: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
/// Add a new tensor operation to the queue.
///
/// The new [operation description](OperationDescription) will be converted to a local
/// representation that can be reused when the same pattern emerge in different but similar
/// scenario, so that the same optimization can be used.
pub fn add(&mut self, global: OperationDescription, operation: Box<dyn Operation<B>>) {
let relative = global.to_relative(&mut self.converter);
self.relative.push(relative);
self.global.push(global);
self.ops.push(ops);
self.operations.push(operation);
}
/// The size of the stream.
pub(crate) fn len(&self) -> usize {
/// The size of the queue.
pub fn len(&self) -> usize {
self.global.len()
}
/// If the stream is empty.
pub(crate) fn is_empty(&self) -> bool {
/// If the queue is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,32 +1,32 @@
use crate::{
stream::{
store::{OptimizationId, OptimizationStore},
Stream,
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
OperationQueue,
},
FusionBackend, HandleContainer, Optimization,
};
/// The mode in which the execution is done.
#[derive(Clone, Copy, Debug)]
pub(crate) enum ExecutionMode {
Lazy,
Sync,
}
impl<B: FusionBackend> Stream<B> {
/// Execute the stream.
///
/// If an [optimization id](OptimizationId) is provided, use it to execute the stream partially, otherwise
/// execute each [operation](crate::stream::Ops).
impl<B: FusionBackend> OperationQueue<B> {
/// Execute the queue partially following the execution strategy from the plan.
pub(crate) fn execute(
&mut self,
id: Option<OptimizationId>,
id: ExecutionPlanId,
handles: &mut HandleContainer<B>,
store: &mut OptimizationStore<B::Optimization>,
store: &mut ExecutionPlanStore<B::Optimization>,
) {
match id {
Some(id) => self.execute_optimization(handles, &mut store.get_mut_unchecked(id).value),
None => self.execute_operations(handles),
}
match &mut store.get_mut_unchecked(id).strategy {
ExecutionStrategy::Optimization(optimization) => {
self.execute_optimization(handles, optimization)
}
ExecutionStrategy::Operations => self.execute_operations(handles),
};
}
fn execute_optimization(
@ -40,14 +40,14 @@ impl<B: FusionBackend> Stream<B> {
optimization.execute(&mut context);
self.drain_stream(num_drained, handles);
self.ops.drain(0..num_drained);
self.operations.drain(0..num_drained);
}
fn execute_operations(&mut self, handles: &mut HandleContainer<B>) {
let num_drained = self.ops.len();
let num_drained = self.operations.len();
for ops in self.ops.drain(0..num_drained) {
ops.execute(handles);
for operation in self.operations.drain(0..num_drained) {
operation.execute(handles);
}
self.drain_stream(num_drained, handles);

View File

@ -1,105 +0,0 @@
use super::ExecutionMode;
use crate::{stream::Stream, FusionBackend, OptimizationBuilder, OptimizationStatus};
/// Explore and create new optimization.
pub struct Explorer<B: FusionBackend> {
builders: Vec<Box<dyn OptimizationBuilder<B>>>,
num_deferred: usize,
}
/// The result of an exploration.
///
/// Either a new optimization is found, or we just continue to explore further.
pub enum Exploration<'a, B: FusionBackend> {
OptimizationFound(Option<&'a dyn OptimizationBuilder<B>>),
Continue,
}
impl<B: FusionBackend> Explorer<B> {
pub(crate) fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
Self {
builders: optimizations,
num_deferred: 0,
}
}
pub(crate) fn defer(&mut self) {
self.num_deferred += 1;
}
pub(crate) fn up_to_date(&self) -> bool {
self.num_deferred == 0
}
pub(crate) fn explore<'a>(
&'a mut self,
stream: &Stream<B>,
mode: ExecutionMode,
) -> Exploration<'a, B> {
// When we are executing with the new ops mode, we need to register the last ops of the
// stream even when there is no skipped operation.
let offset = match mode {
ExecutionMode::Lazy => 1,
ExecutionMode::Sync => 0,
};
for i in (0..self.num_deferred + offset).rev() {
let index = stream.relative.len() - 1 - i;
let relative = &stream.relative[index];
for builder in self.builders.iter_mut() {
builder.register(relative);
}
}
self.num_deferred = 0;
// Can only be lazy when not sync.
if let ExecutionMode::Lazy = mode {
if still_optimizing(&self.builders) {
return Exploration::Continue;
}
}
match find_best_optimization_index(&mut self.builders) {
Some(index) => Exploration::OptimizationFound(Some(self.builders[index].as_ref())),
None => Exploration::OptimizationFound(None),
}
}
pub(crate) fn reset(&mut self, stream: &Stream<B>) {
for ops in self.builders.iter_mut() {
ops.reset();
}
self.num_deferred = stream.relative.len();
}
}
fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuilder<B>>]) -> bool {
let mut num_stopped = 0;
for optimization in optimizations.iter() {
if let OptimizationStatus::Closed = optimization.status() {
num_stopped += 1
}
}
num_stopped < optimizations.len()
}
fn find_best_optimization_index<B: FusionBackend>(
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);
best_score = properties.score;
}
}
best_index
}

View File

@ -0,0 +1,125 @@
use super::ExecutionMode;
use crate::{stream::OperationDescription, OptimizationBuilder, OptimizationStatus};
/// Explore and create new optimization.
pub struct Explorer<O> {
builders: Vec<Box<dyn OptimizationBuilder<O>>>,
num_deferred: usize,
num_explored: usize,
}
/// The result of an exploration done by the [explorer](Explorer).
pub enum Exploration<'a, O> {
/// Found a new optimization.
Found(&'a dyn OptimizationBuilder<O>),
/// No optimization is found.
NotFound { num_explored: usize },
/// We should continue exploring before arriving at a conclusion.
Continue,
}
impl<O> Explorer<O> {
/// Create a new explorer.
pub(crate) fn new(optimizations: Vec<Box<dyn OptimizationBuilder<O>>>) -> Self {
Self {
builders: optimizations,
num_deferred: 0,
num_explored: 0,
}
}
/// Defer the exploration.
pub(crate) fn defer(&mut self) {
self.num_deferred += 1;
}
/// If the explorer is up to date.
pub(crate) fn is_up_to_date(&self) -> bool {
self.num_deferred == 0
}
/// Explore the provided operations.
pub(crate) fn explore<'a>(
&'a mut self,
operations: &[OperationDescription],
mode: ExecutionMode,
) -> Exploration<'a, O> {
// When we are executing with the new operation mode, we need to register the last ops of the
// stream even when there is no skipped operation.
let offset = match mode {
ExecutionMode::Lazy => 1,
ExecutionMode::Sync => 0,
};
let mut is_still_optimizing = still_optimizing(&self.builders);
for i in (0..self.num_deferred + offset).rev() {
if !is_still_optimizing {
break;
}
let index = operations.len() - 1 - i;
let relative = &operations[index];
for builder in self.builders.iter_mut() {
builder.register(relative);
}
self.num_explored += 1;
is_still_optimizing = still_optimizing(&self.builders);
}
self.num_deferred = 0;
// Can only continue exploration when not sync.
if let ExecutionMode::Lazy = mode {
if is_still_optimizing {
return Exploration::Continue;
}
}
match find_best_optimization_index(&mut self.builders) {
Some(index) => Exploration::Found(self.builders[index].as_ref()),
None => Exploration::NotFound {
num_explored: self.num_explored,
},
}
}
/// Reset the state of the explorer to the provided list of operations.
pub(crate) fn reset(&mut self, operations: &[OperationDescription]) {
for operation in self.builders.iter_mut() {
operation.reset();
}
self.num_explored = 0;
self.num_deferred = operations.len();
}
}
fn still_optimizing<O>(optimizations: &[Box<dyn OptimizationBuilder<O>>]) -> bool {
let mut num_stopped = 0;
for optimization in optimizations.iter() {
if let OptimizationStatus::Closed = optimization.status() {
num_stopped += 1
}
}
num_stopped < optimizations.len()
}
fn find_best_optimization_index<O>(
optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
) -> Option<usize> {
let mut best_index = None;
let mut best_score = 0;
for (i, optimization) in optimizations.iter().enumerate() {
let properties = optimization.properties();
if properties.ready && properties.score >= best_score {
best_index = Some(i);
best_score = properties.score;
}
}
best_index
}

View File

@ -1,9 +1,12 @@
mod base;
mod exploration;
mod explorer;
mod policy;
mod processor;
pub(crate) use base::*;
pub(crate) use exploration::*;
pub(crate) use explorer::*;
pub(crate) use policy::*;
pub(crate) use processor::*;
#[cfg(test)]
mod tests;

View File

@ -1,42 +1,43 @@
use super::ExecutionMode;
use crate::stream::{
store::{OptimizationId, OptimizationStore, SearchQuery},
TensorOpsDescription,
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery},
OperationDescription,
};
use std::marker::PhantomData;
/// The stream policy keeps track of all possible optimizations for the current stream.
/// The policy keeps track of all possible execution plans for the current operations.
///
/// # Details
///
/// We keep track of each new operation added to the stream and invalidate potential optimizations
/// when we see a different operation is added while keeping track of the current stream.
/// We keep track of each new operation added and invalidate potential execution plans
/// when we see a different operation is added.
///
/// Therefore, the overhead is very minimal, since the time-complexity of checking for existing
/// optimizations scales with the number of concurrent potential optimizations for the current stream,
/// execution plans scales with the number of concurrent potential plans for the current operations,
/// which isn't supposed to be big at any time.
pub(crate) struct Policy<O> {
// The potential optimizations that we could apply to the current stream, but their streams
// The potential explorations that we could apply to the current stream, but their streams
// still exceed the size of the current stream.
candidates: Vec<OptimizationId>,
// Optimizations that we find during the `updates`, but none of their `end_conditions` matches the
candidates: Vec<ExecutionPlanId>,
// Optimizations that we find during the `updates`, but none of their `trigger` matches the
// current stream.
availables: Vec<(OptimizationId, usize)>,
// Optimization that we find during the `updates` where one of its `end_condition` matches the
availables: Vec<(ExecutionPlanId, usize)>,
// Optimization that we find during the `updates` where one of its `triggers` matches the
// current stream.
found: Option<(OptimizationId, usize)>,
// The size of the stream currently analyzed.
stream_size: usize,
found: Option<(ExecutionPlanId, usize)>,
// The number of operations analyzed.
num_operations: usize,
_item_type: PhantomData<O>,
}
impl<O> Policy<O> {
/// Create a new policy.
pub(crate) fn new() -> Self {
Self {
candidates: Vec::new(),
availables: Vec::new(),
found: None,
stream_size: 0,
num_operations: 0,
_item_type: PhantomData,
}
}
@ -44,17 +45,12 @@ impl<O> Policy<O> {
/// Returns the [action](Action) that should be taken given the state of the policy.
pub fn action(
&self,
optimizations: &OptimizationStore<O>,
stream: &[TensorOpsDescription],
store: &ExecutionPlanStore<O>,
operations: &[OperationDescription],
mode: ExecutionMode,
) -> Action {
let num_minimum_analyzed = match mode {
ExecutionMode::Lazy => self.stream_size - 1,
ExecutionMode::Sync => self.stream_size,
};
if num_minimum_analyzed < stream.len() {
panic!("Internal Error: Can't retrieve the policy action when the number of operations analyzed is lower than the stream itself.");
if self.num_operations < operations.len() {
panic!("Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed.");
}
if let Some((id, _length)) = self.found {
@ -63,30 +59,27 @@ impl<O> Policy<O> {
match mode {
ExecutionMode::Lazy => {
if self.candidates.is_empty() {
// Even if there are optimizations available, we aren't sure if they are the best ones
// we can use. Exploring more optimizations might find a new `end_condition` or
// even find a better optimization.
return Action::Explore;
if !self.candidates.is_empty() {
return Action::Defer;
}
Action::Defer
Action::Explore
}
ExecutionMode::Sync => {
// If an optimization covers the _whole_ stream, we return it, else we explore new
// optimizations.
// If an execution plan covers the _whole_ operation list, we return it, else we explore new
// plans.
for (id, length) in self.availables.iter() {
if *length == stream.len() {
if *length == operations.len() {
return Action::Execute(*id);
}
}
for candidate in self.candidates.iter() {
let item = optimizations.get_unchecked(*candidate);
let item = store.get_unchecked(*candidate);
// The candidate can actually be executed, since the stream is of the same
// size.
if item.stream.len() == stream.len() {
if item.operations.len() == operations.len() {
return Action::Execute(*candidate);
}
}
@ -97,61 +90,68 @@ impl<O> Policy<O> {
}
/// Update the policy state.
pub fn update(&mut self, store: &OptimizationStore<O>, ops: &TensorOpsDescription) {
if self.stream_size == 0 {
self.candidates = store.find(SearchQuery::OptimizationsStartingWith(ops));
pub fn update(&mut self, store: &ExecutionPlanStore<O>, ops: &OperationDescription) {
if self.num_operations == 0 {
self.candidates = store.find(SearchQuery::PlansStartingWith(ops));
} else {
self.analyze_candidates(store, ops, self.stream_size);
self.analyze_candidates(store, ops);
}
self.stream_size += 1;
self.num_operations += 1;
}
// Reset the state of the policy.
pub fn reset(&mut self) {
self.candidates.clear();
self.availables.clear();
self.stream_size = 0;
self.num_operations = 0;
self.found = None;
}
fn analyze_candidates(
&mut self,
optimizations: &OptimizationStore<O>,
next_ops: &TensorOpsDescription,
stream_size: usize,
store: &ExecutionPlanStore<O>,
operation: &OperationDescription,
) {
// The index starts at zero.
let mut invalidated_candidates = Vec::new();
for id in self.candidates.iter() {
let item = optimizations.get_unchecked(*id);
let item = store.get_unchecked(*id);
if item.stream.len() == stream_size {
if item.end_conditions.contains(next_ops) {
self.found = Some((*id, item.stream.len()));
if item.operations.len() == self.num_operations + 1
&& item.operations.last().unwrap() == operation
&& item.triggers.contains(&ExecutionTrigger::Always)
{
self.found = Some((*id, item.operations.len()));
break;
}
if item.operations.len() == self.num_operations {
if item.should_stop_async(operation) {
self.found = Some((*id, item.operations.len()));
break;
} else {
// The optimization is available, but the current operation isn't an existing
// end_condition for this optimization, so we may find a better optimization by
// still growing the stream.
self.availables.push((*id, item.stream.len()));
// The plan is available, but the current operation isn't an existing
// trigger for this plan, so we may find a better plan by
// still growing the operation list.
self.availables.push((*id, item.operations.len()));
invalidated_candidates.push(*id);
continue;
}
};
let next_ops_candidate = match item.stream.get(stream_size) {
let operation_candidate = match item.operations.get(self.num_operations) {
Some(val) => val,
None => {
// Stream of different size, invalidated.
// Operation list of different size, invalidated.
invalidated_candidates.push(*id);
continue;
}
};
if next_ops_candidate != next_ops {
// Stream with different node at the current position, invalidated.
if operation_candidate != operation {
// Operation list with different node at the current position, invalidated.
invalidated_candidates.push(*id);
continue;
}
@ -170,31 +170,34 @@ impl<O> Policy<O> {
/// Action to be made depending on the stream.
#[derive(PartialEq, Eq, Debug)]
pub enum Action {
/// Continue exploring optimizations using the [builder](crate::OptimizationBuilder).
/// Continue exploring using the [builder](crate::OptimizationBuilder).
Explore,
/// The current policy indicates that an optimization may be possible in the future, so the
/// The current policy indicates that an explocation may be possible in the future, so the
/// best action is to defer any execution.
///
/// Sometimes, it can be a false positive and a new optimization should be built from scratch.
/// Sometimes, it can be a false positive and a new exploration should be built from scratch.
/// Therefore it's important to keep the previous operations to rebuild the state if it
/// happens.
Defer,
/// An optimization has been found, and the best action is to execute it!
Execute(OptimizationId),
/// An exploration has been found, and the best action is to execute it!
Execute(ExecutionPlanId),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
stream::{store::OptimizationItem, FloatOpsDescription, UnaryOpsDescription},
stream::{
store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
FloatOperationDescription, UnaryOperationDescription,
},
TensorDescription, TensorId, TensorStatus,
};
use std::ops::Range;
#[test]
fn given_no_optimization_should_explore() {
let store = OptimizationStore::default();
let store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(3);
@ -208,14 +211,14 @@ mod tests {
#[test]
fn given_existing_optimization_when_sync_should_execute_optim() {
let mut store = OptimizationStore::default();
let mut store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(2);
let id = store.add(OptimizationItem {
stream: stream.operations.clone(),
end_conditions: Vec::new(),
value: (),
let id = store.add(ExecutionPlan {
operations: stream.operations.clone(),
triggers: Vec::new(),
strategy: ExecutionStrategy::Operations,
});
stream.assert_updates(
@ -230,15 +233,18 @@ mod tests {
}
#[test]
fn given_existing_optimization_when_found_end_condition_should_execute_optim() {
let mut store = OptimizationStore::default();
fn given_existing_plan_when_found_trigger_should_execute_plan() {
let mut store = ExecutionPlanStore::default();
let mut policy = Policy::new();
let stream = TestStream::new(3);
let id = store.add(OptimizationItem {
stream: stream.operations[0..2].to_vec(),
end_conditions: stream.operations[2..3].to_vec(),
value: (),
let id = store.add(ExecutionPlan {
operations: stream.operations[0..2].to_vec(),
triggers: stream.operations[2..3]
.iter()
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
.collect(),
strategy: ExecutionStrategy::Operations,
});
stream.assert_updates(
@ -256,8 +262,8 @@ mod tests {
}
#[test]
fn should_support_multiple_end_conditions() {
let mut store = OptimizationStore::default();
fn should_support_multiple_triggers() {
let mut store = ExecutionPlanStore::default();
let mut policy_1 = Policy::new();
let mut policy_2 = Policy::new();
@ -265,18 +271,18 @@ mod tests {
let mut stream_2 = TestStream::new(2);
// Create different end operation for each stream.
let end_condition_id_1 = 5;
let end_condition_id_2 = 5;
stream_1.new_ops(end_condition_id_1);
stream_2.new_ops(end_condition_id_2);
let trigger_id_1 = 5;
let trigger_id_2 = 6;
stream_1.new_ops(trigger_id_1);
stream_2.new_ops(trigger_id_2);
let id = store.add(OptimizationItem {
stream: stream_1.operations[0..2].to_vec(),
end_conditions: vec![
stream_1.operations[2].clone(),
stream_2.operations[2].clone(),
let id = store.add(ExecutionPlan {
operations: stream_1.operations[0..2].to_vec(),
triggers: vec![
ExecutionTrigger::OnOperation(stream_1.operations[2].clone()),
ExecutionTrigger::OnOperation(stream_2.operations[2].clone()),
],
value: (),
strategy: ExecutionStrategy::Operations,
});
stream_1.assert_updates(
@ -295,20 +301,20 @@ mod tests {
stream_1.assert_updates(
&store,
&mut policy_1,
AssertUpdatesOptions::OperationsIndex(2..3), // First end condition.
AssertUpdatesOptions::OperationsIndex(2..3), // First trigger.
Action::Execute(id),
);
stream_2.assert_updates(
&store,
&mut policy_2,
AssertUpdatesOptions::OperationsIndex(2..3), // Second end condition.
AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger.
Action::Execute(id),
);
}
#[test]
fn should_select_right_optimization() {
let mut store = OptimizationStore::default();
let mut store = ExecutionPlanStore::default();
let mut policy_1 = Policy::new();
let mut policy_2 = Policy::new();
@ -322,15 +328,21 @@ mod tests {
stream_2.new_ops(5);
stream_2.new_ops(6);
let optimization_stream_1 = store.add(OptimizationItem {
stream: stream_1.operations[0..3].to_vec(),
end_conditions: stream_1.operations[3..4].to_vec(),
value: (),
let optimization_stream_1 = store.add(ExecutionPlan {
operations: stream_1.operations[0..3].to_vec(),
triggers: stream_1.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
.collect(),
strategy: ExecutionStrategy::Operations,
});
let optimization_stream_2 = store.add(OptimizationItem {
stream: stream_2.operations[0..3].to_vec(),
end_conditions: stream_2.operations[3..4].to_vec(),
value: (),
let optimization_stream_2 = store.add(ExecutionPlan {
operations: stream_2.operations[0..3].to_vec(),
triggers: stream_2.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
.collect(),
strategy: ExecutionStrategy::Operations,
});
assert_ne!(optimization_stream_1, optimization_stream_2);
@ -363,16 +375,19 @@ mod tests {
#[test]
fn should_invalidate_wrong_optimizations() {
let mut store = OptimizationStore::default();
let mut store = ExecutionPlanStore::default();
let stream_1 = TestStream::new(4);
let mut stream_2 = TestStream::new(2);
stream_2.new_ops(6);
stream_2.new_ops(7);
store.add(OptimizationItem {
stream: stream_1.operations[0..3].to_vec(),
end_conditions: stream_1.operations[3..4].to_vec(),
value: (),
store.add(ExecutionPlan {
operations: stream_1.operations[0..3].to_vec(),
triggers: stream_1.operations[3..4]
.iter()
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
.collect(),
strategy: ExecutionStrategy::Operations,
});
let mut policy = Policy::new();
@ -396,7 +411,7 @@ mod tests {
#[derive(Default, Debug)]
struct TestStream {
tensors: Vec<TensorDescription>,
operations: Vec<TensorOpsDescription>,
operations: Vec<OperationDescription>,
}
#[derive(Debug)]
@ -418,7 +433,7 @@ mod tests {
/// The first follow should only be cache miss.
pub fn assert_updates(
&self,
optimizations: &OptimizationStore<()>,
optimizations: &ExecutionPlanStore<()>,
policy: &mut Policy<()>,
options: AssertUpdatesOptions,
action: Action,
@ -448,7 +463,7 @@ mod tests {
self.new_empty_node(out_id);
self.operations
.push(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
.push(OperationDescription::Float(FloatOperationDescription::Log(
self.unary_description(),
)));
}
@ -461,10 +476,10 @@ mod tests {
});
}
fn unary_description(&self) -> UnaryOpsDescription {
fn unary_description(&self) -> UnaryOperationDescription {
let size = self.tensors.len();
UnaryOpsDescription {
UnaryOperationDescription {
input: self.tensors[size - 2].clone(),
out: self.tensors[size - 1].clone(),
}

View File

@ -1,44 +1,53 @@
use super::{ExecutionMode, Exploration, Explorer};
use crate::stream::execution::{Action, Policy};
use crate::stream::store::{OptimizationId, OptimizationItem, OptimizationStore};
use crate::stream::{Stream, TensorOpsDescription};
use crate::{FusionBackend, HandleContainer, OptimizationBuilder};
use crate::stream::store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
};
use crate::stream::OperationDescription;
use crate::OptimizationBuilder;
/// Process the [stream](Stream) following a [policy](Policy).
///
/// Explore and create new optimizations using explorations
pub(crate) struct Processor<B: FusionBackend> {
policy: Policy<B::Optimization>,
explorer: Explorer<B>,
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
pub(crate) struct Processor<O> {
policy: Policy<O>,
explorer: Explorer<O>,
}
impl<B: FusionBackend> Processor<B> {
/// A part of a stream that can be executed partially using [execution plan](ExecutionPlan).
pub(crate) trait StreamSegment<O> {
/// The operations in the segment.
fn operations(&self) -> &[OperationDescription];
/// Execute part of the segment using the given plan id.
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<O>);
}
impl<O> Processor<O> {
/// Create a new stream processor.
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<O>>>) -> Self {
Self {
policy: Policy::new(),
explorer: Explorer::new(optimizations),
}
}
/// Process the [stream](Stream) with the provided mode.
pub fn process(
/// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode).
pub fn process<Segment>(
&mut self,
stream: &mut Stream<B>,
optimizations: &mut OptimizationStore<B::Optimization>,
handles: &mut HandleContainer<B>,
mut segment: Segment,
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
) {
) where
Segment: StreamSegment<O>,
{
loop {
if stream.is_empty() {
if segment.operations().is_empty() {
break;
}
match self.action(optimizations, stream, mode) {
match self.action(store, segment.operations(), mode) {
Action::Explore => {
self.explore(stream, optimizations, handles, mode);
self.explore(&mut segment, store, mode);
if self.explorer.up_to_date() {
if self.explorer.is_up_to_date() {
break;
}
}
@ -51,8 +60,12 @@ impl<B: FusionBackend> Processor<B> {
};
}
Action::Execute(id) => {
stream.execute(Some(id), handles, optimizations);
self.reset(optimizations, stream);
if let ExecutionMode::Sync = mode {
store.add_trigger(id, ExecutionTrigger::OnSync);
}
segment.execute(id, store);
self.reset(store, segment.operations());
}
};
@ -62,21 +75,34 @@ impl<B: FusionBackend> Processor<B> {
}
}
fn explore(
fn explore<Item: StreamSegment<O>>(
&mut self,
stream: &mut Stream<B>,
optimizations: &mut OptimizationStore<B::Optimization>,
handles: &mut HandleContainer<B>,
item: &mut Item,
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
) {
match self.explorer.explore(stream, mode) {
Exploration::OptimizationFound(optim) => {
let id = optim.map(|optim| {
Self::on_new_optimization(&self.policy, stream, optimizations, optim, mode)
});
stream.execute(id, handles, optimizations);
self.reset(optimizations, stream);
match self.explorer.explore(item.operations(), mode) {
Exploration::Found(optim) => {
let id = Self::on_optimization_found(
&self.policy,
item.operations(),
store,
optim,
mode,
);
item.execute(id, store);
self.reset(store, item.operations());
}
Exploration::NotFound { num_explored } => {
let id = Self::on_optimization_not_found(
&self.policy,
item.operations(),
store,
mode,
num_explored,
);
item.execute(id, store);
self.reset(store, item.operations());
}
Exploration::Continue => {
if let ExecutionMode::Sync = mode {
@ -86,87 +112,105 @@ impl<B: FusionBackend> Processor<B> {
}
}
fn reset(&mut self, store: &mut OptimizationStore<B::Optimization>, stream: &Stream<B>) {
self.explorer.reset(stream);
fn reset(&mut self, store: &mut ExecutionPlanStore<O>, operations: &[OperationDescription]) {
self.explorer.reset(operations);
self.policy.reset();
// Reset the policy state.
for i in 0..stream.relative.len() {
self.policy.update(store, &stream.relative[i]);
for operation in operations.iter() {
self.policy.update(store, operation);
}
}
fn action(
&mut self,
cache: &OptimizationStore<B::Optimization>,
stream: &Stream<B>,
store: &ExecutionPlanStore<O>,
operations: &[OperationDescription],
mode: ExecutionMode,
) -> Action {
let (stream, next_ops) = Self::split_stream_ref(stream, mode);
if let ExecutionMode::Lazy = mode {
// We update the policy in lazy mode, since
self.policy.update(
store,
operations
.last()
.expect("At least one operation in the operation list."),
);
};
if let Some(next_ops) = next_ops {
self.policy.update(cache, next_ops)
}
self.policy.action(cache, stream, mode)
self.policy.action(store, operations, mode)
}
fn split_stream_owned(
stream: &Stream<B>,
fn on_optimization_found(
policy: &Policy<O>,
operations: &[OperationDescription],
store: &mut ExecutionPlanStore<O>,
builder: &dyn OptimizationBuilder<O>,
mode: ExecutionMode,
) -> (Vec<TensorOpsDescription>, Option<TensorOpsDescription>) {
) -> ExecutionPlanId {
let num_fused = builder.len();
let relative = &operations[0..num_fused];
match mode {
ExecutionMode::Lazy => {
let stream = stream.split_relative_stream();
(stream.0.to_vec(), stream.1.cloned())
}
ExecutionMode::Sync => (stream.relative.clone(), None),
}
}
let next_ops = operations.get(num_fused);
fn split_stream_ref(
stream: &Stream<B>,
mode: ExecutionMode,
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
match mode {
ExecutionMode::Lazy => stream.split_relative_stream(),
ExecutionMode::Sync => (stream.relative.as_slice(), None),
}
}
let trigger = if let Some(next_ops) = next_ops {
ExecutionTrigger::OnOperation(next_ops.clone())
} else {
// Happens if the next ops is included in the fused operation, and there is no
// way the builder can still continue fusing.
ExecutionTrigger::Always
};
fn on_new_optimization(
policy: &Policy<B::Optimization>,
stream: &Stream<B>,
store: &mut OptimizationStore<B::Optimization>,
builder: &dyn OptimizationBuilder<B>,
mode: ExecutionMode,
) -> OptimizationId {
let (stream_relative, next_ops) = Self::split_stream_owned(stream, mode);
// Check if an optimization is available for this stream before creating a new optimization.
//
// Specify a sync execution mode signaling that we want to know if an optimization is
// available right now even if it isn't the best one.
match policy.action(store, &stream_relative, ExecutionMode::Sync) {
Action::Execute(id) => {
// When we are in lazy mode, a next operation will be available.
//
// Since we are adding new optimization only when the policy action is explore, we
// know the existing optimization wasn't flagged as optimal, since the `next_ops'
// wasn't included in the `end_conditions`.
//
// But in this case, we aren't able to actually find a better optimization, so we
// flag the next ops as a stopping criteria, so we won't enter exploration mode the
// next time we see a similar stream following the same pattern.
if let Some(next_ops) = next_ops {
store.add_end_condition(id, next_ops);
match policy.action(store, relative, ExecutionMode::Sync) {
Action::Execute(id) => {
store.add_trigger(id, trigger);
id
}
_ => store.add(ExecutionPlan {
operations: relative.to_vec(),
triggers: vec![trigger],
strategy: ExecutionStrategy::Optimization(builder.build()),
}),
}
}
ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) {
Action::Execute(id) => {
store.add_trigger(id, ExecutionTrigger::OnSync);
id
}
_ => store.add(ExecutionPlan {
operations: operations.to_vec(),
triggers: vec![ExecutionTrigger::OnSync],
strategy: ExecutionStrategy::Optimization(builder.build()),
}),
},
}
}
fn on_optimization_not_found(
policy: &Policy<O>,
operations: &[OperationDescription],
store: &mut ExecutionPlanStore<O>,
mode: ExecutionMode,
num_explored: usize,
) -> ExecutionPlanId {
let relative = &operations[0..num_explored];
let trigger = match mode {
ExecutionMode::Lazy => ExecutionTrigger::Always,
ExecutionMode::Sync => ExecutionTrigger::OnSync,
};
match policy.action(store, relative, ExecutionMode::Sync) {
Action::Execute(id) => {
store.add_trigger(id, trigger);
id
}
_ => store.add(OptimizationItem {
stream: stream_relative,
end_conditions: next_ops.map(|op| vec![op]).unwrap_or_default(),
value: builder.build(),
_ => store.add(ExecutionPlan {
operations: relative.to_vec(),
triggers: vec![trigger],
strategy: ExecutionStrategy::Operations,
}),
}
}

View File

@ -0,0 +1,408 @@
//! A testing module that ensures the correctness of the explorer, policy, and processor.
//!
//! The primary focus is on validating the seamless interaction between these three components to
//! execute and optimize a stream of operations accurately.
//!
//! To test these components effectively, we create mock types for the stream, optimization,
//! optimization builder, and stream segment. These mock types aid in comprehensively
//! understanding the process of optimizing streams.
use crate::{
stream::{
store::{
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
},
BinaryOperationDescription, NumericOperationDescription, OperationDescription,
ScalarOperationDescription,
},
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
TensorStatus,
};
use super::*;
/// A fake stream of operations for testing purpose.
struct TestStream {
processor: Processor<TestOptimization>,
store: ExecutionPlanStore<TestOptimization>,
executed: Vec<ExecutionPlanId>,
operations: Vec<OperationDescription>,
}
/// A fake [optimization builder](OptimizationBuilder) for testing purpose.
struct TestOptimizationBuilder {
builder_id: usize,
expected_operations: Vec<OperationDescription>,
expected_trigger: ExecutionTrigger,
actual: Vec<OperationDescription>,
}
/// A fake optimization for testing purpose.
#[derive(new, Debug, PartialEq)]
struct TestOptimization {
builder_id: usize,
size: usize,
}
/// A fake [stream segment](StreamSegment) for testing purpose.
#[derive(new)]
struct TestSegment<'i> {
operations: &'i mut Vec<OperationDescription>,
executed: &'i mut Vec<ExecutionPlanId>,
}
/// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions.
///
/// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is
/// crucial to verify that the stream's state is correctly updated when various cases occur consecutively.
///
/// Although it might complicate identifying the source of a bug in the code, having this comprehensive
/// test case covers nearly all aspects of the implementation, while remaining easy to read and
/// maintainable.
#[test]
fn should_support_complex_stream() {
// We have 2 different optimization builders in this test case.
let builder_id_1 = 0;
let builder_id_2 = 1;
// We will have a total of 3 execution plans to execute.
let plan_id_1 = 0;
let plan_id_2 = 1;
let plan_id_3 = 2;
// The first builder only contains 2 operations, and the optimization is always available when
// the pattern is met.
let builder_1 = TestOptimizationBuilder::new(
builder_id_1,
vec![operation_1(), operation_2()],
ExecutionTrigger::Always,
);
// The second builder also contains 2 operations, but only becomes available when an operation
// is met.
let builder_2 = TestOptimizationBuilder::new(
builder_id_2,
vec![operation_2(), operation_2()],
ExecutionTrigger::OnOperation(operation_1()),
);
// We finally build the stream with those optimization builders.
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
// builder_1 is still waiting to see next op is operation_2
// builder_2 is closed because it's not the right operation
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(0);
// No optimization found for the first two operations.
stream.add(operation_1());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(1);
stream.assert_last_executed(plan_id_1);
stream.assert_plan(
plan_id_1,
ExecutionPlan {
operations: vec![operation_1(), operation_1()],
triggers: vec![ExecutionTrigger::Always],
strategy: ExecutionStrategy::Operations,
},
);
// Nothing to execute.
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(1);
// Now we should trigger the first optimization builder.
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(2);
stream.assert_last_executed(plan_id_2);
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![ExecutionTrigger::Always],
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_1, 2)),
},
);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(2);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(2);
// Now we should trigger the second optimization builder.
stream.add(operation_1());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(3);
stream.assert_last_executed(plan_id_3);
stream.assert_plan(
plan_id_3,
ExecutionPlan {
operations: vec![operation_2(), operation_2()],
triggers: vec![ExecutionTrigger::OnOperation(operation_1())],
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_2, 2)),
},
);
// Now we should trigger the first optimization builder (second plan).
stream.add(operation_2());
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(4);
stream.assert_last_executed(plan_id_2);
stream.assert_plan(
plan_id_2,
ExecutionPlan {
operations: vec![operation_1(), operation_2()],
triggers: vec![ExecutionTrigger::Always],
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_1, 2)),
},
);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(1);
stream.assert_number_of_executions(4);
// Nothing to execute.
stream.add(operation_2());
stream.assert_number_of_operations(2);
stream.assert_number_of_executions(4);
// On sync we should execute all operations even if their trigger isn't met.
// In this case the optimization from builder 2 (plan 3).
stream.sync();
stream.assert_number_of_operations(0);
stream.assert_number_of_executions(5);
stream.assert_last_executed(plan_id_3);
stream.assert_plan(
plan_id_3,
ExecutionPlan {
operations: vec![operation_2(), operation_2()],
triggers: vec![
ExecutionTrigger::OnOperation(operation_1()),
ExecutionTrigger::OnSync, // We also add OnSync in the triggers.
],
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_2, 2)),
},
);
}
impl TestStream {
/// Create a new stream with the given optimization builders.
fn new(optimizations: Vec<Box<dyn OptimizationBuilder<TestOptimization>>>) -> Self {
Self {
processor: Processor::<TestOptimization>::new(optimizations),
store: ExecutionPlanStore::<TestOptimization>::new(),
executed: Vec::new(),
operations: Vec::new(),
}
}
/// Add an operation to the stream.
fn add(&mut self, operation: OperationDescription) {
self.operations.push(operation);
self.processor.process(
TestSegment::new(&mut self.operations, &mut self.executed),
&mut self.store,
ExecutionMode::Lazy,
);
}
/// Sync the stream.
fn sync(&mut self) {
self.processor.process(
TestSegment::new(&mut self.operations, &mut self.executed),
&mut self.store,
ExecutionMode::Sync,
);
}
/// Assert that the plan has been executed as provided.
fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan<TestOptimization>) {
let actual = self.store.get_unchecked(id);
assert_eq!(actual.triggers, expected.triggers);
assert_eq!(actual.operations, expected.operations);
}
/// Assert that the given plan id has been the last executed.
fn assert_last_executed(&self, id: ExecutionPlanId) {
match self.executed.last() {
Some(last_id) => assert_eq!(*last_id, id),
None => panic!("No plan has been executed"),
}
}
/// Assert the number of executions since the start of the stream.
fn assert_number_of_executions(&self, number: usize) {
assert_eq!(self.executed.len(), number);
}
/// Assert the number of operations queued.
fn assert_number_of_operations(&self, number: usize) {
assert_eq!(self.operations.len(), number);
}
}
impl TestOptimizationBuilder {
/// Create a new optimization builder that follows a pattern with a trigger.
fn new(
builder_id: usize,
operations: Vec<OperationDescription>,
trigger: ExecutionTrigger,
) -> Self {
Self {
builder_id,
expected_operations: operations,
actual: Vec::new(),
expected_trigger: trigger,
}
}
}
impl OptimizationBuilder<TestOptimization> for TestOptimizationBuilder {
/// Register a new operation.
fn register(&mut self, operation: &OperationDescription) {
self.actual.push(operation.clone());
}
/// Build the optimization.
fn build(&self) -> TestOptimization {
TestOptimization::new(self.builder_id, self.len())
}
/// Reset the state.
fn reset(&mut self) {
self.actual.clear();
}
/// Return the optimization status.
fn status(&self) -> OptimizationStatus {
let actual_equal_expected = self.actual == self.expected_operations;
if self.actual.len() < self.expected_operations.len() {
let operations = &self.expected_operations[0..self.actual.len()];
return match self.actual == operations {
// Still optimizing.
true => OptimizationStatus::Open,
// Never gonna be possible on that stream.
false => OptimizationStatus::Closed,
};
}
if self.actual.len() == self.expected_operations.len() && actual_equal_expected {
return match self.expected_trigger {
// Stop right away.
ExecutionTrigger::Always => OptimizationStatus::Closed,
// Wait for the next operation to show up.
ExecutionTrigger::OnOperation(_) => OptimizationStatus::Open,
// Doesn't matter on sync, even open should trigger a build if possible.
ExecutionTrigger::OnSync => OptimizationStatus::Open,
};
}
OptimizationStatus::Closed
}
/// Return the properties of this optimization.
fn properties(&self) -> OptimizationProperties {
if self.actual.len() < self.expected_operations.len() {
// Optimization not possible.
return OptimizationProperties {
score: 0,
ready: false,
};
}
let stream_is_ok =
self.actual[0..self.expected_operations.len()] == self.expected_operations;
if !stream_is_ok {
// Optimization not possible.
return OptimizationProperties {
score: 0,
ready: false,
};
}
// Optimization possible.
OptimizationProperties {
score: 1,
ready: true,
}
}
// The number of operations that should be handle by the optimization.
fn len(&self) -> usize {
self.expected_operations.len()
}
}
impl<'i> StreamSegment<TestOptimization> for TestSegment<'i> {
// The operations in the process.
fn operations(&self) -> &[OperationDescription] {
self.operations
}
// Execute the process.
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<TestOptimization>) {
let execution_plan = store.get_unchecked(id);
match &execution_plan.strategy {
ExecutionStrategy::Optimization(optimization) => {
self.operations.drain(0..optimization.size);
}
ExecutionStrategy::Operations => self.operations.clear(),
};
self.executed.push(id);
}
}
/// Just a simple operation.
fn operation_1() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Add(
BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
rhs: TensorDescription {
id: TensorId::new(1),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
},
))
}
/// Just a simple operation.
fn operation_2() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
ScalarOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
rhs: 5.0,
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
},
))
}

View File

@ -4,9 +4,9 @@ pub(crate) mod store;
mod base;
mod context;
mod multi;
mod ops;
mod operation;
pub use base::*;
pub use context::*;
pub use multi::*;
pub use ops::*;
pub use operation::*;

View File

@ -1,7 +1,7 @@
use super::{
execution::{ExecutionMode, Processor},
store::OptimizationStore,
Ops, Stream, TensorOpsDescription,
execution::{ExecutionMode, Processor, StreamSegment},
store::{ExecutionPlanId, ExecutionPlanStore},
Operation, OperationDescription, OperationQueue,
};
use crate::{FusionBackend, HandleContainer};
@ -9,37 +9,31 @@ use crate::{FusionBackend, HandleContainer};
///
/// TODO: Actually support multiple streams.
pub struct MultiStream<B: FusionBackend> {
items: Vec<Item<B>>,
optimizations: OptimizationStore<B::Optimization>,
}
struct Item<B: FusionBackend> {
stream: Stream<B>,
executor: Processor<B>,
streams: Vec<Stream<B>>,
optimizations: ExecutionPlanStore<B::Optimization>,
}
impl<B: FusionBackend> MultiStream<B> {
pub(crate) fn new(device: B::FusionDevice) -> Self {
Self {
items: vec![Item::new(device)],
optimizations: OptimizationStore::new(),
streams: vec![Stream::new(device)],
optimizations: ExecutionPlanStore::new(),
}
}
/// Register a new tensor operation.
pub fn register(
&mut self,
ops_desc: TensorOpsDescription,
ops: Box<dyn Ops<B>>,
desc: OperationDescription,
operation: Box<dyn Operation<B>>,
handles: &mut HandleContainer<B>,
) {
// TODO: Support more than only one stream.
if let Some(item) = self.items.first_mut() {
item.stream.add(ops_desc, ops);
item.executor.process(
&mut item.stream,
if let Some(item) = self.streams.first_mut() {
item.queue.add(desc, operation);
item.processor.process(
Segment::new(&mut item.queue, handles),
&mut self.optimizations,
handles,
ExecutionMode::Lazy,
);
};
@ -47,22 +41,42 @@ impl<B: FusionBackend> MultiStream<B> {
/// Drain the streams.
pub fn drain(&mut self, handles: &mut HandleContainer<B>) {
self.items.iter_mut().for_each(|item| {
item.executor.process(
&mut item.stream,
self.streams.iter_mut().for_each(|item| {
item.processor.process(
Segment::new(&mut item.queue, handles),
&mut self.optimizations,
handles,
ExecutionMode::Sync,
);
});
}
}
impl<B: FusionBackend> Item<B> {
struct Stream<B: FusionBackend> {
queue: OperationQueue<B>,
processor: Processor<B::Optimization>,
}
#[derive(new)]
struct Segment<'a, B: FusionBackend> {
queue: &'a mut OperationQueue<B>,
handles: &'a mut HandleContainer<B>,
}
impl<'i, B: FusionBackend> StreamSegment<B::Optimization> for Segment<'i, B> {
fn operations(&self) -> &[OperationDescription] {
&self.queue.relative
}
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<B::Optimization>) {
self.queue.execute(id, self.handles, store)
}
}
impl<B: FusionBackend> Stream<B> {
fn new(device: B::FusionDevice) -> Self {
Self {
executor: Processor::new(B::optimizations(device.into())),
stream: Stream::new(),
processor: Processor::new(B::optimizations(device.into())),
queue: OperationQueue::new(),
}
}
}

View File

@ -0,0 +1,111 @@
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
use crate::stream::OperationDescription;
use serde::{Deserialize, Serialize};
/// The store that contains all explorations done on a device.
#[derive(Default, Serialize, Deserialize)]
pub(crate) struct ExecutionPlanStore<O> {
plans: Vec<ExecutionPlan<O>>,
index: ExecutionPlanIndex,
}
/// How a list of operations should be executed.
#[derive(PartialEq, Debug, Serialize, Deserialize, Clone)]
pub(crate) enum ExecutionStrategy<O> {
/// An optimization was found, and therefore should be executed.
Optimization(O),
/// No optimization was found, each operation should be executed individually.
Operations,
}
/// The trigger that indicates when to stop exploring.
#[allow(clippy::large_enum_variant)]
// Triggers are stored in a list, and you can have many `OnOperation` entries,
// but only one `OnSync` entry and one `Always` entry, therefore we don't care if it takes more
// space to store them.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub(crate) enum ExecutionTrigger {
OnOperation(OperationDescription),
OnSync,
Always,
}
/// The unique identifier for an exploration that was executed.
pub(crate) type ExecutionPlanId = usize;
/// The outcome of an exploration that can be stored.
#[derive(Serialize, Deserialize)]
pub(crate) struct ExecutionPlan<O> {
/// The operations on which the exploration is related to.
pub(crate) operations: Vec<OperationDescription>,
/// The criteria that signal when this plan should be executed. Only one trigger is necessary.
pub(crate) triggers: Vec<ExecutionTrigger>,
/// The strategy that should be used when executing this plan.
pub(crate) strategy: ExecutionStrategy<O>,
}
impl<O> ExecutionPlan<O> {
/// Whether exploration should be stop in an async mode.
pub fn should_stop_async(&self, ops: &OperationDescription) -> bool {
for item in self.triggers.iter() {
match item {
ExecutionTrigger::OnOperation(val) => {
if val == ops {
return true;
}
}
ExecutionTrigger::Always => return true,
ExecutionTrigger::OnSync => continue,
}
}
false
}
}
impl<O> ExecutionPlanStore<O> {
pub fn new() -> Self {
Self {
plans: Vec::new(),
index: ExecutionPlanIndex::default(),
}
}
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
self.index.find(query)
}
pub fn add(&mut self, exploration: ExecutionPlan<O>) -> ExecutionPlanId {
if exploration.operations.is_empty() {
panic!("Can't add an empty optimization.");
}
let id = self.plans.len();
self.index.insert(InsertQuery::NewPlan {
operations: &exploration.operations,
id,
});
self.plans.push(exploration);
id
}
pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan<O> {
&mut self.plans[id]
}
pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan<O> {
&self.plans[id]
}
/// Add a new end condition for an optimization.
pub fn add_trigger(&mut self, id: ExecutionPlanId, criterion: ExecutionTrigger) {
let criteria = &mut self.plans[id].triggers;
if !criteria.contains(&criterion) {
criteria.push(criterion);
}
}
}

View File

@ -1,4 +1,4 @@
use crate::stream::{store::OptimizationId, TensorOpsDescription};
use crate::stream::{store::ExecutionPlanId, OperationDescription};
use serde::{Deserialize, Serialize};
use std::{
collections::{hash_map::DefaultHasher, HashMap},
@ -7,52 +7,51 @@ use std::{
/// Index used to search optimizations.
#[derive(Default, Serialize, Deserialize, Clone)]
pub struct OptimizationIndex {
/// We can't use `HashMap<TensorOpsDescription, Vec<OptimizationId>>` since `TensorOpsDescription`
pub struct ExecutionPlanIndex {
/// We can't use `HashMap<OperationDescription, Vec<ExecutionPlanId>>` since `OperationDescription`
/// doesn't implement [`Eq`](core::cmp::Eq).
///
/// `TensorOpsDescription` can't implement `Eq` since float types don't implement it.
/// `OperationDescription` can't implement `Eq` since float types don't implement it.
///
/// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions.
/// This is OK because we use `relative` streams where any scalar values are set to zeros,
/// This is OK because we use `relative` operations where any scalar values are set to zeros,
/// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter).
mapping: HashMap<u64, Vec<(TensorOpsDescription, usize)>>,
starters: Vec<Vec<OptimizationId>>,
mapping: HashMap<u64, Vec<(OperationDescription, usize)>>,
starters: Vec<Vec<ExecutionPlanId>>,
}
pub enum SearchQuery<'a> {
OptimizationsStartingWith(&'a TensorOpsDescription),
PlansStartingWith(&'a OperationDescription),
}
pub enum InsertQuery<'a> {
NewOptimization {
stream: &'a [TensorOpsDescription],
id: OptimizationId,
NewPlan {
operations: &'a [OperationDescription],
id: ExecutionPlanId,
},
}
impl OptimizationIndex {
impl ExecutionPlanIndex {
/// Search optimizations with the given [query](SearchQuery).
pub fn find(&self, query: SearchQuery<'_>) -> Vec<OptimizationId> {
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
match query {
SearchQuery::OptimizationsStartingWith(ops) => self.find_starting_with(ops),
SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops),
}
}
/// Register a new optimization with the given [query](InsertQuery).
pub fn insert(&mut self, query: InsertQuery<'_>) {
match query {
InsertQuery::NewOptimization { stream, id } => self.insert_new_ops(
stream
.first()
.expect("An optimization should never have an empty stream."),
id,
),
InsertQuery::NewPlan { operations, id } => {
if let Some(operation) = operations.first() {
self.insert_new_operation(operation, id)
}
}
}
}
fn find_starting_with(&self, ops: &TensorOpsDescription) -> Vec<OptimizationId> {
let key = self.stream_key(ops);
fn find_starting_with(&self, operation: &OperationDescription) -> Vec<ExecutionPlanId> {
let key = self.operation_key(operation);
let values = match self.mapping.get(&key) {
Some(val) => val,
None => return Vec::new(),
@ -62,7 +61,7 @@ impl OptimizationIndex {
return Vec::new();
}
let (_, index) = match values.iter().find(|value| &value.0 == ops) {
let (_, index) = match values.iter().find(|value| &value.0 == operation) {
Some(val) => val,
None => return Vec::new(),
};
@ -75,8 +74,8 @@ impl OptimizationIndex {
val
}
fn insert_new_ops(&mut self, ops: &TensorOpsDescription, new_id: OptimizationId) {
let key = self.stream_key(ops);
fn insert_new_operation(&mut self, ops: &OperationDescription, new_id: ExecutionPlanId) {
let key = self.operation_key(ops);
let values = match self.mapping.get_mut(&key) {
Some(val) => val,
None => {
@ -106,8 +105,8 @@ impl OptimizationIndex {
.push(new_id);
}
// Hash the value of the first operation in a stream.
fn stream_key(&self, ops: &TensorOpsDescription) -> u64 {
// Hash the value of the first operation in a list.
fn operation_key(&self, ops: &OperationDescription) -> u64 {
let mut hasher = DefaultHasher::new();
ops.hash(&mut hasher);
hasher.finish()
@ -118,80 +117,82 @@ impl OptimizationIndex {
mod tests {
use super::*;
use crate::{
stream::{BinaryOpsDescription, NumericOpsDescription, ScalarOpsDescription},
stream::{
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
},
TensorDescription, TensorId, TensorStatus,
};
#[test]
fn should_find_optimization_id_based_on_tensor_ops() {
let mut index = OptimizationIndex::default();
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1()];
let optimization_id_1 = 0;
index.insert(InsertQuery::NewOptimization {
stream: &stream_1,
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
#[test]
fn should_support_multiple_optimization_ids_with_same_starting_ops() {
let mut index = OptimizationIndex::default();
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_2(), ops_1()];
let stream_2 = [ops_1(), ops_1(), ops_2()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
index.insert(InsertQuery::NewOptimization {
stream: &stream_1,
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewOptimization {
stream: &stream_2,
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1, optimization_id_2]);
}
#[test]
fn should_only_find_optimization_with_correct_starting_ops() {
let mut index = OptimizationIndex::default();
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_1()];
let stream_2 = [ops_2(), ops_1()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
index.insert(InsertQuery::NewOptimization {
stream: &stream_1,
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewOptimization {
stream: &stream_2,
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
#[test]
fn should_handle_hash_collisions() {
let mut index = OptimizationIndex::default();
let mut index = ExecutionPlanIndex::default();
let stream_1 = [ops_1(), ops_1()];
let stream_2 = [ops_3(), ops_1()];
let optimization_id_1 = 0;
let optimization_id_2 = 1;
let stream_1_key = index.stream_key(&stream_1[0]);
let stream_2_key = index.stream_key(&stream_2[0]);
let stream_1_key = index.operation_key(&stream_1[0]);
let stream_2_key = index.operation_key(&stream_2[0]);
assert_eq!(
stream_1_key, stream_2_key,
@ -199,43 +200,45 @@ mod tests {
);
assert_ne!(stream_1[0], stream_2[0], "Ops 1 and Ops 3 are different.");
index.insert(InsertQuery::NewOptimization {
stream: &stream_1,
index.insert(InsertQuery::NewPlan {
operations: &stream_1,
id: optimization_id_1,
});
index.insert(InsertQuery::NewOptimization {
stream: &stream_2,
index.insert(InsertQuery::NewPlan {
operations: &stream_2,
id: optimization_id_2,
});
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
assert_eq!(found, vec![optimization_id_1]);
}
fn ops_1() -> TensorOpsDescription {
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::Add(BinaryOpsDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
fn ops_1() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Add(
BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
rhs: TensorDescription {
id: TensorId::new(1),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
},
rhs: TensorDescription {
id: TensorId::new(1),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
}))
))
}
fn ops_2() -> TensorOpsDescription {
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::AddScalar(
ScalarOpsDescription {
fn ops_2() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
ScalarOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -251,23 +254,25 @@ mod tests {
))
}
fn ops_3() -> TensorOpsDescription {
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::Sub(BinaryOpsDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
fn ops_3() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Sub(
BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
rhs: TensorDescription {
id: TensorId::new(1),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
},
rhs: TensorDescription {
id: TensorId::new(1),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
},
out: TensorDescription {
id: TensorId::new(2),
shape: vec![32, 32],
status: TensorStatus::NotInit,
},
}))
))
}
}

View File

@ -1,5 +1,5 @@
mod base;
mod index;
mod optimization;
pub(crate) use base::*;
pub(super) use index::*;
pub(crate) use optimization::*;

View File

@ -1,57 +0,0 @@
use super::{InsertQuery, OptimizationIndex, SearchQuery};
use crate::stream::TensorOpsDescription;
use serde::{Deserialize, Serialize};
#[derive(Default, Serialize, Deserialize)]
pub(crate) struct OptimizationStore<O> {
pub(super) optimizations: Vec<OptimizationItem<O>>,
pub(super) index: OptimizationIndex,
}
pub(crate) type OptimizationId = usize;
#[derive(Serialize, Deserialize)]
pub(crate) struct OptimizationItem<O> {
pub(crate) stream: Vec<TensorOpsDescription>,
pub(crate) end_conditions: Vec<TensorOpsDescription>,
pub(crate) value: O,
}
impl<O> OptimizationStore<O> {
pub fn new() -> Self {
Self {
optimizations: Vec::new(),
index: OptimizationIndex::default(),
}
}
pub fn find(&self, query: SearchQuery<'_>) -> Vec<OptimizationId> {
self.index.find(query)
}
pub fn add(&mut self, optimization: OptimizationItem<O>) -> OptimizationId {
let id = self.optimizations.len();
self.index.insert(InsertQuery::NewOptimization {
stream: &optimization.stream,
id,
});
self.optimizations.push(optimization);
id
}
pub fn get_mut_unchecked(&mut self, id: OptimizationId) -> &mut OptimizationItem<O> {
&mut self.optimizations[id]
}
pub fn get_unchecked(&self, id: OptimizationId) -> &OptimizationItem<O> {
&self.optimizations[id]
}
/// Add a new end condition for an optimization.
pub fn add_end_condition(&mut self, id: OptimizationId, end_condition: TensorOpsDescription) {
self.optimizations[id].end_conditions.push(end_condition)
}
}

View File

@ -82,7 +82,9 @@ where
type Handle = WgpuFusionHandle;
type FusionClient = MutexFusionClient<Self>;
fn optimizations(device: WgpuDevice) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self>>> {
fn optimizations(
device: WgpuDevice,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::new(device))]
}

View File

@ -7,8 +7,9 @@ use crate::{
};
use burn_fusion::{
stream::{
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
UnaryOperationDescription,
},
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
};
@ -35,43 +36,43 @@ where
pub(crate) device: Device<Wgpu<G, F, I>>,
}
impl<G, F, I> OptimizationBuilder<Wgpu<G, F, I>> for ElementWiseBuilder<G, F, I>
impl<G, F, I> OptimizationBuilder<WgpuOptimization<G, F, I>> for ElementWiseBuilder<G, F, I>
where
G: GraphicsApi,
F: FloatElement,
I: IntElement,
{
fn register(&mut self, ops: &TensorOpsDescription) {
fn register(&mut self, ops: &OperationDescription) {
if let OptimizationStatus::Closed = self.status {
return;
}
match ops {
TensorOpsDescription::BaseOpsFloat(ops) => {
OperationDescription::BaseFloat(ops) => {
if !self.register_base::<F>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::BaseOpsInt(ops) => {
OperationDescription::BaseInt(ops) => {
if !self.register_base::<I>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::FloatOps(ops) => {
OperationDescription::Float(ops) => {
if !self.register_float::<F>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::NumericOpsFloat(ops) => {
OperationDescription::NumericFloat(ops) => {
if !self.register_numeric::<F, _>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
TensorOpsDescription::NumericOpsInt(ops) => {
OperationDescription::NumericInt(ops) => {
if !self.register_numeric::<I, _>(ops) {
self.status = OptimizationStatus::Closed;
return;
@ -107,6 +108,10 @@ where
WgpuOptimization::ElementWise(op.compile())
}
fn len(&self) -> usize {
self.operators.len()
}
fn reset(&mut self) {
self.inputs.clear();
self.locals.drain();
@ -394,9 +399,9 @@ where
Variable::Local(local_index, Item::Scalar(elem))
}
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOperationDescription) -> bool {
match ops {
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
BaseOperationDescription::Equal(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
@ -405,49 +410,49 @@ where
}
}
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOperationDescription) -> bool {
match ops {
FloatOpsDescription::Exp(desc) => {
FloatOperationDescription::Exp(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Exp { input, out }
})
}
FloatOpsDescription::Log(desc) => {
FloatOperationDescription::Log(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log { input, out }
})
}
FloatOpsDescription::Log1p(desc) => {
FloatOperationDescription::Log1p(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log1p { input, out }
})
}
FloatOpsDescription::Cos(desc) => {
FloatOperationDescription::Cos(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Cos { input, out }
})
}
FloatOpsDescription::Sin(desc) => {
FloatOperationDescription::Sin(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Sin { input, out }
})
}
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
FloatOperationDescription::Powf(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
),
FloatOpsDescription::Tanh(desc) => {
FloatOperationDescription::Tanh(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Tanh { input, out }
})
}
FloatOpsDescription::Erf(desc) => {
FloatOperationDescription::Erf(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Erf { input, out }
})
}
FloatOpsDescription::Recip(desc) => {
FloatOperationDescription::Recip(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Recip { input, out }
})
@ -458,100 +463,100 @@ where
fn register_numeric<E: WgpuElement, EDesc: WgpuElement>(
&mut self,
ops: &NumericOpsDescription<EDesc>,
ops: &NumericOperationDescription<EDesc>,
) -> bool {
match ops {
NumericOpsDescription::Add(desc) => self.register_binary_ops(
NumericOperationDescription::Add(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
),
NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
),
NumericOpsDescription::Sub(desc) => self.register_binary_ops(
NumericOperationDescription::Sub(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
),
NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
),
NumericOpsDescription::Mul(desc) => self.register_binary_ops(
NumericOperationDescription::Mul(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::Div(desc) => self.register_binary_ops(
NumericOperationDescription::Div(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::Abs(desc) => {
NumericOperationDescription::Abs(desc) => {
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Abs { input, out }
})
}
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
NumericOperationDescription::Lower(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
NumericOperationDescription::Greater(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
),
NumericOpsDescription::MaskWhere(desc) => {
NumericOperationDescription::MaskWhere(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
@ -571,7 +576,7 @@ where
true
}
NumericOpsDescription::MaskFill(desc) => {
NumericOperationDescription::MaskFill(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
@ -596,7 +601,7 @@ where
fn register_binary_ops<Func>(
&mut self,
desc: &BinaryOpsDescription,
desc: &BinaryOperationDescription,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func,
) -> bool
@ -618,7 +623,7 @@ where
fn register_unary_ops<Func>(
&mut self,
desc: &UnaryOpsDescription,
desc: &UnaryOperationDescription,
(elem_input, elem_out): (Elem, Elem),
func: Func,
) -> bool
@ -639,7 +644,7 @@ where
fn register_scalar_ops<Func, E: Element>(
&mut self,
desc: &ScalarOpsDescription<E>,
desc: &ScalarOperationDescription<E>,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func,
) -> bool

View File

@ -283,7 +283,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use burn_fusion::stream::Ops;
use burn_fusion::stream::Operation;
use burn_fusion::{Fusion, FusionBackend};
use burn_tensor::Int;
use burn_tensor::{backend::Backend, Data, Tensor};
@ -419,7 +419,7 @@ mod tests {
struct FakeAddOps;
impl<B: FusionBackend> Ops<B> for FakeAddOps {
impl<B: FusionBackend> Operation<B> for FakeAddOps {
fn execute(self: Box<Self>, _: &mut burn_fusion::HandleContainer<B>) {
panic!("Should always fused during tests.")
}