mirror of https://github.com/tracel-ai/burn.git
[burn-fusion] save all execution plans for any trigger (#1143)
This commit is contained in:
parent
6079f98950
commit
b99726f804
|
@ -51,6 +51,10 @@ pub fn save<B: Backend>(
|
|||
.join("burn")
|
||||
.join("backend-comparison");
|
||||
|
||||
for bench in benches.iter() {
|
||||
println!("{bench}");
|
||||
}
|
||||
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
client::FusionClient,
|
||||
stream::{Context, TensorOpsDescription},
|
||||
stream::{Context, OperationDescription},
|
||||
FusionClientLocator, FusionTensor,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Device, Shape};
|
||||
|
@ -70,7 +70,7 @@ pub struct OptimizationProperties {
|
|||
}
|
||||
|
||||
/// The fusion operation abstraction allows implementations to fuse many
|
||||
/// [tensor operations](TensorOpsDescription) into one, improving the performance of the backend.
|
||||
/// [tensor operations](OperationDescription) into one, improving the performance of the backend.
|
||||
///
|
||||
///
|
||||
/// # Notes
|
||||
|
@ -79,19 +79,25 @@ pub struct OptimizationProperties {
|
|||
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
|
||||
/// operations should be fused, but that another way of executing them is more efficient.
|
||||
///
|
||||
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
|
||||
/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
|
||||
/// improve the performance.
|
||||
pub trait OptimizationBuilder<B: FusionBackend>: Send {
|
||||
/// Register a new [tensor operation](TensorOpsDescription).
|
||||
fn register(&mut self, ops: &TensorOpsDescription);
|
||||
pub trait OptimizationBuilder<O>: Send {
|
||||
/// Register a new [tensor operation](OperationDescription).
|
||||
fn register(&mut self, operation: &OperationDescription);
|
||||
/// Finish the optimization and create a fusion operation.
|
||||
fn build(&self) -> B::Optimization;
|
||||
fn build(&self) -> O;
|
||||
/// Reset the state.
|
||||
fn reset(&mut self);
|
||||
/// Return the builder [status](OptimizationStatus).
|
||||
fn status(&self) -> OptimizationStatus;
|
||||
/// Return the builder [properties](OptimizationProperties).
|
||||
fn properties(&self) -> OptimizationProperties;
|
||||
/// The number of operation fused.
|
||||
fn len(&self) -> usize;
|
||||
/// If no operations are fused.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
/// The operation created from the [builder](OptimizationBuilder).
|
||||
|
@ -143,7 +149,8 @@ pub trait FusionBackend: Backend {
|
|||
type FusionClient: FusionClient<FusionBackend = Self>;
|
||||
|
||||
/// The list of optimizations that will be used to optimize the computational graph.
|
||||
fn optimizations(device: Device<Self>) -> Vec<Box<dyn OptimizationBuilder<Self>>>;
|
||||
fn optimizations(device: Device<Self>)
|
||||
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;
|
||||
|
||||
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
|
||||
fn float_tensor<const D: usize>(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
stream::{Ops, TensorOpsDescription},
|
||||
stream::{Operation, OperationDescription},
|
||||
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
|
||||
};
|
||||
use burn_tensor::{
|
||||
|
@ -14,11 +14,11 @@ pub trait FusionClient: Send + Sync + Clone {
|
|||
|
||||
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
|
||||
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
|
||||
/// Register a new [tensor operation description](TensorOpsDescription).
|
||||
fn register<O: Ops<Self::FusionBackend> + 'static>(
|
||||
/// Register a new [tensor operation description](OperationDescription).
|
||||
fn register<O: Operation<Self::FusionBackend> + 'static>(
|
||||
&self,
|
||||
description: TensorOpsDescription,
|
||||
ops: O,
|
||||
description: OperationDescription,
|
||||
operation: O,
|
||||
);
|
||||
/// Register all lazy computation.
|
||||
fn drain(&self);
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use super::FusionClient;
|
||||
use crate::{stream::TensorOpsDescription, FusionBackend, FusionServer, FusionTensor, Handle};
|
||||
use crate::{
|
||||
stream::{Operation, OperationDescription},
|
||||
FusionBackend, FusionServer, FusionTensor, Handle,
|
||||
};
|
||||
use burn_tensor::ops::FloatElem;
|
||||
use spin::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
@ -38,12 +41,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn register<O: crate::stream::Ops<Self::FusionBackend> + 'static>(
|
||||
fn register<O: Operation<Self::FusionBackend> + 'static>(
|
||||
&self,
|
||||
description: TensorOpsDescription,
|
||||
ops: O,
|
||||
description: OperationDescription,
|
||||
operation: O,
|
||||
) {
|
||||
self.server.lock().register(description, Box::new(ops))
|
||||
self.server
|
||||
.lock()
|
||||
.register(description, Box::new(operation))
|
||||
}
|
||||
|
||||
fn drain(&self) {
|
||||
|
|
|
@ -7,10 +7,10 @@ macro_rules! binary_float_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: BinaryOpsDescription,
|
||||
desc: BinaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor(&self.desc.rhs);
|
||||
|
@ -31,10 +31,10 @@ macro_rules! binary_float_cmp_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: BinaryOpsDescription,
|
||||
desc: BinaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_float_tensor(&self.desc.rhs);
|
||||
|
@ -55,10 +55,10 @@ macro_rules! binary_int_cmp_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: BinaryOpsDescription,
|
||||
desc: BinaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor(&self.desc.rhs);
|
||||
|
@ -89,10 +89,10 @@ macro_rules! binary_int_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: BinaryOpsDescription,
|
||||
desc: BinaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_int_tensor(&self.desc.rhs);
|
||||
|
|
|
@ -3,9 +3,10 @@ use crate::{
|
|||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
stream::{
|
||||
BaseOpsDescription, BinaryOpsDescription, BoolOpsDescription, CatOpsDescription, Ops,
|
||||
ReshapeDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
|
||||
TensorOpsDescription, UnaryOpsDescription,
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, Operation, OperationDescription, ReshapeDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
|
@ -48,10 +49,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct IntoIntOps<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for IntoIntOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoIntOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_into_int(input);
|
||||
|
@ -61,12 +62,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BoolOps(BoolOpsDescription::IntoInt(desc.clone())),
|
||||
OperationDescription::Bool(BoolOperationDescription::IntoInt(desc.clone())),
|
||||
IntoIntOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -78,10 +79,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> burn_tensor::ops::FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct IntoFloatOps<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for IntoFloatOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_into_float(input);
|
||||
|
@ -91,12 +92,12 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BoolOps(BoolOpsDescription::IntoFloat(desc.clone())),
|
||||
OperationDescription::Bool(BoolOperationDescription::IntoFloat(desc.clone())),
|
||||
IntoFloatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -135,7 +136,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
desc: ReshapeDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
|
||||
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
|
@ -151,7 +152,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Reshape(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Reshape(desc.clone())),
|
||||
ReshapeDimsOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -164,10 +165,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D1> {
|
||||
#[derive(new)]
|
||||
struct SliceOps<const D1: usize, const D2: usize> {
|
||||
desc: SliceOpsDescription,
|
||||
desc: SliceOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
|
||||
|
||||
|
@ -186,13 +187,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = SliceOpsDescription {
|
||||
let desc = SliceOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Slice(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Slice(desc.clone())),
|
||||
SliceOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -206,10 +207,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D1> {
|
||||
#[derive(new)]
|
||||
struct SliceAssignOps<const D1: usize, const D2: usize> {
|
||||
desc: SliceAssignOpsDescription,
|
||||
desc: SliceAssignOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceAssignOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_bool_tensor::<D1>(&self.desc.tensor);
|
||||
let value = handles.get_bool_tensor::<D1>(&self.desc.value);
|
||||
|
@ -227,7 +228,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = SliceAssignOpsDescription {
|
||||
let desc = SliceAssignOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.into(),
|
||||
value: value.into_description(),
|
||||
|
@ -235,7 +236,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::SliceAssign(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::SliceAssign(desc.clone())),
|
||||
SliceAssignOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -248,10 +249,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct CatOps<const D: usize> {
|
||||
desc: CatOpsDescription,
|
||||
desc: CatOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for CatOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
|
@ -278,13 +279,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = CatOpsDescription {
|
||||
let desc = CatOperationDescription {
|
||||
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
|
||||
dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Cat(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Cat(desc.clone())),
|
||||
CatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -297,10 +298,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct EqualOps<const D: usize> {
|
||||
desc: BinaryOpsDescription,
|
||||
desc: BinaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for EqualOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for EqualOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_bool_tensor::<D>(&self.desc.lhs);
|
||||
let rhs = handles.get_bool_tensor(&self.desc.rhs);
|
||||
|
@ -313,13 +314,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::Equal(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Equal(desc.clone())),
|
||||
EqualOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -329,10 +330,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
fn bool_not<const D: usize>(tensor: BoolTensor<Self, D>) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct NotOps<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for NotOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for NotOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_not(input);
|
||||
|
@ -342,13 +343,13 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
TensorOpsDescription::BoolOps(crate::stream::BoolOpsDescription::Not(desc.clone())),
|
||||
OperationDescription::Bool(crate::stream::BoolOperationDescription::Not(desc.clone())),
|
||||
NotOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -365,7 +366,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
desc: SwapDimsDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for SwapDimsOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
|
@ -386,7 +387,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsBool(BaseOpsDescription::SwapDims(desc.clone())),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::SwapDims(desc.clone())),
|
||||
SwapDimsOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,12 +5,13 @@ use crate::{
|
|||
ops::binary::binary_ops_shape,
|
||||
scalar_int_cmp_ops, scalar_int_ops,
|
||||
stream::{
|
||||
self, BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription,
|
||||
GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription,
|
||||
NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
|
||||
SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
|
||||
TensorOpsDescription, UnaryOpsDescription,
|
||||
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, Operation,
|
||||
OperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
unary_int_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -80,7 +81,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: ReshapeDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D1>(&self.desc.input);
|
||||
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
|
@ -96,7 +97,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Reshape(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Reshape(desc.clone())),
|
||||
ReshapeDimsOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -109,10 +110,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D1> {
|
||||
#[derive(new)]
|
||||
struct SliceOps<const D1: usize, const D2: usize> {
|
||||
desc: SliceOpsDescription,
|
||||
desc: SliceOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
|
||||
|
||||
|
@ -131,13 +132,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = SliceOpsDescription {
|
||||
let desc = SliceOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Slice(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Slice(desc.clone())),
|
||||
SliceOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -151,10 +152,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D1> {
|
||||
#[derive(new)]
|
||||
struct SliceAssignOps<const D1: usize, const D2: usize> {
|
||||
desc: SliceAssignOpsDescription,
|
||||
desc: SliceAssignOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for SliceAssignOps<D1, D2> {
|
||||
impl<const D1: usize, const D2: usize, B: FusionBackend> Operation<B> for SliceAssignOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D1>(&self.desc.tensor);
|
||||
let value = handles.get_int_tensor::<D1>(&self.desc.value);
|
||||
|
@ -171,14 +172,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = SliceAssignOpsDescription {
|
||||
let desc = SliceAssignOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.into(),
|
||||
value: value.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::SliceAssign(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::SliceAssign(desc.clone())),
|
||||
SliceAssignOps::<D1, D2>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -192,10 +193,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct MaskWhereOps<const D: usize> {
|
||||
desc: MaskWhereOpsDescription,
|
||||
desc: MaskWhereOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for MaskWhereOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskWhereOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let value = handles.get_int_tensor(&self.desc.value);
|
||||
|
@ -210,14 +211,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = MaskWhereOpsDescription {
|
||||
let desc = MaskWhereOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
value: value.into_description(),
|
||||
mask: mask.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaskWhere(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())),
|
||||
MaskWhereOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -231,10 +232,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct MaskFillOps<const D: usize> {
|
||||
desc: MaskFillOpsDescription<i32>,
|
||||
desc: MaskFillOperationDescription<i32>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for MaskFillOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaskFillOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let mask = handles.get_bool_tensor(&self.desc.mask);
|
||||
|
@ -247,14 +248,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = MaskFillOpsDescription {
|
||||
let desc = MaskFillOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
value: value.elem(),
|
||||
mask: mask.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaskFill(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())),
|
||||
MaskFillOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -268,10 +269,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct GatherOps<const D: usize> {
|
||||
desc: GatherOpsDescription,
|
||||
desc: GatherOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for GatherOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for GatherOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
@ -283,14 +284,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = indices.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = GatherOpsDescription {
|
||||
let desc = GatherOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
indices: indices.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Gather(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())),
|
||||
GatherOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -305,10 +306,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct ScatterOps<const D: usize> {
|
||||
desc: ScatterOpsDescription,
|
||||
desc: ScatterOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for ScatterOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ScatterOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
@ -322,7 +323,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = ScatterOpsDescription {
|
||||
let desc = ScatterOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
indices: indices.into_description(),
|
||||
|
@ -330,7 +331,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Scatter(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())),
|
||||
ScatterOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -344,10 +345,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct SelectOps<const D: usize> {
|
||||
desc: SelectOpsDescription,
|
||||
desc: SelectOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for SelectOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
@ -361,14 +362,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let mut shape: Vec<usize> = tensor.shape.clone();
|
||||
shape[dim] = indices.shape[0];
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = SelectOpsDescription {
|
||||
let desc = SelectOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
indices: indices.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Select(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())),
|
||||
SelectOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -383,10 +384,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct SelectAssignOps<const D: usize> {
|
||||
desc: SelectAssignOpsDescription,
|
||||
desc: SelectAssignOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for SelectAssignOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SelectAssignOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let indices = handles.get_int_tensor(&self.desc.indices);
|
||||
|
@ -400,7 +401,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let shape: Vec<usize> = tensor.shape.clone();
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
let desc = SelectAssignOpsDescription {
|
||||
let desc = SelectAssignOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
indices: indices.into_description(),
|
||||
|
@ -408,7 +409,9 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SelectAssign(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::SelectAssign(
|
||||
desc.clone(),
|
||||
)),
|
||||
SelectAssignOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -418,10 +421,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct CatOps<const D: usize> {
|
||||
desc: CatOpsDescription,
|
||||
desc: CatOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for CatOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for CatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensors = self
|
||||
.desc
|
||||
|
@ -448,13 +451,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = CatOpsDescription {
|
||||
let desc = CatOperationDescription {
|
||||
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
|
||||
dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Cat(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Cat(desc.clone())),
|
||||
CatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -471,13 +474,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::Equal(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Equal(desc.clone())),
|
||||
EqualOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -492,13 +495,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::EqualElem(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())),
|
||||
EqualElemOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -515,13 +518,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Greater(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())),
|
||||
GreaterOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -536,13 +539,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterElem(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
GreaterElemOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -559,13 +564,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterEqual(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual(
|
||||
desc.clone(),
|
||||
)),
|
||||
GreaterEqualOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -580,13 +587,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::GreaterEqualElem(
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
GreaterEqualElemOps::<D>::new(desc),
|
||||
|
@ -605,13 +612,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Lower(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())),
|
||||
LowerOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -626,13 +633,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerElem(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())),
|
||||
LowerElemOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -649,13 +656,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerEqual(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())),
|
||||
LowerEqualOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -670,13 +677,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::LowerEqualElem(
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
LowerEqualElemOps::<D>::new(desc),
|
||||
|
@ -695,13 +702,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Add(desc.clone())),
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Add(
|
||||
desc.clone(),
|
||||
)),
|
||||
AddOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -716,13 +725,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::AddScalar(
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
AddOps::<D>::new(desc),
|
||||
|
@ -741,13 +750,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Sub(desc.clone())),
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Sub(
|
||||
desc.clone(),
|
||||
)),
|
||||
SubOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -762,13 +773,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SubScalar(
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
SubOps::<D>::new(desc),
|
||||
|
@ -787,13 +798,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Mul(desc.clone())),
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Mul(
|
||||
desc.clone(),
|
||||
)),
|
||||
MulOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -808,13 +821,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MulScalar(
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
MulOps::<D>::new(desc),
|
||||
|
@ -833,13 +846,15 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape));
|
||||
|
||||
let desc = BinaryOpsDescription {
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Div(desc.clone())),
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::Div(
|
||||
desc.clone(),
|
||||
)),
|
||||
DivOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -854,13 +869,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = lhs.client.tensor_uninitialized(lhs.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
stream::TensorOpsDescription::NumericOpsInt(NumericOpsDescription::DivScalar(
|
||||
stream::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
DivOps::<D>::new(desc),
|
||||
|
@ -875,7 +890,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: TensorDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for ZerosOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ZerosOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.shape.clone());
|
||||
let output = B::int_zeros::<D>(shape, &handles.device);
|
||||
|
@ -888,7 +903,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let out = client.tensor_uninitialized(shape);
|
||||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Zeros(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())),
|
||||
ZerosOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -901,7 +916,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: TensorDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for OnesOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for OnesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let shape = Shape::from(self.desc.shape.clone());
|
||||
let output = B::int_ones::<D>(shape, &handles.device);
|
||||
|
@ -915,7 +930,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Ones(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())),
|
||||
OnesOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -927,12 +942,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Sum(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())),
|
||||
SumOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -946,13 +961,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::SumDim(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())),
|
||||
SumDimOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -964,12 +979,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Mean(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())),
|
||||
MeanOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -983,13 +998,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MeanDim(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())),
|
||||
MeanDimOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1003,13 +1018,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ArgMax(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())),
|
||||
ArgMaxOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1023,13 +1038,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ArgMin(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())),
|
||||
ArgMinOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1043,10 +1058,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct ClampOps<const D: usize> {
|
||||
desc: ClampOpsDescription<i32>,
|
||||
desc: ClampOperationDescription<i32>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for ClampOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for ClampOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
|
||||
|
@ -1056,14 +1071,14 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
let desc = ClampOpsDescription {
|
||||
let desc = ClampOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
min: min.elem(),
|
||||
max: max.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Clamp(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())),
|
||||
ClampOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1075,12 +1090,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Abs(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())),
|
||||
AbsOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1090,10 +1105,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct IntoFloatOps<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for IntoFloatOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for IntoFloatOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = B::int_into_float(input);
|
||||
|
@ -1102,12 +1117,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::IntOps(stream::IntOpsDescription::IntoFloat(desc.clone())),
|
||||
OperationDescription::Int(stream::IntOperationDescription::IntoFloat(desc.clone())),
|
||||
IntoFloatOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1124,7 +1139,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: SwapDimsDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for SwapDimsOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for SwapDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
|
||||
|
@ -1145,7 +1160,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::BaseOpsInt(BaseOpsDescription::SwapDims(desc.clone())),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::SwapDims(desc.clone())),
|
||||
SwapDimsOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1157,12 +1172,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Max(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())),
|
||||
MaxOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1176,13 +1191,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaxDim(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())),
|
||||
MaxDimOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1198,7 +1213,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: ReduceDimWithIndicesDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for MaxDimWithIndicesOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MaxDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
|
||||
|
@ -1220,7 +1235,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MaxDimWithIndices(
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxDimWithIndicesOps::<D>::new(desc),
|
||||
|
@ -1234,12 +1249,12 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let out = tensor.client.tensor_uninitialized(vec![1]);
|
||||
|
||||
let desc = UnaryOpsDescription {
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::Min(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())),
|
||||
MinOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1253,13 +1268,13 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] = 1;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
let desc = ScalarOperationDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MinDim(desc.clone())),
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())),
|
||||
MinDimOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1275,7 +1290,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
desc: ReduceDimWithIndicesDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for MinDimWithIndicesOps<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for MinDimWithIndicesOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let tensor = handles.get_int_tensor::<D>(&self.desc.tensor);
|
||||
let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
|
||||
|
@ -1297,7 +1312,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::MinDimWithIndices(
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
MinDimWithIndicesOps::<D>::new(desc),
|
||||
|
|
|
@ -7,8 +7,8 @@ use crate::{
|
|||
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
|
||||
ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Ops,
|
||||
TensorOpsDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Operation,
|
||||
OperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend, HandleContainer,
|
||||
};
|
||||
|
@ -28,7 +28,7 @@ macro_rules! make_ops {
|
|||
desc: $desc,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Ops<B> for $name {
|
||||
impl<B: FusionBackend> Operation<B> for $name {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
$fn(self.desc, handles)
|
||||
|
@ -78,7 +78,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.clone().register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::Conv1d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv1d(
|
||||
description.clone(),
|
||||
)),
|
||||
Conv1dOps::new(description),
|
||||
|
@ -136,7 +136,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::Conv2d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
Conv2dOps::new(desc),
|
||||
|
@ -188,9 +188,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::ConvTranspose1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::ConvTranspose1d(desc.clone()),
|
||||
),
|
||||
ConvTranspose1dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -248,9 +248,9 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::ConvTranspose2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::ConvTranspose2d(desc.clone()),
|
||||
),
|
||||
ConvTranspose2dOps::new(desc),
|
||||
);
|
||||
|
||||
|
@ -294,7 +294,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::AvgPool1d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
AvgPool1dOps::new(desc),
|
||||
|
@ -344,7 +344,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::AvgPool2d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
AvgPool2dOps::new(desc),
|
||||
|
@ -392,8 +392,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AvgPool1dBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AvgPool1dBackward(desc.clone()),
|
||||
),
|
||||
AvgPool1dBackwardOps::new(desc),
|
||||
);
|
||||
|
@ -440,8 +440,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AvgPool2dBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AvgPool2dBackward(desc.clone()),
|
||||
),
|
||||
AvgPool2dBackwardOps::new(desc),
|
||||
);
|
||||
|
@ -487,7 +487,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::MaxPool1d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool1d(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool1dOps::new(desc),
|
||||
|
@ -547,7 +547,7 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(crate::stream::ModuleOpsDescription::MaxPool2d(
|
||||
OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool2d(
|
||||
desc.clone(),
|
||||
)),
|
||||
MaxPool2dOps::new(desc),
|
||||
|
@ -596,8 +596,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::MaxPool1dWithIndices(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool1dWithIndices(desc.clone()),
|
||||
),
|
||||
MaxPool1dWithIndicesOps::new(desc),
|
||||
);
|
||||
|
@ -659,8 +659,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::MaxPool2dWithIndices(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool2dWithIndices(desc.clone()),
|
||||
),
|
||||
MaxPool2dWithIndicesOps::new(desc),
|
||||
);
|
||||
|
@ -711,8 +711,10 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool1dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
),
|
||||
),
|
||||
MaxPool1dWithIndicesBackwardOps::new(desc),
|
||||
);
|
||||
|
@ -763,8 +765,10 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::MaxPool2dWithIndicesBackward(
|
||||
desc.clone(),
|
||||
),
|
||||
),
|
||||
MaxPool2dWithIndicesBackwardOps::new(desc),
|
||||
);
|
||||
|
@ -793,8 +797,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AdaptiveAvgPool1d(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1d(desc.clone()),
|
||||
),
|
||||
AdaptiveAvgPool1dOps::new(desc),
|
||||
);
|
||||
|
@ -826,8 +830,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AdaptiveAvgPool2d(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2d(desc.clone()),
|
||||
),
|
||||
AdaptiveAvgPool2dOps::new(desc),
|
||||
);
|
||||
|
@ -859,8 +863,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
};
|
||||
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc.clone()),
|
||||
),
|
||||
AdaptiveAvgPool1dBackwardOps::new(desc),
|
||||
);
|
||||
|
@ -892,8 +896,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::ModuleOps(
|
||||
crate::stream::ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc.clone()),
|
||||
OperationDescription::Module(
|
||||
crate::stream::ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc.clone()),
|
||||
),
|
||||
AdaptiveAvgPool2dBackwardOps::new(desc),
|
||||
);
|
||||
|
|
|
@ -14,10 +14,10 @@ macro_rules! scalar_float_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<$elem>,
|
||||
desc: ScalarOperationDescription<$elem>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
@ -34,10 +34,10 @@ macro_rules! scalar_float_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<$elem>,
|
||||
desc: ScalarOperationDescription<$elem>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
@ -58,10 +58,10 @@ macro_rules! scalar_float2int_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<$elem>,
|
||||
desc: ScalarOperationDescription<$elem>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs.clone());
|
||||
|
@ -81,10 +81,10 @@ macro_rules! unary_float_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
@ -104,10 +104,10 @@ macro_rules! unary_int_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: UnaryOpsDescription,
|
||||
desc: UnaryOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let output = $ops(input);
|
||||
|
@ -127,10 +127,10 @@ macro_rules! scalar_float_cmp_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<f32>,
|
||||
desc: ScalarOperationDescription<f32>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
@ -150,10 +150,10 @@ macro_rules! scalar_int_cmp_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<i32>,
|
||||
desc: ScalarOperationDescription<i32>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
@ -180,10 +180,10 @@ macro_rules! scalar_int_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<$elem>,
|
||||
desc: ScalarOperationDescription<$elem>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs));
|
||||
|
@ -200,10 +200,10 @@ macro_rules! scalar_int_ops {
|
|||
) => {
|
||||
#[derive(new)]
|
||||
struct $name<const D: usize> {
|
||||
desc: ScalarOpsDescription<$elem>,
|
||||
desc: ScalarOperationDescription<$elem>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
|
||||
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
|
||||
let output = $ops(lhs, self.desc.rhs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
stream::{MultiStream, Ops, TensorOpsDescription},
|
||||
stream::{MultiStream, Operation, OperationDescription},
|
||||
FusionBackend, HandleContainer, TensorId,
|
||||
};
|
||||
use burn_tensor::ops::{FloatElem, IntElem};
|
||||
|
@ -26,8 +26,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, ops_desc: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
|
||||
self.streams.register(ops_desc, ops, &mut self.handles)
|
||||
pub fn register(&mut self, desc: OperationDescription, operation: Box<dyn Operation<B>>) {
|
||||
self.streams.register(desc, operation, &mut self.handles)
|
||||
}
|
||||
|
||||
pub fn drain_streams(&mut self) {
|
||||
|
|
|
@ -1,51 +1,52 @@
|
|||
use super::Ops;
|
||||
use super::RelativeStreamConverter;
|
||||
use super::TensorOpsDescription;
|
||||
use super::Operation;
|
||||
use super::OperationConverter;
|
||||
use super::OperationDescription;
|
||||
use crate::FusionBackend;
|
||||
|
||||
/// A growing list of [tensor operation descriptions](TensorOpsDescription).
|
||||
pub struct Stream<B: FusionBackend> {
|
||||
pub(crate) global: Vec<TensorOpsDescription>,
|
||||
pub(crate) relative: Vec<TensorOpsDescription>,
|
||||
pub(crate) converter: RelativeStreamConverter,
|
||||
pub(crate) ops: Vec<Box<dyn Ops<B>>>,
|
||||
/// A growing list of [tensor operation descriptions](OperationDescription).
|
||||
pub struct OperationQueue<B: FusionBackend> {
|
||||
pub(crate) global: Vec<OperationDescription>,
|
||||
pub(crate) relative: Vec<OperationDescription>,
|
||||
pub(crate) converter: OperationConverter,
|
||||
pub(crate) operations: Vec<Box<dyn Operation<B>>>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Stream<B> {
|
||||
pub(crate) fn new() -> Self {
|
||||
impl<B: FusionBackend> Default for OperationQueue<B> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> OperationQueue<B> {
|
||||
/// Create a new empty queue.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
global: Vec::new(),
|
||||
relative: Vec::new(),
|
||||
converter: RelativeStreamConverter::default(),
|
||||
ops: Vec::new(),
|
||||
converter: OperationConverter::default(),
|
||||
operations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn split_relative_stream(
|
||||
&self,
|
||||
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
|
||||
let len = self.relative.len();
|
||||
if len < 1 {
|
||||
return (&self.relative, None);
|
||||
}
|
||||
|
||||
(&self.relative[0..len - 1], self.relative.last())
|
||||
}
|
||||
|
||||
pub(crate) fn add(&mut self, global: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
|
||||
/// Add a new tensor operation to the queue.
|
||||
///
|
||||
/// The new [operation description](OperationDescription) will be converted to a local
|
||||
/// representation that can be reused when the same pattern emerge in different but similar
|
||||
/// scenario, so that the same optimization can be used.
|
||||
pub fn add(&mut self, global: OperationDescription, operation: Box<dyn Operation<B>>) {
|
||||
let relative = global.to_relative(&mut self.converter);
|
||||
self.relative.push(relative);
|
||||
self.global.push(global);
|
||||
self.ops.push(ops);
|
||||
self.operations.push(operation);
|
||||
}
|
||||
|
||||
/// The size of the stream.
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
/// The size of the queue.
|
||||
pub fn len(&self) -> usize {
|
||||
self.global.len()
|
||||
}
|
||||
|
||||
/// If the stream is empty.
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
/// If the queue is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,32 +1,32 @@
|
|||
use crate::{
|
||||
stream::{
|
||||
store::{OptimizationId, OptimizationStore},
|
||||
Stream,
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},
|
||||
OperationQueue,
|
||||
},
|
||||
FusionBackend, HandleContainer, Optimization,
|
||||
};
|
||||
|
||||
/// The mode in which the execution is done.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) enum ExecutionMode {
|
||||
Lazy,
|
||||
Sync,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Stream<B> {
|
||||
/// Execute the stream.
|
||||
///
|
||||
/// If an [optimization id](OptimizationId) is provided, use it to execute the stream partially, otherwise
|
||||
/// execute each [operation](crate::stream::Ops).
|
||||
impl<B: FusionBackend> OperationQueue<B> {
|
||||
/// Execute the queue partially following the execution strategy from the plan.
|
||||
pub(crate) fn execute(
|
||||
&mut self,
|
||||
id: Option<OptimizationId>,
|
||||
id: ExecutionPlanId,
|
||||
handles: &mut HandleContainer<B>,
|
||||
store: &mut OptimizationStore<B::Optimization>,
|
||||
store: &mut ExecutionPlanStore<B::Optimization>,
|
||||
) {
|
||||
match id {
|
||||
Some(id) => self.execute_optimization(handles, &mut store.get_mut_unchecked(id).value),
|
||||
None => self.execute_operations(handles),
|
||||
}
|
||||
match &mut store.get_mut_unchecked(id).strategy {
|
||||
ExecutionStrategy::Optimization(optimization) => {
|
||||
self.execute_optimization(handles, optimization)
|
||||
}
|
||||
ExecutionStrategy::Operations => self.execute_operations(handles),
|
||||
};
|
||||
}
|
||||
|
||||
fn execute_optimization(
|
||||
|
@ -40,14 +40,14 @@ impl<B: FusionBackend> Stream<B> {
|
|||
optimization.execute(&mut context);
|
||||
|
||||
self.drain_stream(num_drained, handles);
|
||||
self.ops.drain(0..num_drained);
|
||||
self.operations.drain(0..num_drained);
|
||||
}
|
||||
|
||||
fn execute_operations(&mut self, handles: &mut HandleContainer<B>) {
|
||||
let num_drained = self.ops.len();
|
||||
let num_drained = self.operations.len();
|
||||
|
||||
for ops in self.ops.drain(0..num_drained) {
|
||||
ops.execute(handles);
|
||||
for operation in self.operations.drain(0..num_drained) {
|
||||
operation.execute(handles);
|
||||
}
|
||||
|
||||
self.drain_stream(num_drained, handles);
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
use super::ExecutionMode;
|
||||
use crate::{stream::Stream, FusionBackend, OptimizationBuilder, OptimizationStatus};
|
||||
|
||||
/// Explore and create new optimization.
|
||||
pub struct Explorer<B: FusionBackend> {
|
||||
builders: Vec<Box<dyn OptimizationBuilder<B>>>,
|
||||
num_deferred: usize,
|
||||
}
|
||||
|
||||
/// The result of an exploration.
|
||||
///
|
||||
/// Either a new optimization is found, or we just continue to explore further.
|
||||
pub enum Exploration<'a, B: FusionBackend> {
|
||||
OptimizationFound(Option<&'a dyn OptimizationBuilder<B>>),
|
||||
Continue,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Explorer<B> {
|
||||
pub(crate) fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
|
||||
Self {
|
||||
builders: optimizations,
|
||||
num_deferred: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn defer(&mut self) {
|
||||
self.num_deferred += 1;
|
||||
}
|
||||
|
||||
pub(crate) fn up_to_date(&self) -> bool {
|
||||
self.num_deferred == 0
|
||||
}
|
||||
|
||||
pub(crate) fn explore<'a>(
|
||||
&'a mut self,
|
||||
stream: &Stream<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> Exploration<'a, B> {
|
||||
// When we are executing with the new ops mode, we need to register the last ops of the
|
||||
// stream even when there is no skipped operation.
|
||||
let offset = match mode {
|
||||
ExecutionMode::Lazy => 1,
|
||||
ExecutionMode::Sync => 0,
|
||||
};
|
||||
|
||||
for i in (0..self.num_deferred + offset).rev() {
|
||||
let index = stream.relative.len() - 1 - i;
|
||||
let relative = &stream.relative[index];
|
||||
|
||||
for builder in self.builders.iter_mut() {
|
||||
builder.register(relative);
|
||||
}
|
||||
}
|
||||
self.num_deferred = 0;
|
||||
|
||||
// Can only be lazy when not sync.
|
||||
if let ExecutionMode::Lazy = mode {
|
||||
if still_optimizing(&self.builders) {
|
||||
return Exploration::Continue;
|
||||
}
|
||||
}
|
||||
|
||||
match find_best_optimization_index(&mut self.builders) {
|
||||
Some(index) => Exploration::OptimizationFound(Some(self.builders[index].as_ref())),
|
||||
None => Exploration::OptimizationFound(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&mut self, stream: &Stream<B>) {
|
||||
for ops in self.builders.iter_mut() {
|
||||
ops.reset();
|
||||
}
|
||||
self.num_deferred = stream.relative.len();
|
||||
}
|
||||
}
|
||||
|
||||
fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuilder<B>>]) -> bool {
|
||||
let mut num_stopped = 0;
|
||||
|
||||
for optimization in optimizations.iter() {
|
||||
if let OptimizationStatus::Closed = optimization.status() {
|
||||
num_stopped += 1
|
||||
}
|
||||
}
|
||||
|
||||
num_stopped < optimizations.len()
|
||||
}
|
||||
|
||||
fn find_best_optimization_index<B: FusionBackend>(
|
||||
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
|
||||
) -> Option<usize> {
|
||||
let mut best_index = None;
|
||||
let mut best_score = 0;
|
||||
|
||||
for (i, optimization) in optimizations.iter().enumerate() {
|
||||
let properties = optimization.properties();
|
||||
|
||||
if properties.ready && properties.score >= best_score {
|
||||
best_index = Some(i);
|
||||
best_score = properties.score;
|
||||
}
|
||||
}
|
||||
|
||||
best_index
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
use super::ExecutionMode;
|
||||
use crate::{stream::OperationDescription, OptimizationBuilder, OptimizationStatus};
|
||||
|
||||
/// Explore and create new optimization.
|
||||
pub struct Explorer<O> {
|
||||
builders: Vec<Box<dyn OptimizationBuilder<O>>>,
|
||||
num_deferred: usize,
|
||||
num_explored: usize,
|
||||
}
|
||||
|
||||
/// The result of an exploration done by the [explorer](Explorer).
|
||||
pub enum Exploration<'a, O> {
|
||||
/// Found a new optimization.
|
||||
Found(&'a dyn OptimizationBuilder<O>),
|
||||
/// No optimization is found.
|
||||
NotFound { num_explored: usize },
|
||||
/// We should continue exploring before arriving at a conclusion.
|
||||
Continue,
|
||||
}
|
||||
|
||||
impl<O> Explorer<O> {
|
||||
/// Create a new explorer.
|
||||
pub(crate) fn new(optimizations: Vec<Box<dyn OptimizationBuilder<O>>>) -> Self {
|
||||
Self {
|
||||
builders: optimizations,
|
||||
num_deferred: 0,
|
||||
num_explored: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Defer the exploration.
|
||||
pub(crate) fn defer(&mut self) {
|
||||
self.num_deferred += 1;
|
||||
}
|
||||
|
||||
/// If the explorer is up to date.
|
||||
pub(crate) fn is_up_to_date(&self) -> bool {
|
||||
self.num_deferred == 0
|
||||
}
|
||||
|
||||
/// Explore the provided operations.
|
||||
pub(crate) fn explore<'a>(
|
||||
&'a mut self,
|
||||
operations: &[OperationDescription],
|
||||
mode: ExecutionMode,
|
||||
) -> Exploration<'a, O> {
|
||||
// When we are executing with the new operation mode, we need to register the last ops of the
|
||||
// stream even when there is no skipped operation.
|
||||
let offset = match mode {
|
||||
ExecutionMode::Lazy => 1,
|
||||
ExecutionMode::Sync => 0,
|
||||
};
|
||||
|
||||
let mut is_still_optimizing = still_optimizing(&self.builders);
|
||||
|
||||
for i in (0..self.num_deferred + offset).rev() {
|
||||
if !is_still_optimizing {
|
||||
break;
|
||||
}
|
||||
let index = operations.len() - 1 - i;
|
||||
let relative = &operations[index];
|
||||
|
||||
for builder in self.builders.iter_mut() {
|
||||
builder.register(relative);
|
||||
}
|
||||
self.num_explored += 1;
|
||||
|
||||
is_still_optimizing = still_optimizing(&self.builders);
|
||||
}
|
||||
self.num_deferred = 0;
|
||||
|
||||
// Can only continue exploration when not sync.
|
||||
if let ExecutionMode::Lazy = mode {
|
||||
if is_still_optimizing {
|
||||
return Exploration::Continue;
|
||||
}
|
||||
}
|
||||
|
||||
match find_best_optimization_index(&mut self.builders) {
|
||||
Some(index) => Exploration::Found(self.builders[index].as_ref()),
|
||||
None => Exploration::NotFound {
|
||||
num_explored: self.num_explored,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the state of the explorer to the provided list of operations.
|
||||
pub(crate) fn reset(&mut self, operations: &[OperationDescription]) {
|
||||
for operation in self.builders.iter_mut() {
|
||||
operation.reset();
|
||||
}
|
||||
self.num_explored = 0;
|
||||
self.num_deferred = operations.len();
|
||||
}
|
||||
}
|
||||
|
||||
fn still_optimizing<O>(optimizations: &[Box<dyn OptimizationBuilder<O>>]) -> bool {
|
||||
let mut num_stopped = 0;
|
||||
|
||||
for optimization in optimizations.iter() {
|
||||
if let OptimizationStatus::Closed = optimization.status() {
|
||||
num_stopped += 1
|
||||
}
|
||||
}
|
||||
|
||||
num_stopped < optimizations.len()
|
||||
}
|
||||
|
||||
fn find_best_optimization_index<O>(
|
||||
optimizations: &mut [Box<dyn OptimizationBuilder<O>>],
|
||||
) -> Option<usize> {
|
||||
let mut best_index = None;
|
||||
let mut best_score = 0;
|
||||
|
||||
for (i, optimization) in optimizations.iter().enumerate() {
|
||||
let properties = optimization.properties();
|
||||
|
||||
if properties.ready && properties.score >= best_score {
|
||||
best_index = Some(i);
|
||||
best_score = properties.score;
|
||||
}
|
||||
}
|
||||
|
||||
best_index
|
||||
}
|
|
@ -1,9 +1,12 @@
|
|||
mod base;
|
||||
mod exploration;
|
||||
mod explorer;
|
||||
mod policy;
|
||||
mod processor;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use exploration::*;
|
||||
pub(crate) use explorer::*;
|
||||
pub(crate) use policy::*;
|
||||
pub(crate) use processor::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
|
|
@ -1,42 +1,43 @@
|
|||
use super::ExecutionMode;
|
||||
use crate::stream::{
|
||||
store::{OptimizationId, OptimizationStore, SearchQuery},
|
||||
TensorOpsDescription,
|
||||
store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery},
|
||||
OperationDescription,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// The stream policy keeps track of all possible optimizations for the current stream.
|
||||
/// The policy keeps track of all possible execution plans for the current operations.
|
||||
///
|
||||
/// # Details
|
||||
///
|
||||
/// We keep track of each new operation added to the stream and invalidate potential optimizations
|
||||
/// when we see a different operation is added while keeping track of the current stream.
|
||||
/// We keep track of each new operation added and invalidate potential execution plans
|
||||
/// when we see a different operation is added.
|
||||
///
|
||||
/// Therefore, the overhead is very minimal, since the time-complexity of checking for existing
|
||||
/// optimizations scales with the number of concurrent potential optimizations for the current stream,
|
||||
/// execution plans scales with the number of concurrent potential plans for the current operations,
|
||||
/// which isn't supposed to be big at any time.
|
||||
pub(crate) struct Policy<O> {
|
||||
// The potential optimizations that we could apply to the current stream, but their streams
|
||||
// The potential explorations that we could apply to the current stream, but their streams
|
||||
// still exceed the size of the current stream.
|
||||
candidates: Vec<OptimizationId>,
|
||||
// Optimizations that we find during the `updates`, but none of their `end_conditions` matches the
|
||||
candidates: Vec<ExecutionPlanId>,
|
||||
// Optimizations that we find during the `updates`, but none of their `trigger` matches the
|
||||
// current stream.
|
||||
availables: Vec<(OptimizationId, usize)>,
|
||||
// Optimization that we find during the `updates` where one of its `end_condition` matches the
|
||||
availables: Vec<(ExecutionPlanId, usize)>,
|
||||
// Optimization that we find during the `updates` where one of its `triggers` matches the
|
||||
// current stream.
|
||||
found: Option<(OptimizationId, usize)>,
|
||||
// The size of the stream currently analyzed.
|
||||
stream_size: usize,
|
||||
found: Option<(ExecutionPlanId, usize)>,
|
||||
// The number of operations analyzed.
|
||||
num_operations: usize,
|
||||
_item_type: PhantomData<O>,
|
||||
}
|
||||
|
||||
impl<O> Policy<O> {
|
||||
/// Create a new policy.
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
candidates: Vec::new(),
|
||||
availables: Vec::new(),
|
||||
found: None,
|
||||
stream_size: 0,
|
||||
num_operations: 0,
|
||||
_item_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -44,17 +45,12 @@ impl<O> Policy<O> {
|
|||
/// Returns the [action](Action) that should be taken given the state of the policy.
|
||||
pub fn action(
|
||||
&self,
|
||||
optimizations: &OptimizationStore<O>,
|
||||
stream: &[TensorOpsDescription],
|
||||
store: &ExecutionPlanStore<O>,
|
||||
operations: &[OperationDescription],
|
||||
mode: ExecutionMode,
|
||||
) -> Action {
|
||||
let num_minimum_analyzed = match mode {
|
||||
ExecutionMode::Lazy => self.stream_size - 1,
|
||||
ExecutionMode::Sync => self.stream_size,
|
||||
};
|
||||
|
||||
if num_minimum_analyzed < stream.len() {
|
||||
panic!("Internal Error: Can't retrieve the policy action when the number of operations analyzed is lower than the stream itself.");
|
||||
if self.num_operations < operations.len() {
|
||||
panic!("Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed.");
|
||||
}
|
||||
|
||||
if let Some((id, _length)) = self.found {
|
||||
|
@ -63,30 +59,27 @@ impl<O> Policy<O> {
|
|||
|
||||
match mode {
|
||||
ExecutionMode::Lazy => {
|
||||
if self.candidates.is_empty() {
|
||||
// Even if there are optimizations available, we aren't sure if they are the best ones
|
||||
// we can use. Exploring more optimizations might find a new `end_condition` or
|
||||
// even find a better optimization.
|
||||
return Action::Explore;
|
||||
if !self.candidates.is_empty() {
|
||||
return Action::Defer;
|
||||
}
|
||||
|
||||
Action::Defer
|
||||
Action::Explore
|
||||
}
|
||||
ExecutionMode::Sync => {
|
||||
// If an optimization covers the _whole_ stream, we return it, else we explore new
|
||||
// optimizations.
|
||||
// If an execution plan covers the _whole_ operation list, we return it, else we explore new
|
||||
// plans.
|
||||
for (id, length) in self.availables.iter() {
|
||||
if *length == stream.len() {
|
||||
if *length == operations.len() {
|
||||
return Action::Execute(*id);
|
||||
}
|
||||
}
|
||||
|
||||
for candidate in self.candidates.iter() {
|
||||
let item = optimizations.get_unchecked(*candidate);
|
||||
let item = store.get_unchecked(*candidate);
|
||||
|
||||
// The candidate can actually be executed, since the stream is of the same
|
||||
// size.
|
||||
if item.stream.len() == stream.len() {
|
||||
if item.operations.len() == operations.len() {
|
||||
return Action::Execute(*candidate);
|
||||
}
|
||||
}
|
||||
|
@ -97,61 +90,68 @@ impl<O> Policy<O> {
|
|||
}
|
||||
|
||||
/// Update the policy state.
|
||||
pub fn update(&mut self, store: &OptimizationStore<O>, ops: &TensorOpsDescription) {
|
||||
if self.stream_size == 0 {
|
||||
self.candidates = store.find(SearchQuery::OptimizationsStartingWith(ops));
|
||||
pub fn update(&mut self, store: &ExecutionPlanStore<O>, ops: &OperationDescription) {
|
||||
if self.num_operations == 0 {
|
||||
self.candidates = store.find(SearchQuery::PlansStartingWith(ops));
|
||||
} else {
|
||||
self.analyze_candidates(store, ops, self.stream_size);
|
||||
self.analyze_candidates(store, ops);
|
||||
}
|
||||
|
||||
self.stream_size += 1;
|
||||
self.num_operations += 1;
|
||||
}
|
||||
|
||||
// Reset the state of the policy.
|
||||
pub fn reset(&mut self) {
|
||||
self.candidates.clear();
|
||||
self.availables.clear();
|
||||
self.stream_size = 0;
|
||||
self.num_operations = 0;
|
||||
self.found = None;
|
||||
}
|
||||
|
||||
fn analyze_candidates(
|
||||
&mut self,
|
||||
optimizations: &OptimizationStore<O>,
|
||||
next_ops: &TensorOpsDescription,
|
||||
stream_size: usize,
|
||||
store: &ExecutionPlanStore<O>,
|
||||
operation: &OperationDescription,
|
||||
) {
|
||||
// The index starts at zero.
|
||||
let mut invalidated_candidates = Vec::new();
|
||||
|
||||
for id in self.candidates.iter() {
|
||||
let item = optimizations.get_unchecked(*id);
|
||||
let item = store.get_unchecked(*id);
|
||||
|
||||
if item.stream.len() == stream_size {
|
||||
if item.end_conditions.contains(next_ops) {
|
||||
self.found = Some((*id, item.stream.len()));
|
||||
if item.operations.len() == self.num_operations + 1
|
||||
&& item.operations.last().unwrap() == operation
|
||||
&& item.triggers.contains(&ExecutionTrigger::Always)
|
||||
{
|
||||
self.found = Some((*id, item.operations.len()));
|
||||
break;
|
||||
}
|
||||
|
||||
if item.operations.len() == self.num_operations {
|
||||
if item.should_stop_async(operation) {
|
||||
self.found = Some((*id, item.operations.len()));
|
||||
break;
|
||||
} else {
|
||||
// The optimization is available, but the current operation isn't an existing
|
||||
// end_condition for this optimization, so we may find a better optimization by
|
||||
// still growing the stream.
|
||||
self.availables.push((*id, item.stream.len()));
|
||||
// The plan is available, but the current operation isn't an existing
|
||||
// trigger for this plan, so we may find a better plan by
|
||||
// still growing the operation list.
|
||||
self.availables.push((*id, item.operations.len()));
|
||||
invalidated_candidates.push(*id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let next_ops_candidate = match item.stream.get(stream_size) {
|
||||
let operation_candidate = match item.operations.get(self.num_operations) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// Stream of different size, invalidated.
|
||||
// Operation list of different size, invalidated.
|
||||
invalidated_candidates.push(*id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if next_ops_candidate != next_ops {
|
||||
// Stream with different node at the current position, invalidated.
|
||||
if operation_candidate != operation {
|
||||
// Operation list with different node at the current position, invalidated.
|
||||
invalidated_candidates.push(*id);
|
||||
continue;
|
||||
}
|
||||
|
@ -170,31 +170,34 @@ impl<O> Policy<O> {
|
|||
/// Action to be made depending on the stream.
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum Action {
|
||||
/// Continue exploring optimizations using the [builder](crate::OptimizationBuilder).
|
||||
/// Continue exploring using the [builder](crate::OptimizationBuilder).
|
||||
Explore,
|
||||
/// The current policy indicates that an optimization may be possible in the future, so the
|
||||
/// The current policy indicates that an explocation may be possible in the future, so the
|
||||
/// best action is to defer any execution.
|
||||
///
|
||||
/// Sometimes, it can be a false positive and a new optimization should be built from scratch.
|
||||
/// Sometimes, it can be a false positive and a new exploration should be built from scratch.
|
||||
/// Therefore it's important to keep the previous operations to rebuild the state if it
|
||||
/// happens.
|
||||
Defer,
|
||||
/// An optimization has been found, and the best action is to execute it!
|
||||
Execute(OptimizationId),
|
||||
/// An exploration has been found, and the best action is to execute it!
|
||||
Execute(ExecutionPlanId),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
stream::{store::OptimizationItem, FloatOpsDescription, UnaryOpsDescription},
|
||||
stream::{
|
||||
store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},
|
||||
FloatOperationDescription, UnaryOperationDescription,
|
||||
},
|
||||
TensorDescription, TensorId, TensorStatus,
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
#[test]
|
||||
fn given_no_optimization_should_explore() {
|
||||
let store = OptimizationStore::default();
|
||||
let store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
let stream = TestStream::new(3);
|
||||
|
||||
|
@ -208,14 +211,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn given_existing_optimization_when_sync_should_execute_optim() {
|
||||
let mut store = OptimizationStore::default();
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
let stream = TestStream::new(2);
|
||||
|
||||
let id = store.add(OptimizationItem {
|
||||
stream: stream.operations.clone(),
|
||||
end_conditions: Vec::new(),
|
||||
value: (),
|
||||
let id = store.add(ExecutionPlan {
|
||||
operations: stream.operations.clone(),
|
||||
triggers: Vec::new(),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
|
||||
stream.assert_updates(
|
||||
|
@ -230,15 +233,18 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn given_existing_optimization_when_found_end_condition_should_execute_optim() {
|
||||
let mut store = OptimizationStore::default();
|
||||
fn given_existing_plan_when_found_trigger_should_execute_plan() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy = Policy::new();
|
||||
|
||||
let stream = TestStream::new(3);
|
||||
let id = store.add(OptimizationItem {
|
||||
stream: stream.operations[0..2].to_vec(),
|
||||
end_conditions: stream.operations[2..3].to_vec(),
|
||||
value: (),
|
||||
let id = store.add(ExecutionPlan {
|
||||
operations: stream.operations[0..2].to_vec(),
|
||||
triggers: stream.operations[2..3]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
|
||||
.collect(),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
|
||||
stream.assert_updates(
|
||||
|
@ -256,8 +262,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_multiple_end_conditions() {
|
||||
let mut store = OptimizationStore::default();
|
||||
fn should_support_multiple_triggers() {
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy_1 = Policy::new();
|
||||
let mut policy_2 = Policy::new();
|
||||
|
||||
|
@ -265,18 +271,18 @@ mod tests {
|
|||
let mut stream_2 = TestStream::new(2);
|
||||
|
||||
// Create different end operation for each stream.
|
||||
let end_condition_id_1 = 5;
|
||||
let end_condition_id_2 = 5;
|
||||
stream_1.new_ops(end_condition_id_1);
|
||||
stream_2.new_ops(end_condition_id_2);
|
||||
let trigger_id_1 = 5;
|
||||
let trigger_id_2 = 6;
|
||||
stream_1.new_ops(trigger_id_1);
|
||||
stream_2.new_ops(trigger_id_2);
|
||||
|
||||
let id = store.add(OptimizationItem {
|
||||
stream: stream_1.operations[0..2].to_vec(),
|
||||
end_conditions: vec![
|
||||
stream_1.operations[2].clone(),
|
||||
stream_2.operations[2].clone(),
|
||||
let id = store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..2].to_vec(),
|
||||
triggers: vec![
|
||||
ExecutionTrigger::OnOperation(stream_1.operations[2].clone()),
|
||||
ExecutionTrigger::OnOperation(stream_2.operations[2].clone()),
|
||||
],
|
||||
value: (),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
|
||||
stream_1.assert_updates(
|
||||
|
@ -295,20 +301,20 @@ mod tests {
|
|||
stream_1.assert_updates(
|
||||
&store,
|
||||
&mut policy_1,
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // First end condition.
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // First trigger.
|
||||
Action::Execute(id),
|
||||
);
|
||||
stream_2.assert_updates(
|
||||
&store,
|
||||
&mut policy_2,
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // Second end condition.
|
||||
AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger.
|
||||
Action::Execute(id),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_right_optimization() {
|
||||
let mut store = OptimizationStore::default();
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let mut policy_1 = Policy::new();
|
||||
let mut policy_2 = Policy::new();
|
||||
|
||||
|
@ -322,15 +328,21 @@ mod tests {
|
|||
stream_2.new_ops(5);
|
||||
stream_2.new_ops(6);
|
||||
|
||||
let optimization_stream_1 = store.add(OptimizationItem {
|
||||
stream: stream_1.operations[0..3].to_vec(),
|
||||
end_conditions: stream_1.operations[3..4].to_vec(),
|
||||
value: (),
|
||||
let optimization_stream_1 = store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..3].to_vec(),
|
||||
triggers: stream_1.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
|
||||
.collect(),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
let optimization_stream_2 = store.add(OptimizationItem {
|
||||
stream: stream_2.operations[0..3].to_vec(),
|
||||
end_conditions: stream_2.operations[3..4].to_vec(),
|
||||
value: (),
|
||||
let optimization_stream_2 = store.add(ExecutionPlan {
|
||||
operations: stream_2.operations[0..3].to_vec(),
|
||||
triggers: stream_2.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
|
||||
.collect(),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
assert_ne!(optimization_stream_1, optimization_stream_2);
|
||||
|
||||
|
@ -363,16 +375,19 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn should_invalidate_wrong_optimizations() {
|
||||
let mut store = OptimizationStore::default();
|
||||
let mut store = ExecutionPlanStore::default();
|
||||
let stream_1 = TestStream::new(4);
|
||||
let mut stream_2 = TestStream::new(2);
|
||||
stream_2.new_ops(6);
|
||||
stream_2.new_ops(7);
|
||||
|
||||
store.add(OptimizationItem {
|
||||
stream: stream_1.operations[0..3].to_vec(),
|
||||
end_conditions: stream_1.operations[3..4].to_vec(),
|
||||
value: (),
|
||||
store.add(ExecutionPlan {
|
||||
operations: stream_1.operations[0..3].to_vec(),
|
||||
triggers: stream_1.operations[3..4]
|
||||
.iter()
|
||||
.map(|desc| ExecutionTrigger::OnOperation(desc.clone()))
|
||||
.collect(),
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
});
|
||||
|
||||
let mut policy = Policy::new();
|
||||
|
@ -396,7 +411,7 @@ mod tests {
|
|||
#[derive(Default, Debug)]
|
||||
struct TestStream {
|
||||
tensors: Vec<TensorDescription>,
|
||||
operations: Vec<TensorOpsDescription>,
|
||||
operations: Vec<OperationDescription>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -418,7 +433,7 @@ mod tests {
|
|||
/// The first follow should only be cache miss.
|
||||
pub fn assert_updates(
|
||||
&self,
|
||||
optimizations: &OptimizationStore<()>,
|
||||
optimizations: &ExecutionPlanStore<()>,
|
||||
policy: &mut Policy<()>,
|
||||
options: AssertUpdatesOptions,
|
||||
action: Action,
|
||||
|
@ -448,7 +463,7 @@ mod tests {
|
|||
self.new_empty_node(out_id);
|
||||
|
||||
self.operations
|
||||
.push(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
|
||||
.push(OperationDescription::Float(FloatOperationDescription::Log(
|
||||
self.unary_description(),
|
||||
)));
|
||||
}
|
||||
|
@ -461,10 +476,10 @@ mod tests {
|
|||
});
|
||||
}
|
||||
|
||||
fn unary_description(&self) -> UnaryOpsDescription {
|
||||
fn unary_description(&self) -> UnaryOperationDescription {
|
||||
let size = self.tensors.len();
|
||||
|
||||
UnaryOpsDescription {
|
||||
UnaryOperationDescription {
|
||||
input: self.tensors[size - 2].clone(),
|
||||
out: self.tensors[size - 1].clone(),
|
||||
}
|
||||
|
|
|
@ -1,44 +1,53 @@
|
|||
use super::{ExecutionMode, Exploration, Explorer};
|
||||
use crate::stream::execution::{Action, Policy};
|
||||
use crate::stream::store::{OptimizationId, OptimizationItem, OptimizationStore};
|
||||
use crate::stream::{Stream, TensorOpsDescription};
|
||||
use crate::{FusionBackend, HandleContainer, OptimizationBuilder};
|
||||
use crate::stream::store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
};
|
||||
use crate::stream::OperationDescription;
|
||||
use crate::OptimizationBuilder;
|
||||
|
||||
/// Process the [stream](Stream) following a [policy](Policy).
|
||||
///
|
||||
/// Explore and create new optimizations using explorations
|
||||
pub(crate) struct Processor<B: FusionBackend> {
|
||||
policy: Policy<B::Optimization>,
|
||||
explorer: Explorer<B>,
|
||||
/// Process a [stream segment](StreamSegment) following a [policy](Policy).
|
||||
pub(crate) struct Processor<O> {
|
||||
policy: Policy<O>,
|
||||
explorer: Explorer<O>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Processor<B> {
|
||||
/// A part of a stream that can be executed partially using [execution plan](ExecutionPlan).
|
||||
pub(crate) trait StreamSegment<O> {
|
||||
/// The operations in the segment.
|
||||
fn operations(&self) -> &[OperationDescription];
|
||||
/// Execute part of the segment using the given plan id.
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<O>);
|
||||
}
|
||||
|
||||
impl<O> Processor<O> {
|
||||
/// Create a new stream processor.
|
||||
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
|
||||
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<O>>>) -> Self {
|
||||
Self {
|
||||
policy: Policy::new(),
|
||||
explorer: Explorer::new(optimizations),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the [stream](Stream) with the provided mode.
|
||||
pub fn process(
|
||||
/// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode).
|
||||
pub fn process<Segment>(
|
||||
&mut self,
|
||||
stream: &mut Stream<B>,
|
||||
optimizations: &mut OptimizationStore<B::Optimization>,
|
||||
handles: &mut HandleContainer<B>,
|
||||
mut segment: Segment,
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
mode: ExecutionMode,
|
||||
) {
|
||||
) where
|
||||
Segment: StreamSegment<O>,
|
||||
{
|
||||
loop {
|
||||
if stream.is_empty() {
|
||||
if segment.operations().is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
match self.action(optimizations, stream, mode) {
|
||||
match self.action(store, segment.operations(), mode) {
|
||||
Action::Explore => {
|
||||
self.explore(stream, optimizations, handles, mode);
|
||||
self.explore(&mut segment, store, mode);
|
||||
|
||||
if self.explorer.up_to_date() {
|
||||
if self.explorer.is_up_to_date() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -51,8 +60,12 @@ impl<B: FusionBackend> Processor<B> {
|
|||
};
|
||||
}
|
||||
Action::Execute(id) => {
|
||||
stream.execute(Some(id), handles, optimizations);
|
||||
self.reset(optimizations, stream);
|
||||
if let ExecutionMode::Sync = mode {
|
||||
store.add_trigger(id, ExecutionTrigger::OnSync);
|
||||
}
|
||||
|
||||
segment.execute(id, store);
|
||||
self.reset(store, segment.operations());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -62,21 +75,34 @@ impl<B: FusionBackend> Processor<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn explore(
|
||||
fn explore<Item: StreamSegment<O>>(
|
||||
&mut self,
|
||||
stream: &mut Stream<B>,
|
||||
optimizations: &mut OptimizationStore<B::Optimization>,
|
||||
handles: &mut HandleContainer<B>,
|
||||
item: &mut Item,
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
mode: ExecutionMode,
|
||||
) {
|
||||
match self.explorer.explore(stream, mode) {
|
||||
Exploration::OptimizationFound(optim) => {
|
||||
let id = optim.map(|optim| {
|
||||
Self::on_new_optimization(&self.policy, stream, optimizations, optim, mode)
|
||||
});
|
||||
|
||||
stream.execute(id, handles, optimizations);
|
||||
self.reset(optimizations, stream);
|
||||
match self.explorer.explore(item.operations(), mode) {
|
||||
Exploration::Found(optim) => {
|
||||
let id = Self::on_optimization_found(
|
||||
&self.policy,
|
||||
item.operations(),
|
||||
store,
|
||||
optim,
|
||||
mode,
|
||||
);
|
||||
item.execute(id, store);
|
||||
self.reset(store, item.operations());
|
||||
}
|
||||
Exploration::NotFound { num_explored } => {
|
||||
let id = Self::on_optimization_not_found(
|
||||
&self.policy,
|
||||
item.operations(),
|
||||
store,
|
||||
mode,
|
||||
num_explored,
|
||||
);
|
||||
item.execute(id, store);
|
||||
self.reset(store, item.operations());
|
||||
}
|
||||
Exploration::Continue => {
|
||||
if let ExecutionMode::Sync = mode {
|
||||
|
@ -86,87 +112,105 @@ impl<B: FusionBackend> Processor<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, store: &mut OptimizationStore<B::Optimization>, stream: &Stream<B>) {
|
||||
self.explorer.reset(stream);
|
||||
fn reset(&mut self, store: &mut ExecutionPlanStore<O>, operations: &[OperationDescription]) {
|
||||
self.explorer.reset(operations);
|
||||
self.policy.reset();
|
||||
|
||||
// Reset the policy state.
|
||||
for i in 0..stream.relative.len() {
|
||||
self.policy.update(store, &stream.relative[i]);
|
||||
for operation in operations.iter() {
|
||||
self.policy.update(store, operation);
|
||||
}
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
cache: &OptimizationStore<B::Optimization>,
|
||||
stream: &Stream<B>,
|
||||
store: &ExecutionPlanStore<O>,
|
||||
operations: &[OperationDescription],
|
||||
mode: ExecutionMode,
|
||||
) -> Action {
|
||||
let (stream, next_ops) = Self::split_stream_ref(stream, mode);
|
||||
if let ExecutionMode::Lazy = mode {
|
||||
// We update the policy in lazy mode, since
|
||||
self.policy.update(
|
||||
store,
|
||||
operations
|
||||
.last()
|
||||
.expect("At least one operation in the operation list."),
|
||||
);
|
||||
};
|
||||
|
||||
if let Some(next_ops) = next_ops {
|
||||
self.policy.update(cache, next_ops)
|
||||
}
|
||||
|
||||
self.policy.action(cache, stream, mode)
|
||||
self.policy.action(store, operations, mode)
|
||||
}
|
||||
|
||||
fn split_stream_owned(
|
||||
stream: &Stream<B>,
|
||||
fn on_optimization_found(
|
||||
policy: &Policy<O>,
|
||||
operations: &[OperationDescription],
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
builder: &dyn OptimizationBuilder<O>,
|
||||
mode: ExecutionMode,
|
||||
) -> (Vec<TensorOpsDescription>, Option<TensorOpsDescription>) {
|
||||
) -> ExecutionPlanId {
|
||||
let num_fused = builder.len();
|
||||
let relative = &operations[0..num_fused];
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Lazy => {
|
||||
let stream = stream.split_relative_stream();
|
||||
(stream.0.to_vec(), stream.1.cloned())
|
||||
}
|
||||
ExecutionMode::Sync => (stream.relative.clone(), None),
|
||||
}
|
||||
}
|
||||
let next_ops = operations.get(num_fused);
|
||||
|
||||
fn split_stream_ref(
|
||||
stream: &Stream<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
|
||||
match mode {
|
||||
ExecutionMode::Lazy => stream.split_relative_stream(),
|
||||
ExecutionMode::Sync => (stream.relative.as_slice(), None),
|
||||
}
|
||||
}
|
||||
let trigger = if let Some(next_ops) = next_ops {
|
||||
ExecutionTrigger::OnOperation(next_ops.clone())
|
||||
} else {
|
||||
// Happens if the next ops is included in the fused operation, and there is no
|
||||
// way the builder can still continue fusing.
|
||||
ExecutionTrigger::Always
|
||||
};
|
||||
|
||||
fn on_new_optimization(
|
||||
policy: &Policy<B::Optimization>,
|
||||
stream: &Stream<B>,
|
||||
store: &mut OptimizationStore<B::Optimization>,
|
||||
builder: &dyn OptimizationBuilder<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> OptimizationId {
|
||||
let (stream_relative, next_ops) = Self::split_stream_owned(stream, mode);
|
||||
|
||||
// Check if an optimization is available for this stream before creating a new optimization.
|
||||
//
|
||||
// Specify a sync execution mode signaling that we want to know if an optimization is
|
||||
// available right now even if it isn't the best one.
|
||||
match policy.action(store, &stream_relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
// When we are in lazy mode, a next operation will be available.
|
||||
//
|
||||
// Since we are adding new optimization only when the policy action is explore, we
|
||||
// know the existing optimization wasn't flagged as optimal, since the `next_ops'
|
||||
// wasn't included in the `end_conditions`.
|
||||
//
|
||||
// But in this case, we aren't able to actually find a better optimization, so we
|
||||
// flag the next ops as a stopping criteria, so we won't enter exploration mode the
|
||||
// next time we see a similar stream following the same pattern.
|
||||
if let Some(next_ops) = next_ops {
|
||||
store.add_end_condition(id, next_ops);
|
||||
match policy.action(store, relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
store.add_trigger(id, trigger);
|
||||
id
|
||||
}
|
||||
_ => store.add(ExecutionPlan {
|
||||
operations: relative.to_vec(),
|
||||
triggers: vec![trigger],
|
||||
strategy: ExecutionStrategy::Optimization(builder.build()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
store.add_trigger(id, ExecutionTrigger::OnSync);
|
||||
id
|
||||
}
|
||||
_ => store.add(ExecutionPlan {
|
||||
operations: operations.to_vec(),
|
||||
triggers: vec![ExecutionTrigger::OnSync],
|
||||
strategy: ExecutionStrategy::Optimization(builder.build()),
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn on_optimization_not_found(
|
||||
policy: &Policy<O>,
|
||||
operations: &[OperationDescription],
|
||||
store: &mut ExecutionPlanStore<O>,
|
||||
mode: ExecutionMode,
|
||||
num_explored: usize,
|
||||
) -> ExecutionPlanId {
|
||||
let relative = &operations[0..num_explored];
|
||||
let trigger = match mode {
|
||||
ExecutionMode::Lazy => ExecutionTrigger::Always,
|
||||
ExecutionMode::Sync => ExecutionTrigger::OnSync,
|
||||
};
|
||||
|
||||
match policy.action(store, relative, ExecutionMode::Sync) {
|
||||
Action::Execute(id) => {
|
||||
store.add_trigger(id, trigger);
|
||||
id
|
||||
}
|
||||
_ => store.add(OptimizationItem {
|
||||
stream: stream_relative,
|
||||
end_conditions: next_ops.map(|op| vec![op]).unwrap_or_default(),
|
||||
value: builder.build(),
|
||||
_ => store.add(ExecutionPlan {
|
||||
operations: relative.to_vec(),
|
||||
triggers: vec![trigger],
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,408 @@
|
|||
//! A testing module that ensures the correctness of the explorer, policy, and processor.
|
||||
//!
|
||||
//! The primary focus is on validating the seamless interaction between these three components to
|
||||
//! execute and optimize a stream of operations accurately.
|
||||
//!
|
||||
//! To test these components effectively, we create mock types for the stream, optimization,
|
||||
//! optimization builder, and stream segment. These mock types aid in comprehensively
|
||||
//! understanding the process of optimizing streams.
|
||||
use crate::{
|
||||
stream::{
|
||||
store::{
|
||||
ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,
|
||||
},
|
||||
BinaryOperationDescription, NumericOperationDescription, OperationDescription,
|
||||
ScalarOperationDescription,
|
||||
},
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
|
||||
TensorStatus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// A fake stream of operations for testing purpose.
|
||||
struct TestStream {
|
||||
processor: Processor<TestOptimization>,
|
||||
store: ExecutionPlanStore<TestOptimization>,
|
||||
executed: Vec<ExecutionPlanId>,
|
||||
operations: Vec<OperationDescription>,
|
||||
}
|
||||
|
||||
/// A fake [optimization builder](OptimizationBuilder) for testing purpose.
|
||||
struct TestOptimizationBuilder {
|
||||
builder_id: usize,
|
||||
expected_operations: Vec<OperationDescription>,
|
||||
expected_trigger: ExecutionTrigger,
|
||||
actual: Vec<OperationDescription>,
|
||||
}
|
||||
|
||||
/// A fake optimization for testing purpose.
|
||||
#[derive(new, Debug, PartialEq)]
|
||||
struct TestOptimization {
|
||||
builder_id: usize,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
/// A fake [stream segment](StreamSegment) for testing purpose.
|
||||
#[derive(new)]
|
||||
struct TestSegment<'i> {
|
||||
operations: &'i mut Vec<OperationDescription>,
|
||||
executed: &'i mut Vec<ExecutionPlanId>,
|
||||
}
|
||||
|
||||
/// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions.
|
||||
///
|
||||
/// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is
|
||||
/// crucial to verify that the stream's state is correctly updated when various cases occur consecutively.
|
||||
///
|
||||
/// Although it might complicate identifying the source of a bug in the code, having this comprehensive
|
||||
/// test case covers nearly all aspects of the implementation, while remaining easy to read and
|
||||
/// maintainable.
|
||||
#[test]
|
||||
fn should_support_complex_stream() {
|
||||
// We have 2 different optimization builders in this test case.
|
||||
let builder_id_1 = 0;
|
||||
let builder_id_2 = 1;
|
||||
|
||||
// We will have a total of 3 execution plans to execute.
|
||||
let plan_id_1 = 0;
|
||||
let plan_id_2 = 1;
|
||||
let plan_id_3 = 2;
|
||||
|
||||
// The first builder only contains 2 operations, and the optimization is always available when
|
||||
// the pattern is met.
|
||||
let builder_1 = TestOptimizationBuilder::new(
|
||||
builder_id_1,
|
||||
vec![operation_1(), operation_2()],
|
||||
ExecutionTrigger::Always,
|
||||
);
|
||||
// The second builder also contains 2 operations, but only becomes available when an operation
|
||||
// is met.
|
||||
let builder_2 = TestOptimizationBuilder::new(
|
||||
builder_id_2,
|
||||
vec![operation_2(), operation_2()],
|
||||
ExecutionTrigger::OnOperation(operation_1()),
|
||||
);
|
||||
|
||||
// We finally build the stream with those optimization builders.
|
||||
let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);
|
||||
|
||||
// builder_1 is still waiting to see next op is operation_2
|
||||
// builder_2 is closed because it's not the right operation
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(0);
|
||||
|
||||
// No optimization found for the first two operations.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(1);
|
||||
stream.assert_last_executed(plan_id_1);
|
||||
stream.assert_plan(
|
||||
plan_id_1,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_1()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
strategy: ExecutionStrategy::Operations,
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(1);
|
||||
|
||||
// Now we should trigger the first optimization builder.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(2);
|
||||
stream.assert_last_executed(plan_id_2);
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(2);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(2);
|
||||
|
||||
// Now we should trigger the second optimization builder.
|
||||
stream.add(operation_1());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(3);
|
||||
stream.assert_last_executed(plan_id_3);
|
||||
stream.assert_plan(
|
||||
plan_id_3,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_2(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::OnOperation(operation_1())],
|
||||
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_2, 2)),
|
||||
},
|
||||
);
|
||||
|
||||
// Now we should trigger the first optimization builder (second plan).
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(4);
|
||||
stream.assert_last_executed(plan_id_2);
|
||||
stream.assert_plan(
|
||||
plan_id_2,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_1(), operation_2()],
|
||||
triggers: vec![ExecutionTrigger::Always],
|
||||
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_1, 2)),
|
||||
},
|
||||
);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(1);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
// Nothing to execute.
|
||||
stream.add(operation_2());
|
||||
stream.assert_number_of_operations(2);
|
||||
stream.assert_number_of_executions(4);
|
||||
|
||||
// On sync we should execute all operations even if their trigger isn't met.
|
||||
// In this case the optimization from builder 2 (plan 3).
|
||||
stream.sync();
|
||||
stream.assert_number_of_operations(0);
|
||||
stream.assert_number_of_executions(5);
|
||||
stream.assert_last_executed(plan_id_3);
|
||||
stream.assert_plan(
|
||||
plan_id_3,
|
||||
ExecutionPlan {
|
||||
operations: vec![operation_2(), operation_2()],
|
||||
triggers: vec![
|
||||
ExecutionTrigger::OnOperation(operation_1()),
|
||||
ExecutionTrigger::OnSync, // We also add OnSync in the triggers.
|
||||
],
|
||||
strategy: ExecutionStrategy::Optimization(TestOptimization::new(builder_id_2, 2)),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
impl TestStream {
|
||||
/// Create a new stream with the given optimization builders.
|
||||
fn new(optimizations: Vec<Box<dyn OptimizationBuilder<TestOptimization>>>) -> Self {
|
||||
Self {
|
||||
processor: Processor::<TestOptimization>::new(optimizations),
|
||||
store: ExecutionPlanStore::<TestOptimization>::new(),
|
||||
executed: Vec::new(),
|
||||
operations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an operation to the stream.
|
||||
fn add(&mut self, operation: OperationDescription) {
|
||||
self.operations.push(operation);
|
||||
self.processor.process(
|
||||
TestSegment::new(&mut self.operations, &mut self.executed),
|
||||
&mut self.store,
|
||||
ExecutionMode::Lazy,
|
||||
);
|
||||
}
|
||||
|
||||
/// Sync the stream.
|
||||
fn sync(&mut self) {
|
||||
self.processor.process(
|
||||
TestSegment::new(&mut self.operations, &mut self.executed),
|
||||
&mut self.store,
|
||||
ExecutionMode::Sync,
|
||||
);
|
||||
}
|
||||
|
||||
/// Assert that the plan has been executed as provided.
|
||||
fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan<TestOptimization>) {
|
||||
let actual = self.store.get_unchecked(id);
|
||||
assert_eq!(actual.triggers, expected.triggers);
|
||||
assert_eq!(actual.operations, expected.operations);
|
||||
}
|
||||
|
||||
/// Assert that the given plan id has been the last executed.
|
||||
fn assert_last_executed(&self, id: ExecutionPlanId) {
|
||||
match self.executed.last() {
|
||||
Some(last_id) => assert_eq!(*last_id, id),
|
||||
None => panic!("No plan has been executed"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert the number of executions since the start of the stream.
|
||||
fn assert_number_of_executions(&self, number: usize) {
|
||||
assert_eq!(self.executed.len(), number);
|
||||
}
|
||||
|
||||
/// Assert the number of operations queued.
|
||||
fn assert_number_of_operations(&self, number: usize) {
|
||||
assert_eq!(self.operations.len(), number);
|
||||
}
|
||||
}
|
||||
|
||||
impl TestOptimizationBuilder {
|
||||
/// Create a new optimization builder that follows a pattern with a trigger.
|
||||
fn new(
|
||||
builder_id: usize,
|
||||
operations: Vec<OperationDescription>,
|
||||
trigger: ExecutionTrigger,
|
||||
) -> Self {
|
||||
Self {
|
||||
builder_id,
|
||||
expected_operations: operations,
|
||||
actual: Vec::new(),
|
||||
expected_trigger: trigger,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OptimizationBuilder<TestOptimization> for TestOptimizationBuilder {
|
||||
/// Register a new operation.
|
||||
fn register(&mut self, operation: &OperationDescription) {
|
||||
self.actual.push(operation.clone());
|
||||
}
|
||||
|
||||
/// Build the optimization.
|
||||
fn build(&self) -> TestOptimization {
|
||||
TestOptimization::new(self.builder_id, self.len())
|
||||
}
|
||||
|
||||
/// Reset the state.
|
||||
fn reset(&mut self) {
|
||||
self.actual.clear();
|
||||
}
|
||||
|
||||
/// Return the optimization status.
|
||||
fn status(&self) -> OptimizationStatus {
|
||||
let actual_equal_expected = self.actual == self.expected_operations;
|
||||
|
||||
if self.actual.len() < self.expected_operations.len() {
|
||||
let operations = &self.expected_operations[0..self.actual.len()];
|
||||
|
||||
return match self.actual == operations {
|
||||
// Still optimizing.
|
||||
true => OptimizationStatus::Open,
|
||||
// Never gonna be possible on that stream.
|
||||
false => OptimizationStatus::Closed,
|
||||
};
|
||||
}
|
||||
|
||||
if self.actual.len() == self.expected_operations.len() && actual_equal_expected {
|
||||
return match self.expected_trigger {
|
||||
// Stop right away.
|
||||
ExecutionTrigger::Always => OptimizationStatus::Closed,
|
||||
// Wait for the next operation to show up.
|
||||
ExecutionTrigger::OnOperation(_) => OptimizationStatus::Open,
|
||||
// Doesn't matter on sync, even open should trigger a build if possible.
|
||||
ExecutionTrigger::OnSync => OptimizationStatus::Open,
|
||||
};
|
||||
}
|
||||
|
||||
OptimizationStatus::Closed
|
||||
}
|
||||
|
||||
/// Return the properties of this optimization.
|
||||
fn properties(&self) -> OptimizationProperties {
|
||||
if self.actual.len() < self.expected_operations.len() {
|
||||
// Optimization not possible.
|
||||
return OptimizationProperties {
|
||||
score: 0,
|
||||
ready: false,
|
||||
};
|
||||
}
|
||||
|
||||
let stream_is_ok =
|
||||
self.actual[0..self.expected_operations.len()] == self.expected_operations;
|
||||
|
||||
if !stream_is_ok {
|
||||
// Optimization not possible.
|
||||
return OptimizationProperties {
|
||||
score: 0,
|
||||
ready: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Optimization possible.
|
||||
OptimizationProperties {
|
||||
score: 1,
|
||||
ready: true,
|
||||
}
|
||||
}
|
||||
|
||||
// The number of operations that should be handle by the optimization.
|
||||
fn len(&self) -> usize {
|
||||
self.expected_operations.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'i> StreamSegment<TestOptimization> for TestSegment<'i> {
|
||||
// The operations in the process.
|
||||
fn operations(&self) -> &[OperationDescription] {
|
||||
self.operations
|
||||
}
|
||||
|
||||
// Execute the process.
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<TestOptimization>) {
|
||||
let execution_plan = store.get_unchecked(id);
|
||||
|
||||
match &execution_plan.strategy {
|
||||
ExecutionStrategy::Optimization(optimization) => {
|
||||
self.operations.drain(0..optimization.size);
|
||||
}
|
||||
ExecutionStrategy::Operations => self.operations.clear(),
|
||||
};
|
||||
|
||||
self.executed.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
fn operation_1() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Add(
|
||||
BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
rhs: TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
fn operation_2() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
|
||||
ScalarOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
rhs: 5.0,
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
},
|
||||
))
|
||||
}
|
|
@ -4,9 +4,9 @@ pub(crate) mod store;
|
|||
mod base;
|
||||
mod context;
|
||||
mod multi;
|
||||
mod ops;
|
||||
mod operation;
|
||||
|
||||
pub use base::*;
|
||||
pub use context::*;
|
||||
pub use multi::*;
|
||||
pub use ops::*;
|
||||
pub use operation::*;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::{
|
||||
execution::{ExecutionMode, Processor},
|
||||
store::OptimizationStore,
|
||||
Ops, Stream, TensorOpsDescription,
|
||||
execution::{ExecutionMode, Processor, StreamSegment},
|
||||
store::{ExecutionPlanId, ExecutionPlanStore},
|
||||
Operation, OperationDescription, OperationQueue,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer};
|
||||
|
||||
|
@ -9,37 +9,31 @@ use crate::{FusionBackend, HandleContainer};
|
|||
///
|
||||
/// TODO: Actually support multiple streams.
|
||||
pub struct MultiStream<B: FusionBackend> {
|
||||
items: Vec<Item<B>>,
|
||||
optimizations: OptimizationStore<B::Optimization>,
|
||||
}
|
||||
|
||||
struct Item<B: FusionBackend> {
|
||||
stream: Stream<B>,
|
||||
executor: Processor<B>,
|
||||
streams: Vec<Stream<B>>,
|
||||
optimizations: ExecutionPlanStore<B::Optimization>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> MultiStream<B> {
|
||||
pub(crate) fn new(device: B::FusionDevice) -> Self {
|
||||
Self {
|
||||
items: vec![Item::new(device)],
|
||||
optimizations: OptimizationStore::new(),
|
||||
streams: vec![Stream::new(device)],
|
||||
optimizations: ExecutionPlanStore::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new tensor operation.
|
||||
pub fn register(
|
||||
&mut self,
|
||||
ops_desc: TensorOpsDescription,
|
||||
ops: Box<dyn Ops<B>>,
|
||||
desc: OperationDescription,
|
||||
operation: Box<dyn Operation<B>>,
|
||||
handles: &mut HandleContainer<B>,
|
||||
) {
|
||||
// TODO: Support more than only one stream.
|
||||
if let Some(item) = self.items.first_mut() {
|
||||
item.stream.add(ops_desc, ops);
|
||||
item.executor.process(
|
||||
&mut item.stream,
|
||||
if let Some(item) = self.streams.first_mut() {
|
||||
item.queue.add(desc, operation);
|
||||
item.processor.process(
|
||||
Segment::new(&mut item.queue, handles),
|
||||
&mut self.optimizations,
|
||||
handles,
|
||||
ExecutionMode::Lazy,
|
||||
);
|
||||
};
|
||||
|
@ -47,22 +41,42 @@ impl<B: FusionBackend> MultiStream<B> {
|
|||
|
||||
/// Drain the streams.
|
||||
pub fn drain(&mut self, handles: &mut HandleContainer<B>) {
|
||||
self.items.iter_mut().for_each(|item| {
|
||||
item.executor.process(
|
||||
&mut item.stream,
|
||||
self.streams.iter_mut().for_each(|item| {
|
||||
item.processor.process(
|
||||
Segment::new(&mut item.queue, handles),
|
||||
&mut self.optimizations,
|
||||
handles,
|
||||
ExecutionMode::Sync,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Item<B> {
|
||||
struct Stream<B: FusionBackend> {
|
||||
queue: OperationQueue<B>,
|
||||
processor: Processor<B::Optimization>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct Segment<'a, B: FusionBackend> {
|
||||
queue: &'a mut OperationQueue<B>,
|
||||
handles: &'a mut HandleContainer<B>,
|
||||
}
|
||||
|
||||
impl<'i, B: FusionBackend> StreamSegment<B::Optimization> for Segment<'i, B> {
|
||||
fn operations(&self) -> &[OperationDescription] {
|
||||
&self.queue.relative
|
||||
}
|
||||
|
||||
fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<B::Optimization>) {
|
||||
self.queue.execute(id, self.handles, store)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Stream<B> {
|
||||
fn new(device: B::FusionDevice) -> Self {
|
||||
Self {
|
||||
executor: Processor::new(B::optimizations(device.into())),
|
||||
stream: Stream::new(),
|
||||
processor: Processor::new(B::optimizations(device.into())),
|
||||
queue: OperationQueue::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,111 @@
|
|||
use super::{ExecutionPlanIndex, InsertQuery, SearchQuery};
|
||||
use crate::stream::OperationDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The store that contains all explorations done on a device.
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub(crate) struct ExecutionPlanStore<O> {
|
||||
plans: Vec<ExecutionPlan<O>>,
|
||||
index: ExecutionPlanIndex,
|
||||
}
|
||||
|
||||
/// How a list of operations should be executed.
|
||||
#[derive(PartialEq, Debug, Serialize, Deserialize, Clone)]
|
||||
pub(crate) enum ExecutionStrategy<O> {
|
||||
/// An optimization was found, and therefore should be executed.
|
||||
Optimization(O),
|
||||
/// No optimization was found, each operation should be executed individually.
|
||||
Operations,
|
||||
}
|
||||
|
||||
/// The trigger that indicates when to stop exploring.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
// Triggers are stored in a list, and you can have many `OnOperation` entries,
|
||||
// but only one `OnSync` entry and one `Always` entry, therefore we don't care if it takes more
|
||||
// space to store them.
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub(crate) enum ExecutionTrigger {
|
||||
OnOperation(OperationDescription),
|
||||
OnSync,
|
||||
Always,
|
||||
}
|
||||
|
||||
/// The unique identifier for an exploration that was executed.
|
||||
pub(crate) type ExecutionPlanId = usize;
|
||||
|
||||
/// The outcome of an exploration that can be stored.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct ExecutionPlan<O> {
|
||||
/// The operations on which the exploration is related to.
|
||||
pub(crate) operations: Vec<OperationDescription>,
|
||||
/// The criteria that signal when this plan should be executed. Only one trigger is necessary.
|
||||
pub(crate) triggers: Vec<ExecutionTrigger>,
|
||||
/// The strategy that should be used when executing this plan.
|
||||
pub(crate) strategy: ExecutionStrategy<O>,
|
||||
}
|
||||
|
||||
impl<O> ExecutionPlan<O> {
|
||||
/// Whether exploration should be stop in an async mode.
|
||||
pub fn should_stop_async(&self, ops: &OperationDescription) -> bool {
|
||||
for item in self.triggers.iter() {
|
||||
match item {
|
||||
ExecutionTrigger::OnOperation(val) => {
|
||||
if val == ops {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
ExecutionTrigger::Always => return true,
|
||||
ExecutionTrigger::OnSync => continue,
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<O> ExecutionPlanStore<O> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
plans: Vec::new(),
|
||||
index: ExecutionPlanIndex::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
|
||||
self.index.find(query)
|
||||
}
|
||||
|
||||
pub fn add(&mut self, exploration: ExecutionPlan<O>) -> ExecutionPlanId {
|
||||
if exploration.operations.is_empty() {
|
||||
panic!("Can't add an empty optimization.");
|
||||
}
|
||||
|
||||
let id = self.plans.len();
|
||||
|
||||
self.index.insert(InsertQuery::NewPlan {
|
||||
operations: &exploration.operations,
|
||||
id,
|
||||
});
|
||||
|
||||
self.plans.push(exploration);
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan<O> {
|
||||
&mut self.plans[id]
|
||||
}
|
||||
|
||||
pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan<O> {
|
||||
&self.plans[id]
|
||||
}
|
||||
|
||||
/// Add a new end condition for an optimization.
|
||||
pub fn add_trigger(&mut self, id: ExecutionPlanId, criterion: ExecutionTrigger) {
|
||||
let criteria = &mut self.plans[id].triggers;
|
||||
|
||||
if !criteria.contains(&criterion) {
|
||||
criteria.push(criterion);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::stream::{store::OptimizationId, TensorOpsDescription};
|
||||
use crate::stream::{store::ExecutionPlanId, OperationDescription};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
|
@ -7,52 +7,51 @@ use std::{
|
|||
|
||||
/// Index used to search optimizations.
|
||||
#[derive(Default, Serialize, Deserialize, Clone)]
|
||||
pub struct OptimizationIndex {
|
||||
/// We can't use `HashMap<TensorOpsDescription, Vec<OptimizationId>>` since `TensorOpsDescription`
|
||||
pub struct ExecutionPlanIndex {
|
||||
/// We can't use `HashMap<OperationDescription, Vec<ExecutionPlanId>>` since `OperationDescription`
|
||||
/// doesn't implement [`Eq`](core::cmp::Eq).
|
||||
///
|
||||
/// `TensorOpsDescription` can't implement `Eq` since float types don't implement it.
|
||||
/// `OperationDescription` can't implement `Eq` since float types don't implement it.
|
||||
///
|
||||
/// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions.
|
||||
/// This is OK because we use `relative` streams where any scalar values are set to zeros,
|
||||
/// This is OK because we use `relative` operations where any scalar values are set to zeros,
|
||||
/// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter).
|
||||
mapping: HashMap<u64, Vec<(TensorOpsDescription, usize)>>,
|
||||
starters: Vec<Vec<OptimizationId>>,
|
||||
mapping: HashMap<u64, Vec<(OperationDescription, usize)>>,
|
||||
starters: Vec<Vec<ExecutionPlanId>>,
|
||||
}
|
||||
|
||||
pub enum SearchQuery<'a> {
|
||||
OptimizationsStartingWith(&'a TensorOpsDescription),
|
||||
PlansStartingWith(&'a OperationDescription),
|
||||
}
|
||||
|
||||
pub enum InsertQuery<'a> {
|
||||
NewOptimization {
|
||||
stream: &'a [TensorOpsDescription],
|
||||
id: OptimizationId,
|
||||
NewPlan {
|
||||
operations: &'a [OperationDescription],
|
||||
id: ExecutionPlanId,
|
||||
},
|
||||
}
|
||||
|
||||
impl OptimizationIndex {
|
||||
impl ExecutionPlanIndex {
|
||||
/// Search optimizations with the given [query](SearchQuery).
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<OptimizationId> {
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {
|
||||
match query {
|
||||
SearchQuery::OptimizationsStartingWith(ops) => self.find_starting_with(ops),
|
||||
SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new optimization with the given [query](InsertQuery).
|
||||
pub fn insert(&mut self, query: InsertQuery<'_>) {
|
||||
match query {
|
||||
InsertQuery::NewOptimization { stream, id } => self.insert_new_ops(
|
||||
stream
|
||||
.first()
|
||||
.expect("An optimization should never have an empty stream."),
|
||||
id,
|
||||
),
|
||||
InsertQuery::NewPlan { operations, id } => {
|
||||
if let Some(operation) = operations.first() {
|
||||
self.insert_new_operation(operation, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_starting_with(&self, ops: &TensorOpsDescription) -> Vec<OptimizationId> {
|
||||
let key = self.stream_key(ops);
|
||||
fn find_starting_with(&self, operation: &OperationDescription) -> Vec<ExecutionPlanId> {
|
||||
let key = self.operation_key(operation);
|
||||
let values = match self.mapping.get(&key) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
|
@ -62,7 +61,7 @@ impl OptimizationIndex {
|
|||
return Vec::new();
|
||||
}
|
||||
|
||||
let (_, index) = match values.iter().find(|value| &value.0 == ops) {
|
||||
let (_, index) = match values.iter().find(|value| &value.0 == operation) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
@ -75,8 +74,8 @@ impl OptimizationIndex {
|
|||
val
|
||||
}
|
||||
|
||||
fn insert_new_ops(&mut self, ops: &TensorOpsDescription, new_id: OptimizationId) {
|
||||
let key = self.stream_key(ops);
|
||||
fn insert_new_operation(&mut self, ops: &OperationDescription, new_id: ExecutionPlanId) {
|
||||
let key = self.operation_key(ops);
|
||||
let values = match self.mapping.get_mut(&key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
|
@ -106,8 +105,8 @@ impl OptimizationIndex {
|
|||
.push(new_id);
|
||||
}
|
||||
|
||||
// Hash the value of the first operation in a stream.
|
||||
fn stream_key(&self, ops: &TensorOpsDescription) -> u64 {
|
||||
// Hash the value of the first operation in a list.
|
||||
fn operation_key(&self, ops: &OperationDescription) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
ops.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
|
@ -118,80 +117,82 @@ impl OptimizationIndex {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
stream::{BinaryOpsDescription, NumericOpsDescription, ScalarOpsDescription},
|
||||
stream::{
|
||||
BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription,
|
||||
},
|
||||
TensorDescription, TensorId, TensorStatus,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn should_find_optimization_id_based_on_tensor_ops() {
|
||||
let mut index = OptimizationIndex::default();
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_1,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_multiple_optimization_ids_with_same_starting_ops() {
|
||||
let mut index = OptimizationIndex::default();
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_2(), ops_1()];
|
||||
let stream_2 = [ops_1(), ops_1(), ops_2()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_1,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_2,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1, optimization_id_2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_only_find_optimization_with_correct_starting_ops() {
|
||||
let mut index = OptimizationIndex::default();
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_1()];
|
||||
let stream_2 = [ops_2(), ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_1,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_2,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_handle_hash_collisions() {
|
||||
let mut index = OptimizationIndex::default();
|
||||
let mut index = ExecutionPlanIndex::default();
|
||||
let stream_1 = [ops_1(), ops_1()];
|
||||
let stream_2 = [ops_3(), ops_1()];
|
||||
let optimization_id_1 = 0;
|
||||
let optimization_id_2 = 1;
|
||||
|
||||
let stream_1_key = index.stream_key(&stream_1[0]);
|
||||
let stream_2_key = index.stream_key(&stream_2[0]);
|
||||
let stream_1_key = index.operation_key(&stream_1[0]);
|
||||
let stream_2_key = index.operation_key(&stream_2[0]);
|
||||
|
||||
assert_eq!(
|
||||
stream_1_key, stream_2_key,
|
||||
|
@ -199,43 +200,45 @@ mod tests {
|
|||
);
|
||||
assert_ne!(stream_1[0], stream_2[0], "Ops 1 and Ops 3 are different.");
|
||||
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_1,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_1,
|
||||
id: optimization_id_1,
|
||||
});
|
||||
index.insert(InsertQuery::NewOptimization {
|
||||
stream: &stream_2,
|
||||
index.insert(InsertQuery::NewPlan {
|
||||
operations: &stream_2,
|
||||
id: optimization_id_2,
|
||||
});
|
||||
|
||||
let found = index.find(SearchQuery::OptimizationsStartingWith(&stream_1[0]));
|
||||
let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));
|
||||
|
||||
assert_eq!(found, vec![optimization_id_1]);
|
||||
}
|
||||
|
||||
fn ops_1() -> TensorOpsDescription {
|
||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::Add(BinaryOpsDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
fn ops_1() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Add(
|
||||
BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
rhs: TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
},
|
||||
rhs: TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
}))
|
||||
))
|
||||
}
|
||||
|
||||
fn ops_2() -> TensorOpsDescription {
|
||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::AddScalar(
|
||||
ScalarOpsDescription {
|
||||
fn ops_2() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
|
||||
ScalarOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -251,23 +254,25 @@ mod tests {
|
|||
))
|
||||
}
|
||||
|
||||
fn ops_3() -> TensorOpsDescription {
|
||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::Sub(BinaryOpsDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
fn ops_3() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Sub(
|
||||
BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
rhs: TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
},
|
||||
rhs: TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(2),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
},
|
||||
}))
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mod base;
|
||||
mod index;
|
||||
mod optimization;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(super) use index::*;
|
||||
pub(crate) use optimization::*;
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
use super::{InsertQuery, OptimizationIndex, SearchQuery};
|
||||
use crate::stream::TensorOpsDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub(crate) struct OptimizationStore<O> {
|
||||
pub(super) optimizations: Vec<OptimizationItem<O>>,
|
||||
pub(super) index: OptimizationIndex,
|
||||
}
|
||||
|
||||
pub(crate) type OptimizationId = usize;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct OptimizationItem<O> {
|
||||
pub(crate) stream: Vec<TensorOpsDescription>,
|
||||
pub(crate) end_conditions: Vec<TensorOpsDescription>,
|
||||
pub(crate) value: O,
|
||||
}
|
||||
|
||||
impl<O> OptimizationStore<O> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
optimizations: Vec::new(),
|
||||
index: OptimizationIndex::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, query: SearchQuery<'_>) -> Vec<OptimizationId> {
|
||||
self.index.find(query)
|
||||
}
|
||||
|
||||
pub fn add(&mut self, optimization: OptimizationItem<O>) -> OptimizationId {
|
||||
let id = self.optimizations.len();
|
||||
|
||||
self.index.insert(InsertQuery::NewOptimization {
|
||||
stream: &optimization.stream,
|
||||
id,
|
||||
});
|
||||
|
||||
self.optimizations.push(optimization);
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
pub fn get_mut_unchecked(&mut self, id: OptimizationId) -> &mut OptimizationItem<O> {
|
||||
&mut self.optimizations[id]
|
||||
}
|
||||
|
||||
pub fn get_unchecked(&self, id: OptimizationId) -> &OptimizationItem<O> {
|
||||
&self.optimizations[id]
|
||||
}
|
||||
|
||||
/// Add a new end condition for an optimization.
|
||||
pub fn add_end_condition(&mut self, id: OptimizationId, end_condition: TensorOpsDescription) {
|
||||
self.optimizations[id].end_conditions.push(end_condition)
|
||||
}
|
||||
}
|
|
@ -82,7 +82,9 @@ where
|
|||
type Handle = WgpuFusionHandle;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
||||
fn optimizations(device: WgpuDevice) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self>>> {
|
||||
fn optimizations(
|
||||
device: WgpuDevice,
|
||||
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
|
||||
vec![Box::new(ElementWiseBuilder::new(device))]
|
||||
}
|
||||
|
||||
|
|
|
@ -7,8 +7,9 @@ use crate::{
|
|||
};
|
||||
use burn_fusion::{
|
||||
stream::{
|
||||
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
|
||||
BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription,
|
||||
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
|
||||
};
|
||||
|
@ -35,43 +36,43 @@ where
|
|||
pub(crate) device: Device<Wgpu<G, F, I>>,
|
||||
}
|
||||
|
||||
impl<G, F, I> OptimizationBuilder<Wgpu<G, F, I>> for ElementWiseBuilder<G, F, I>
|
||||
impl<G, F, I> OptimizationBuilder<WgpuOptimization<G, F, I>> for ElementWiseBuilder<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn register(&mut self, ops: &TensorOpsDescription) {
|
||||
fn register(&mut self, ops: &OperationDescription) {
|
||||
if let OptimizationStatus::Closed = self.status {
|
||||
return;
|
||||
}
|
||||
|
||||
match ops {
|
||||
TensorOpsDescription::BaseOpsFloat(ops) => {
|
||||
OperationDescription::BaseFloat(ops) => {
|
||||
if !self.register_base::<F>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::BaseOpsInt(ops) => {
|
||||
OperationDescription::BaseInt(ops) => {
|
||||
if !self.register_base::<I>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::FloatOps(ops) => {
|
||||
OperationDescription::Float(ops) => {
|
||||
if !self.register_float::<F>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::NumericOpsFloat(ops) => {
|
||||
OperationDescription::NumericFloat(ops) => {
|
||||
if !self.register_numeric::<F, _>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::NumericOpsInt(ops) => {
|
||||
OperationDescription::NumericInt(ops) => {
|
||||
if !self.register_numeric::<I, _>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
|
@ -107,6 +108,10 @@ where
|
|||
WgpuOptimization::ElementWise(op.compile())
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.inputs.clear();
|
||||
self.locals.drain();
|
||||
|
@ -394,9 +399,9 @@ where
|
|||
Variable::Local(local_index, Item::Scalar(elem))
|
||||
}
|
||||
|
||||
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
|
||||
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOperationDescription) -> bool {
|
||||
match ops {
|
||||
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
|
||||
BaseOperationDescription::Equal(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||
|
@ -405,49 +410,49 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
|
||||
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOperationDescription) -> bool {
|
||||
match ops {
|
||||
FloatOpsDescription::Exp(desc) => {
|
||||
FloatOperationDescription::Exp(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Exp { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Log(desc) => {
|
||||
FloatOperationDescription::Log(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Log { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Log1p(desc) => {
|
||||
FloatOperationDescription::Log1p(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Log1p { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Cos(desc) => {
|
||||
FloatOperationDescription::Cos(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Cos { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Sin(desc) => {
|
||||
FloatOperationDescription::Sin(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Sin { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
|
||||
FloatOperationDescription::Powf(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
|
||||
),
|
||||
FloatOpsDescription::Tanh(desc) => {
|
||||
FloatOperationDescription::Tanh(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Tanh { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Erf(desc) => {
|
||||
FloatOperationDescription::Erf(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Erf { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Recip(desc) => {
|
||||
FloatOperationDescription::Recip(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Recip { input, out }
|
||||
})
|
||||
|
@ -458,100 +463,100 @@ where
|
|||
|
||||
fn register_numeric<E: WgpuElement, EDesc: WgpuElement>(
|
||||
&mut self,
|
||||
ops: &NumericOpsDescription<EDesc>,
|
||||
ops: &NumericOperationDescription<EDesc>,
|
||||
) -> bool {
|
||||
match ops {
|
||||
NumericOpsDescription::Add(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Add(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Sub(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Sub(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Mul(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Mul(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Div(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Div(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Abs(desc) => {
|
||||
NumericOperationDescription::Abs(desc) => {
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Abs { input, out }
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Lower(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::Greater(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
|
||||
NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
|
||||
NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::MaskWhere(desc) => {
|
||||
NumericOperationDescription::MaskWhere(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
@ -571,7 +576,7 @@ where
|
|||
|
||||
true
|
||||
}
|
||||
NumericOpsDescription::MaskFill(desc) => {
|
||||
NumericOperationDescription::MaskFill(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
@ -596,7 +601,7 @@ where
|
|||
|
||||
fn register_binary_ops<Func>(
|
||||
&mut self,
|
||||
desc: &BinaryOpsDescription,
|
||||
desc: &BinaryOperationDescription,
|
||||
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
|
@ -618,7 +623,7 @@ where
|
|||
|
||||
fn register_unary_ops<Func>(
|
||||
&mut self,
|
||||
desc: &UnaryOpsDescription,
|
||||
desc: &UnaryOperationDescription,
|
||||
(elem_input, elem_out): (Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
|
@ -639,7 +644,7 @@ where
|
|||
|
||||
fn register_scalar_ops<Func, E: Element>(
|
||||
&mut self,
|
||||
desc: &ScalarOpsDescription<E>,
|
||||
desc: &ScalarOperationDescription<E>,
|
||||
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
|
|
|
@ -283,7 +283,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_fusion::stream::Ops;
|
||||
use burn_fusion::stream::Operation;
|
||||
use burn_fusion::{Fusion, FusionBackend};
|
||||
use burn_tensor::Int;
|
||||
use burn_tensor::{backend::Backend, Data, Tensor};
|
||||
|
@ -419,7 +419,7 @@ mod tests {
|
|||
|
||||
struct FakeAddOps;
|
||||
|
||||
impl<B: FusionBackend> Ops<B> for FakeAddOps {
|
||||
impl<B: FusionBackend> Operation<B> for FakeAddOps {
|
||||
fn execute(self: Box<Self>, _: &mut burn_fusion::HandleContainer<B>) {
|
||||
panic!("Should always fused during tests.")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue