mirror of https://github.com/tracel-ai/burn.git
Feat/fusion/cache (#1020)
This commit is contained in:
parent
b0de56da29
commit
670280dda2
|
@ -87,7 +87,7 @@ fn erf_positive<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
|
|||
fn bench<B: Backend>(device: &B::Device) {
|
||||
const D: usize = 3;
|
||||
let shape: Shape<D> = [32, 512, 2048].into();
|
||||
let num_repeats = 1;
|
||||
let num_repeats = 10;
|
||||
|
||||
let reference_gelu = CustomGeluBenchmark::<B, D>::new(
|
||||
shape.clone(),
|
||||
|
|
|
@ -17,5 +17,6 @@ std = []
|
|||
[dependencies]
|
||||
burn-tensor = {path = "../burn-tensor", version = "0.11.0", default-features = false }
|
||||
burn-common = {path = "../burn-common", version = "0.11.0" }
|
||||
hashbrown = { workspace = true }
|
||||
derive-new = {workspace = true}
|
||||
spin = {workspace = true}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use crate::{
|
||||
client::FusionClient, graph::TensorOpsDescription, FusionClientLocator, FusionTensor,
|
||||
HandleContainer,
|
||||
client::FusionClient,
|
||||
graph::{Context, OptimizationFactory, TensorOpsDescription},
|
||||
FusionClientLocator, FusionTensor,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Device, Shape};
|
||||
use core::marker::PhantomData;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();
|
||||
|
||||
|
@ -49,17 +50,18 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The status of a [fusion ops](FusionOps).
|
||||
pub enum FusionStatus {
|
||||
/// The status of a [builder](OptimizationBuilder).
|
||||
#[derive(Clone, Debug, Copy)]
|
||||
pub enum OptimizationStatus {
|
||||
/// No more operations can be fused.
|
||||
Closed(FusionProperties),
|
||||
Closed,
|
||||
/// More operations can be fused.
|
||||
Open(FusionProperties),
|
||||
Open,
|
||||
}
|
||||
|
||||
/// The properties of a [fusion ops](FusionOps).
|
||||
/// The properties of a [builder](OptimizationProperties).
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct FusionProperties {
|
||||
pub struct OptimizationProperties {
|
||||
/// The score of the optimization, higher is better.
|
||||
pub score: u64,
|
||||
/// If the operation is ready to be executed.
|
||||
|
@ -78,28 +80,42 @@ pub struct FusionProperties {
|
|||
///
|
||||
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
|
||||
/// improve the performance.
|
||||
pub trait FusionOps<B: FusionBackend>: Send {
|
||||
pub trait OptimizationBuilder<B: FusionBackend>: Send {
|
||||
/// Register a new [tensor operation](TensorOpsDescription).
|
||||
///
|
||||
/// The return value should be either [closed](FusionStatus::Closed) or
|
||||
/// [open](FusionStatus::Open).
|
||||
///
|
||||
/// When [closed](FusionStatus::Closed), it's assumed that no more operation can be added
|
||||
/// to the current fusion operation. No [tensor operation](TensorOpsDescription) can be
|
||||
/// ignored, they are either accepted or rejected, and the [status](FusionStatus) describes it.
|
||||
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus;
|
||||
/// Execute the operation.
|
||||
fn execute(&mut self, handles: &mut HandleContainer<B>);
|
||||
fn register(&mut self, ops: &TensorOpsDescription);
|
||||
/// Finish the optimization and create a fusion operation.
|
||||
fn build(&self) -> Box<dyn Optimization<B>>;
|
||||
/// Reset the state.
|
||||
fn reset(&mut self);
|
||||
/// The size of operations fused.
|
||||
/// Return the builder [status](OptimizationStatus).
|
||||
fn status(&self) -> OptimizationStatus;
|
||||
/// Return the builder [properties](OptimizationProperties).
|
||||
fn properties(&self) -> OptimizationProperties;
|
||||
}
|
||||
|
||||
/// The operation created from the [builder](OptimizationBuilder).
|
||||
pub trait Optimization<B: FusionBackend>: Send {
|
||||
/// Execute the operation.
|
||||
fn execute(&self, context: &mut Context<'_, B>);
|
||||
/// The number of registered operations in this optimization.
|
||||
fn len(&self) -> usize;
|
||||
/// If the current operation is empty.
|
||||
/// If the current optimization is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
// We implement the OptimizationFactory for all boxed optimization to be used with the Optimization
|
||||
// Cache. The factory is only used to simplify types and allows better testing. It isn't a public
|
||||
// crate.
|
||||
impl<B: FusionBackend> OptimizationFactory<Box<dyn Optimization<B>>>
|
||||
for Box<dyn OptimizationBuilder<B>>
|
||||
{
|
||||
fn create(&self) -> Box<dyn Optimization<B>> {
|
||||
OptimizationBuilder::build(self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
/// The device id.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
|
||||
pub struct DeviceId {
|
||||
|
@ -116,7 +132,7 @@ pub trait FusionDevice: Clone + Send + Sync + PartialEq {
|
|||
}
|
||||
|
||||
/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
|
||||
/// [fusion operation](crate::FusionOps).
|
||||
/// [operation builder](crate::OptimizationBuilder).
|
||||
pub trait FusionBackend: Backend {
|
||||
/// The device type that can return an ID.
|
||||
///
|
||||
|
@ -127,8 +143,8 @@ pub trait FusionBackend: Backend {
|
|||
/// What kind of client should be used.
|
||||
type FusionClient: FusionClient<FusionBackend = Self>;
|
||||
|
||||
/// The list of operations that will be used to optimize the computational graph.
|
||||
fn operations(device: &Device<Self>) -> Vec<Box<dyn FusionOps<Self>>>;
|
||||
/// The list of optimizations that will be used to optimize the computational graph.
|
||||
fn optimizations(device: &Device<Self>) -> Vec<Box<dyn OptimizationBuilder<Self>>>;
|
||||
|
||||
/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
|
||||
fn float_tensor<const D: usize>(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
graph::{GraphExecution, Ops, TensorOpsDescription},
|
||||
graph::{Ops, TensorOpsDescription},
|
||||
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
|
||||
};
|
||||
use burn_tensor::{
|
||||
|
@ -11,8 +11,6 @@ use burn_tensor::{
|
|||
pub trait FusionClient: Send + Sync + Clone {
|
||||
/// The [fusion backend](FusionBackend) associated type.
|
||||
type FusionBackend: FusionBackend;
|
||||
/// The [graph execution](GraphExecution) associated type.
|
||||
type GraphExecution: GraphExecution<Self::FusionBackend>;
|
||||
|
||||
/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
|
||||
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
|
||||
|
|
|
@ -1,26 +1,21 @@
|
|||
use super::FusionClient;
|
||||
use crate::{
|
||||
graph::{GraphExecution, TensorOpsDescription},
|
||||
FusionBackend, FusionServer, FusionTensor, Handle,
|
||||
};
|
||||
use crate::{graph::TensorOpsDescription, FusionBackend, FusionServer, FusionTensor, Handle};
|
||||
use burn_tensor::ops::FloatElem;
|
||||
use spin::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Use a mutex to communicate with the fusion server.
|
||||
pub struct MutexFusionClient<B, G>
|
||||
pub struct MutexFusionClient<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
G: GraphExecution<B>,
|
||||
{
|
||||
server: Arc<Mutex<FusionServer<B, G>>>,
|
||||
server: Arc<Mutex<FusionServer<B>>>,
|
||||
device: B::FusionDevice,
|
||||
}
|
||||
|
||||
impl<B, G> Clone for MutexFusionClient<B, G>
|
||||
impl<B> Clone for MutexFusionClient<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
G: GraphExecution<B>,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
|
@ -30,13 +25,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<B, G> FusionClient for MutexFusionClient<B, G>
|
||||
impl<B> FusionClient for MutexFusionClient<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
G: GraphExecution<B>,
|
||||
{
|
||||
type FusionBackend = B;
|
||||
type GraphExecution = G;
|
||||
|
||||
fn new(device: B::FusionDevice) -> Self {
|
||||
Self {
|
||||
|
|
|
@ -1,100 +1,97 @@
|
|||
use super::Ops;
|
||||
use super::RelativeGraphConverter;
|
||||
use super::TensorOpsDescription;
|
||||
use crate::{FusionBackend, FusionOps, FusionProperties, FusionStatus, HandleContainer};
|
||||
use std::{ops::RangeBounds, sync::Arc};
|
||||
use crate::Optimization;
|
||||
use crate::{FusionBackend, HandleContainer};
|
||||
use std::ops::RangeBounds;
|
||||
|
||||
/// The computational graph containing a list of [tensor operation descriptions](TensorOpsDescription).
|
||||
pub struct Graph<B: FusionBackend> {
|
||||
operations: Vec<Arc<TensorOpsDescription>>,
|
||||
pub(crate) global: Vec<TensorOpsDescription>,
|
||||
pub(crate) relative: Vec<TensorOpsDescription>,
|
||||
converter: RelativeGraphConverter,
|
||||
ops: Vec<Box<dyn Ops<B>>>,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Graph<B> {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
operations: Vec::new(),
|
||||
global: Vec::new(),
|
||||
relative: Vec::new(),
|
||||
converter: RelativeGraphConverter::default(),
|
||||
ops: Vec::new(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn add(&mut self, description: Arc<TensorOpsDescription>, ops: Box<dyn Ops<B>>) {
|
||||
self.operations.push(description);
|
||||
|
||||
pub(crate) fn split_relative_graph(
|
||||
&self,
|
||||
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
|
||||
let len = self.relative.len();
|
||||
if len < 1 {
|
||||
return (&self.relative, None);
|
||||
}
|
||||
|
||||
(&self.relative[0..len - 1], self.relative.last())
|
||||
}
|
||||
|
||||
pub(crate) fn add(&mut self, global: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
|
||||
let relative = global.to_relative(&mut self.converter);
|
||||
self.relative.push(relative);
|
||||
self.global.push(global);
|
||||
self.ops.push(ops);
|
||||
}
|
||||
|
||||
/// The size of the graph.
|
||||
pub fn len(&self) -> usize {
|
||||
self.operations.len()
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.global.len()
|
||||
}
|
||||
|
||||
/// If the graph is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.operations.len() == 0
|
||||
}
|
||||
|
||||
fn remove<R: RangeBounds<usize> + Clone>(
|
||||
&mut self,
|
||||
range: R,
|
||||
handles: &mut HandleContainer<B>,
|
||||
) {
|
||||
for ops in self.operations.drain(range.clone()) {
|
||||
ops.cleanup_tensor(handles)
|
||||
}
|
||||
self.ops.drain(range);
|
||||
}
|
||||
|
||||
fn nodes(&self) -> &[Arc<TensorOpsDescription>] {
|
||||
&self.operations
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
pub(crate) fn execute_optimization(
|
||||
&mut self,
|
||||
handles: &mut HandleContainer<B>,
|
||||
index: usize,
|
||||
optimizations: &mut [Optimization<B>],
|
||||
optimization: &dyn Optimization<B>,
|
||||
) {
|
||||
let optimization = optimizations.get_mut(index).unwrap();
|
||||
let num_keep = optimization.ops.len();
|
||||
optimization.ops.execute(handles);
|
||||
let num_keep = optimization.len();
|
||||
let mut context = self.converter.context(handles);
|
||||
optimization.execute(&mut context);
|
||||
|
||||
self.remove(0..num_keep, handles);
|
||||
|
||||
for optimization in optimizations.iter_mut() {
|
||||
optimization.reset();
|
||||
|
||||
for node in self.nodes() {
|
||||
optimization.register(node);
|
||||
}
|
||||
}
|
||||
self.cleanup_partial(0..num_keep, handles);
|
||||
}
|
||||
|
||||
pub(crate) fn execute(&mut self, handles: &mut HandleContainer<B>) {
|
||||
for (description, ops) in self.operations.drain(..).zip(self.ops.drain(..)) {
|
||||
pub(crate) fn execute_operations(&mut self, handles: &mut HandleContainer<B>) {
|
||||
for (description, ops) in self.global.drain(..).zip(self.ops.drain(..)) {
|
||||
ops.execute(handles);
|
||||
description.cleanup_tensor(handles);
|
||||
}
|
||||
self.cleanup_relative_graph();
|
||||
}
|
||||
}
|
||||
|
||||
/// An optimization that can be executed.
|
||||
#[derive(new)]
|
||||
pub struct Optimization<B: FusionBackend> {
|
||||
/// The [fusion operation](FusionOps) to potentially be executed.
|
||||
pub ops: Box<dyn FusionOps<B>>,
|
||||
/// The current status of the optimization.
|
||||
pub status: FusionStatus,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> Optimization<B> {
|
||||
pub(crate) fn register(&mut self, ops: &TensorOpsDescription) {
|
||||
if let FusionStatus::Closed(_) = self.status {
|
||||
return;
|
||||
fn cleanup_partial<R: RangeBounds<usize> + Clone>(
|
||||
&mut self,
|
||||
range: R,
|
||||
handles: &mut HandleContainer<B>,
|
||||
) {
|
||||
for ops in self.global.drain(range.clone()) {
|
||||
ops.cleanup_tensor(handles)
|
||||
}
|
||||
self.ops.drain(range);
|
||||
|
||||
self.status = self.ops.register(ops);
|
||||
// Rebuild the relative graph when partially removing the global graph.
|
||||
self.cleanup_relative_graph();
|
||||
|
||||
for node in self.global.iter() {
|
||||
let relative = node.to_relative(&mut self.converter);
|
||||
self.relative.push(relative);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&mut self) {
|
||||
self.ops.reset();
|
||||
self.status = FusionStatus::Open(FusionProperties::default());
|
||||
fn cleanup_relative_graph(&mut self) {
|
||||
self.relative.clear();
|
||||
self.converter.clear();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,845 @@
|
|||
use super::{
|
||||
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
|
||||
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
|
||||
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOpsDescription, BinaryOpsDescription,
|
||||
BoolOpsDescription, ClampOpsDescription, Conv1dDescription, Conv2dDescription,
|
||||
ConvTranspose1dDescription, ConvTranspose2dDescription, EmbeddingBackwardDescription,
|
||||
EmbeddingDescription, FloatOpsDescription, GatherOpsDescription, IntOpsDescription,
|
||||
MaskFillOpsDescription, MaskWhereOpsDescription, MaxPool1dDescription,
|
||||
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, ModuleOpsDescription,
|
||||
NumericOpsDescription, RandomOpsDescription, ReduceDimWithIndicesDescription,
|
||||
ReshapeDescription, ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
|
||||
SelectOpsDescription, SliceOpsDescription, SwapDimsDescription, TensorOpsDescription,
|
||||
UnaryOpsDescription,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
|
||||
/// mapped to an existing tensor that can be fetched and updated with the
|
||||
/// [handle container](HandleContainer).
|
||||
///
|
||||
/// It also contains all scalar values, which can change even for the same graph. They are sorted
|
||||
/// in the order in which they appear in the graph.
|
||||
#[derive(new)]
|
||||
pub struct Context<'a, B: FusionBackend> {
|
||||
/// The tensor mapping where local tensor id points to the updated tensor description.
|
||||
pub tensors: &'a HashMap<TensorId, TensorDescription>,
|
||||
/// Handle container to retrieve tensors based on their description.
|
||||
pub handles: &'a mut HandleContainer<B>,
|
||||
/// Float scalars found in the graph in the order they appeared.
|
||||
pub scalar_floats: &'a Vec<f32>,
|
||||
/// Int scalars found in the graph in the order they appeared.
|
||||
pub scalar_ints: &'a Vec<i32>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct RelativeGraphConverter {
|
||||
tensors_relative2global: HashMap<TensorId, TensorDescription>,
|
||||
tensors_global2relative: HashMap<TensorId, TensorDescription>,
|
||||
/// Only useful to create new shape ID.
|
||||
/// You should use tensor descriptions to retrieve the proper shape.
|
||||
shapes_global2relative: HashMap<usize, usize>,
|
||||
scalar_floats: Vec<f32>,
|
||||
scalar_ints: Vec<i32>,
|
||||
}
|
||||
|
||||
impl RelativeGraphConverter {
|
||||
pub(crate) fn context<'a, B: FusionBackend>(
|
||||
&'a self,
|
||||
handles: &'a mut HandleContainer<B>,
|
||||
) -> Context<'a, B> {
|
||||
Context {
|
||||
handles,
|
||||
tensors: &self.tensors_relative2global,
|
||||
scalar_floats: &self.scalar_floats,
|
||||
scalar_ints: &self.scalar_ints,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clear(&mut self) {
|
||||
self.tensors_relative2global.clear();
|
||||
self.tensors_global2relative.clear();
|
||||
self.shapes_global2relative.clear();
|
||||
self.scalar_floats.clear();
|
||||
self.scalar_ints.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E) -> E {
|
||||
self.scalar_floats.push(elem.elem());
|
||||
// We return 0 so that the id from a scalar operation is the same no matter its scalar
|
||||
// value.
|
||||
0.elem()
|
||||
}
|
||||
|
||||
pub(crate) fn relative_int<E: Element>(&mut self, elem: &E) -> E {
|
||||
self.scalar_ints.push(elem.elem());
|
||||
// We return 0 so that the id from a scalar operation is the same no matter its scalar
|
||||
// value.
|
||||
0.elem()
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
TensorOpsDescription::BaseOpsFloat(ops) => {
|
||||
TensorOpsDescription::BaseOpsFloat(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::BaseOpsInt(ops) => {
|
||||
TensorOpsDescription::BaseOpsInt(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::BaseOpsBool(ops) => {
|
||||
TensorOpsDescription::BaseOpsBool(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::NumericOpsFloat(ops) => TensorOpsDescription::NumericOpsFloat(
|
||||
ops.to_relative(converter, |converter, e| converter.relative_float(e)),
|
||||
),
|
||||
TensorOpsDescription::NumericOpsInt(ops) => TensorOpsDescription::NumericOpsInt(
|
||||
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
|
||||
),
|
||||
TensorOpsDescription::BoolOps(ops) => {
|
||||
TensorOpsDescription::BoolOps(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::IntOps(ops) => {
|
||||
TensorOpsDescription::IntOps(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::FloatOps(ops) => {
|
||||
TensorOpsDescription::FloatOps(ops.to_relative(converter))
|
||||
}
|
||||
TensorOpsDescription::ModuleOps(ops) => {
|
||||
TensorOpsDescription::ModuleOps(ops.to_relative(converter))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
ModuleOpsDescription::Embedding(desc) => {
|
||||
ModuleOpsDescription::Embedding(EmbeddingDescription {
|
||||
weights: desc.weights.to_relative(converter),
|
||||
indices: desc.indices.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::EmbeddingBackward(desc) => {
|
||||
ModuleOpsDescription::EmbeddingBackward(EmbeddingBackwardDescription {
|
||||
weights: desc.weights.to_relative(converter),
|
||||
out_grad: desc.out_grad.to_relative(converter),
|
||||
indices: desc.indices.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::Conv1d(desc) => ModuleOpsDescription::Conv1d(Conv1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
options: desc.options.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
ModuleOpsDescription::Conv2d(desc) => ModuleOpsDescription::Conv2d(Conv2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
options: desc.options.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
ModuleOpsDescription::ConvTranspose1d(desc) => {
|
||||
ModuleOpsDescription::ConvTranspose1d(ConvTranspose1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
options: desc.options.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::ConvTranspose2d(desc) => {
|
||||
ModuleOpsDescription::ConvTranspose2d(ConvTranspose2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
options: desc.options.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AvgPool1d(desc) => {
|
||||
ModuleOpsDescription::AvgPool1d(super::AvgPool1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
count_include_pad: desc.count_include_pad,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AvgPool2d(desc) => {
|
||||
ModuleOpsDescription::AvgPool2d(AvgPool2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
count_include_pad: desc.count_include_pad,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AvgPool1dBackward(desc) => {
|
||||
ModuleOpsDescription::AvgPool1dBackward(super::AvgPool1dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
count_include_pad: desc.count_include_pad,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AvgPool2dBackward(desc) => {
|
||||
ModuleOpsDescription::AvgPool2dBackward(AvgPool2dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
count_include_pad: desc.count_include_pad,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AdaptiveAvgPool1d(desc) => {
|
||||
ModuleOpsDescription::AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
output_size: desc.output_size,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AdaptiveAvgPool2d(desc) => {
|
||||
ModuleOpsDescription::AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
output_size: desc.output_size,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::AdaptiveAvgPool1dBackward(desc) => {
|
||||
ModuleOpsDescription::AdaptiveAvgPool1dBackward(
|
||||
AdaptiveAvgPool1dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
},
|
||||
)
|
||||
}
|
||||
ModuleOpsDescription::AdaptiveAvgPool2dBackward(desc) => {
|
||||
ModuleOpsDescription::AdaptiveAvgPool2dBackward(
|
||||
AdaptiveAvgPool2dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
},
|
||||
)
|
||||
}
|
||||
ModuleOpsDescription::MaxPool1d(desc) => {
|
||||
ModuleOpsDescription::MaxPool1d(MaxPool1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::MaxPool1dWithIndices(desc) => {
|
||||
ModuleOpsDescription::MaxPool1dWithIndices(MaxPool1dWithIndicesDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
out_indices: desc.out_indices.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::MaxPool1dWithIndicesBackward(desc) => {
|
||||
ModuleOpsDescription::MaxPool1dWithIndicesBackward(
|
||||
MaxPool1dWithIndicesBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
indices: desc.indices.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
},
|
||||
)
|
||||
}
|
||||
ModuleOpsDescription::MaxPool2d(desc) => {
|
||||
ModuleOpsDescription::MaxPool2d(MaxPool2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::MaxPool2dWithIndices(desc) => {
|
||||
ModuleOpsDescription::MaxPool2dWithIndices(MaxPool2dWithIndicesDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
out_indices: desc.out_indices.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOpsDescription::MaxPool2dWithIndicesBackward(desc) => {
|
||||
ModuleOpsDescription::MaxPool2dWithIndicesBackward(
|
||||
MaxPool2dWithIndicesBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
grad: desc.grad.to_relative(converter),
|
||||
indices: desc.indices.to_relative(converter),
|
||||
kernel_size: desc.kernel_size,
|
||||
stride: desc.stride,
|
||||
padding: desc.padding,
|
||||
dilation: desc.dilation,
|
||||
out: desc.out.to_relative(converter),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
FloatOpsDescription::Exp(desc) => FloatOpsDescription::Exp(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Log(desc) => FloatOpsDescription::Log(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Log1p(desc) => FloatOpsDescription::Log1p(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Erf(desc) => FloatOpsDescription::Erf(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Powf(desc) => FloatOpsDescription::Powf(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: converter.relative_float(&desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Sqrt(desc) => FloatOpsDescription::Sqrt(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Cos(desc) => FloatOpsDescription::Cos(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Sin(desc) => FloatOpsDescription::Sin(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::Tanh(desc) => FloatOpsDescription::Tanh(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
FloatOpsDescription::IntoInt(desc) => {
|
||||
FloatOpsDescription::IntoInt(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Matmul(desc) => {
|
||||
FloatOpsDescription::Matmul(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Random(desc) => {
|
||||
FloatOpsDescription::Random(RandomOpsDescription {
|
||||
out: desc.out.to_relative(converter),
|
||||
distribution: desc.distribution,
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Recip(desc) => FloatOpsDescription::Recip(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BoolOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
BoolOpsDescription::IntoFloat(desc) => {
|
||||
BoolOpsDescription::IntoFloat(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BoolOpsDescription::IntoInt(desc) => BoolOpsDescription::IntoInt(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
BoolOpsDescription::Not(desc) => BoolOpsDescription::Not(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
IntOpsDescription::IntoFloat(desc) => {
|
||||
IntOpsDescription::IntoFloat(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Element> NumericOpsDescription<E> {
|
||||
pub(crate) fn to_relative<F>(
|
||||
&self,
|
||||
converter: &mut RelativeGraphConverter,
|
||||
local_elem: F,
|
||||
) -> Self
|
||||
where
|
||||
F: Fn(&mut RelativeGraphConverter, &E) -> E,
|
||||
{
|
||||
match self {
|
||||
NumericOpsDescription::Add(desc) => NumericOpsDescription::Add(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::AddScalar(desc) => {
|
||||
NumericOpsDescription::AddScalar(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Sub(desc) => NumericOpsDescription::Sub(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::SubScalar(desc) => {
|
||||
NumericOpsDescription::SubScalar(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Div(desc) => NumericOpsDescription::Div(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::DivScalar(desc) => {
|
||||
NumericOpsDescription::DivScalar(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Mul(desc) => NumericOpsDescription::Mul(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::MulScalar(desc) => {
|
||||
NumericOpsDescription::MulScalar(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Abs(desc) => NumericOpsDescription::Abs(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::Ones(desc) => {
|
||||
NumericOpsDescription::Ones(desc.to_relative(converter))
|
||||
}
|
||||
NumericOpsDescription::Zeros(desc) => {
|
||||
NumericOpsDescription::Zeros(desc.to_relative(converter))
|
||||
}
|
||||
NumericOpsDescription::Full(desc) => NumericOpsDescription::Full((
|
||||
desc.0.to_relative(converter),
|
||||
local_elem(converter, &desc.1),
|
||||
)),
|
||||
NumericOpsDescription::Gather(desc) => {
|
||||
NumericOpsDescription::Gather(GatherOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
indices: desc.indices.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Scatter(desc) => {
|
||||
NumericOpsDescription::Scatter(ScatterOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
indices: desc.indices.to_relative(converter),
|
||||
value: desc.value.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Select(desc) => {
|
||||
NumericOpsDescription::Select(SelectOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
indices: desc.indices.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::SelectAssign(desc) => {
|
||||
NumericOpsDescription::SelectAssign(SelectAssignOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
indices: desc.indices.to_relative(converter),
|
||||
value: desc.value.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::MaskWhere(desc) => {
|
||||
NumericOpsDescription::MaskWhere(MaskWhereOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
mask: desc.mask.to_relative(converter),
|
||||
value: desc.value.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::MaskFill(desc) => {
|
||||
NumericOpsDescription::MaskFill(MaskFillOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
mask: desc.mask.to_relative(converter),
|
||||
value: local_elem(converter, &desc.value),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::MeanDim(desc) => {
|
||||
NumericOpsDescription::MeanDim(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs, // Dim should stay the same.
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Mean(desc) => NumericOpsDescription::Mean(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::Sum(desc) => NumericOpsDescription::Sum(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::SumDim(desc) => {
|
||||
NumericOpsDescription::SumDim(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs, // Dim should stay the same.
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::EqualElem(desc) => {
|
||||
NumericOpsDescription::EqualElem(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Greater(desc) => {
|
||||
NumericOpsDescription::Greater(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::GreaterElem(desc) => {
|
||||
NumericOpsDescription::GreaterElem(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::GreaterEqual(desc) => {
|
||||
NumericOpsDescription::GreaterEqual(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::GreaterEqualElem(desc) => {
|
||||
NumericOpsDescription::GreaterEqualElem(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Lower(desc) => {
|
||||
NumericOpsDescription::Lower(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::LowerElem(desc) => {
|
||||
NumericOpsDescription::LowerElem(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::LowerEqual(desc) => {
|
||||
NumericOpsDescription::LowerEqual(BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::LowerEqualElem(desc) => {
|
||||
NumericOpsDescription::LowerEqualElem(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ArgMax(desc) => {
|
||||
NumericOpsDescription::ArgMax(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ArgMin(desc) => {
|
||||
NumericOpsDescription::ArgMin(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Max(desc) => NumericOpsDescription::Max(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::MaxDimWithIndices(desc) => {
|
||||
NumericOpsDescription::MaxDimWithIndices(ReduceDimWithIndicesDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
out: desc.out.to_relative(converter),
|
||||
out_indices: desc.out_indices.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::MinDimWithIndices(desc) => {
|
||||
NumericOpsDescription::MinDimWithIndices(ReduceDimWithIndicesDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
out: desc.out.to_relative(converter),
|
||||
out_indices: desc.out_indices.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Min(desc) => NumericOpsDescription::Min(UnaryOpsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
NumericOpsDescription::MaxDim(desc) => {
|
||||
NumericOpsDescription::MaxDim(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::MinDim(desc) => {
|
||||
NumericOpsDescription::MinDim(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Clamp(desc) => {
|
||||
NumericOpsDescription::Clamp(ClampOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
min: local_elem(converter, &desc.min),
|
||||
max: local_elem(converter, &desc.max),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ClampMax(desc) => {
|
||||
NumericOpsDescription::ClampMax(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ClampMin(desc) => {
|
||||
NumericOpsDescription::ClampMin(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BaseOpsDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
match self {
|
||||
BaseOpsDescription::ToDevice(desc) => {
|
||||
BaseOpsDescription::ToDevice(desc.to_relative(converter))
|
||||
}
|
||||
BaseOpsDescription::Reshape(desc) => BaseOpsDescription::Reshape(ReshapeDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
BaseOpsDescription::SwapDims(desc) => {
|
||||
BaseOpsDescription::SwapDims(SwapDimsDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
dim1: desc.dim1,
|
||||
dim2: desc.dim2,
|
||||
})
|
||||
}
|
||||
BaseOpsDescription::Slice(desc) => BaseOpsDescription::Slice(SliceOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
ranges: desc.ranges.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
BaseOpsDescription::SliceAssign(desc) => {
|
||||
BaseOpsDescription::SliceAssign(super::SliceAssignOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
ranges: desc.ranges.clone(),
|
||||
value: desc.value.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOpsDescription::Equal(desc) => {
|
||||
BaseOpsDescription::Equal(super::BinaryOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: desc.rhs.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOpsDescription::Repeat(desc) => {
|
||||
BaseOpsDescription::Repeat(super::RepeatOpsDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
times: desc.times,
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOpsDescription::Cat(desc) => BaseOpsDescription::Cat(super::CatOpsDescription {
|
||||
tensors: desc
|
||||
.tensors
|
||||
.iter()
|
||||
.map(|tensor| tensor.to_relative(converter))
|
||||
.collect(),
|
||||
dim: desc.dim,
|
||||
out: desc.out.to_relative(converter),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDescription {
|
||||
pub(crate) fn to_relative(&self, converter: &mut RelativeGraphConverter) -> Self {
|
||||
let relative_id = if let Some(value) = converter.tensors_global2relative.get(&self.id) {
|
||||
// If we already have the same tensor registered, we have to update its value, but not
|
||||
// its id.
|
||||
value.id.clone()
|
||||
} else {
|
||||
// We create a new relative id since we never seen this tensor in the graph before.
|
||||
TensorId::new(converter.tensors_relative2global.len() as u64)
|
||||
};
|
||||
|
||||
// We can create relative shapes by mapping each shape found to an ID, which is a `usize`.
|
||||
let mut relative_shape = Vec::with_capacity(self.shape.len());
|
||||
for dim in self.shape.iter() {
|
||||
if let Some(dim_id) = converter.shapes_global2relative.get(dim) {
|
||||
// We already saw that dim value before, so we retrieve its ID.
|
||||
relative_shape.push(*dim_id);
|
||||
} else {
|
||||
// We never saw this dim value before, therefore we create a new ID.
|
||||
let dim_id = converter.shapes_global2relative.len();
|
||||
relative_shape.push(dim_id);
|
||||
converter.shapes_global2relative.insert(*dim, dim_id);
|
||||
}
|
||||
}
|
||||
|
||||
// We create the relative tensor.
|
||||
let relative_tensor = TensorDescription {
|
||||
id: relative_id.clone(),
|
||||
shape: relative_shape,
|
||||
status: self.status.clone(),
|
||||
};
|
||||
|
||||
// We update both mappings.
|
||||
converter
|
||||
.tensors_relative2global
|
||||
.insert(relative_id, self.clone());
|
||||
converter
|
||||
.tensors_global2relative
|
||||
.insert(self.id.clone(), relative_tensor.clone());
|
||||
|
||||
relative_tensor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::TensorStatus;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tensor_description_to_relative() {
|
||||
let tensor1 = TensorDescription {
|
||||
id: TensorId::new(500),
|
||||
shape: vec![512, 32, 2048],
|
||||
status: TensorStatus::ReadOnly,
|
||||
};
|
||||
let tensor2 = TensorDescription {
|
||||
id: TensorId::new(501),
|
||||
shape: vec![512, 128, 2048],
|
||||
status: TensorStatus::ReadOnly,
|
||||
};
|
||||
let mut converter = RelativeGraphConverter::default();
|
||||
let tensor1_local = tensor1.to_relative(&mut converter);
|
||||
let tensor2_local = tensor2.to_relative(&mut converter);
|
||||
|
||||
assert_eq!(
|
||||
tensor1_local,
|
||||
TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![0, 1, 2],
|
||||
status: TensorStatus::ReadOnly
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
tensor2_local,
|
||||
TensorDescription {
|
||||
id: TensorId::new(1),
|
||||
shape: vec![0, 3, 2],
|
||||
status: TensorStatus::ReadOnly
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,59 +1,199 @@
|
|||
use super::{Graph, Optimization};
|
||||
use crate::{FusionBackend, FusionStatus, HandleContainer};
|
||||
|
||||
/// The graph execution trait abstracts the way the graph is executing optimizations.
|
||||
pub trait GraphExecution<B: FusionBackend>: Default + Send {
|
||||
/// Execute the given graph using the list of potential [optimizations](Optimization).
|
||||
/// May do nothing if empty or not ready
|
||||
fn maybe_execute(
|
||||
&mut self,
|
||||
graph: &mut Graph<B>,
|
||||
handles: &mut HandleContainer<B>,
|
||||
optimizations: &mut [Optimization<B>],
|
||||
force: bool,
|
||||
);
|
||||
}
|
||||
use super::{CacheResult, Condition, Graph, OptimizationCache, TensorOpsDescription};
|
||||
use crate::{
|
||||
FusionBackend, HandleContainer, Optimization, OptimizationBuilder, OptimizationStatus,
|
||||
};
|
||||
|
||||
/// Execute an optimization following a greedy algorithm.
|
||||
#[derive(Default)]
|
||||
pub struct GreedyGraphExecution;
|
||||
pub(crate) struct GraphExecution<B: FusionBackend> {
|
||||
optimization_cache: OptimizationCache<Box<dyn Optimization<B>>>,
|
||||
optimizations: Vec<Box<dyn OptimizationBuilder<B>>>,
|
||||
num_skipped: usize,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> GraphExecution<B> for GreedyGraphExecution {
|
||||
fn maybe_execute(
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) enum ExecutionMode {
|
||||
// Signal that we execute the graph after a new ops is added to the graph.
|
||||
NewOps,
|
||||
// Signal that we execute the graph because of a sync without any new ops added to the graph.
|
||||
Sync,
|
||||
}
|
||||
|
||||
impl<B: FusionBackend> GraphExecution<B> {
|
||||
/// Create a new graph execution with the given optimization builders.
|
||||
pub fn new(optimizations: Vec<Box<dyn OptimizationBuilder<B>>>) -> Self {
|
||||
Self {
|
||||
optimization_cache: OptimizationCache::new(),
|
||||
optimizations,
|
||||
num_skipped: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the graph with the provided mode.
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
graph: &mut Graph<B>,
|
||||
handles: &mut HandleContainer<B>,
|
||||
optimizations: &mut [Optimization<B>],
|
||||
force: bool,
|
||||
mode: ExecutionMode,
|
||||
) {
|
||||
loop {
|
||||
if !force && still_optimizing(optimizations) {
|
||||
if graph.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
match find_best_optimization_index(optimizations) {
|
||||
Some(index) => {
|
||||
graph.execute_optimization(handles, index, optimizations);
|
||||
}
|
||||
None => {
|
||||
graph.execute(handles);
|
||||
optimizations.iter_mut().for_each(|ops| ops.reset());
|
||||
}
|
||||
}
|
||||
match self.cache(graph, mode) {
|
||||
CacheResult::Miss => {
|
||||
match self.build(graph, mode) {
|
||||
BuildAction::ExecuteOptimization(ops) => {
|
||||
graph.execute_optimization(handles, ops);
|
||||
self.reset(graph);
|
||||
}
|
||||
BuildAction::ExecuteOperations => {
|
||||
graph.execute_operations(handles);
|
||||
self.reset(graph);
|
||||
}
|
||||
BuildAction::ContinueBuilding => {
|
||||
if let ExecutionMode::Sync = mode {
|
||||
panic!("Can't continue building when sync is called.")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if graph.is_empty() {
|
||||
// No more ops to fuse.
|
||||
if self.num_skipped == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
CacheResult::OnPath => {
|
||||
self.num_skipped += 1;
|
||||
|
||||
match mode {
|
||||
ExecutionMode::NewOps => break,
|
||||
ExecutionMode::Sync => panic!("Can't wait while sync"),
|
||||
};
|
||||
}
|
||||
CacheResult::Found(ops) => {
|
||||
graph.execute_optimization(handles, ops.as_ref());
|
||||
self.reset(graph);
|
||||
}
|
||||
};
|
||||
|
||||
if let ExecutionMode::NewOps = mode {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build(&mut self, graph: &mut Graph<B>, mode: ExecutionMode) -> BuildAction<'_, B> {
|
||||
// When we are executing with the new ops mode, we need to register the last ops of the
|
||||
// graph even when there is no skipped operation.
|
||||
let offset = match mode {
|
||||
ExecutionMode::NewOps => 1,
|
||||
ExecutionMode::Sync => 0,
|
||||
};
|
||||
|
||||
for i in (0..self.num_skipped + offset).rev() {
|
||||
let index = graph.relative.len() - 1 - i;
|
||||
let relative = &graph.relative[index];
|
||||
|
||||
for ops in self.optimizations.iter_mut() {
|
||||
ops.register(relative);
|
||||
}
|
||||
}
|
||||
self.num_skipped = 0;
|
||||
|
||||
// Can only be lazy when not sync.
|
||||
if let ExecutionMode::NewOps = mode {
|
||||
if still_optimizing(&self.optimizations) {
|
||||
return BuildAction::ContinueBuilding;
|
||||
}
|
||||
}
|
||||
|
||||
match find_best_optimization_index(&self.optimizations) {
|
||||
Some(index) => {
|
||||
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
|
||||
let optimization = &self.optimizations[index];
|
||||
let ops = self
|
||||
.optimization_cache
|
||||
.complete(optimization, relative, next_ops);
|
||||
BuildAction::ExecuteOptimization(ops.as_ref())
|
||||
}
|
||||
None => {
|
||||
// TODO: Cache this result too.
|
||||
BuildAction::ExecuteOperations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, graph: &mut Graph<B>) {
|
||||
for ops in self.optimizations.iter_mut() {
|
||||
ops.reset();
|
||||
}
|
||||
self.num_skipped = graph.relative.len();
|
||||
|
||||
self.optimization_cache.reset();
|
||||
|
||||
// Reset the policy state.
|
||||
for i in 0..self.num_skipped {
|
||||
let _ = self.optimization_cache.follow(
|
||||
&graph.relative[0..i],
|
||||
Condition::NextOps(&graph.relative[i]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn cache<'a>(
|
||||
&'a mut self,
|
||||
graph: &mut Graph<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> CacheResult<'a, Box<dyn Optimization<B>>> {
|
||||
let (graph, next_ops) = Self::split_relative_graph_ref(graph, mode);
|
||||
let end_condition = next_ops.map(Condition::NextOps).unwrap_or(Condition::Sync);
|
||||
let action = self.optimization_cache.follow(graph, end_condition);
|
||||
|
||||
match mode {
|
||||
ExecutionMode::NewOps => action,
|
||||
ExecutionMode::Sync => match action {
|
||||
CacheResult::Miss => CacheResult::Miss,
|
||||
CacheResult::OnPath => CacheResult::Miss,
|
||||
CacheResult::Found(ops) => CacheResult::Found(ops),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn split_relative_graph_owned(
|
||||
graph: &Graph<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> (Vec<TensorOpsDescription>, Option<TensorOpsDescription>) {
|
||||
match mode {
|
||||
ExecutionMode::NewOps => {
|
||||
let graph = graph.split_relative_graph();
|
||||
(graph.0.to_vec(), graph.1.cloned())
|
||||
}
|
||||
ExecutionMode::Sync => (graph.relative.clone(), None),
|
||||
}
|
||||
}
|
||||
|
||||
fn split_relative_graph_ref(
|
||||
graph: &Graph<B>,
|
||||
mode: ExecutionMode,
|
||||
) -> (&[TensorOpsDescription], Option<&TensorOpsDescription>) {
|
||||
match mode {
|
||||
ExecutionMode::NewOps => graph.split_relative_graph(),
|
||||
ExecutionMode::Sync => (graph.relative.as_slice(), None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn still_optimizing<B: FusionBackend>(optimizations: &[Optimization<B>]) -> bool {
|
||||
enum BuildAction<'a, B: FusionBackend> {
|
||||
ExecuteOptimization(&'a dyn Optimization<B>),
|
||||
ExecuteOperations,
|
||||
ContinueBuilding,
|
||||
}
|
||||
|
||||
fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuilder<B>>]) -> bool {
|
||||
let mut num_stopped = 0;
|
||||
|
||||
for optimization in optimizations.iter() {
|
||||
if let FusionStatus::Closed(_) = optimization.status {
|
||||
if let OptimizationStatus::Closed = optimization.status() {
|
||||
num_stopped += 1
|
||||
}
|
||||
}
|
||||
|
@ -62,16 +202,13 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Optimization<B>]) -> bool
|
|||
}
|
||||
|
||||
fn find_best_optimization_index<B: FusionBackend>(
|
||||
optimizations: &[Optimization<B>],
|
||||
optimizations: &[Box<dyn OptimizationBuilder<B>>],
|
||||
) -> Option<usize> {
|
||||
let mut best_index = None;
|
||||
let mut best_score = 0;
|
||||
|
||||
for (i, optimization) in optimizations.iter().enumerate() {
|
||||
let properties = match optimization.status {
|
||||
FusionStatus::Closed(properties) => properties,
|
||||
FusionStatus::Open(properties) => properties,
|
||||
};
|
||||
let properties = optimization.properties();
|
||||
|
||||
if properties.ready && properties.score >= best_score {
|
||||
best_index = Some(i);
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
pub(crate) mod execution;
|
||||
|
||||
mod base;
|
||||
mod execution;
|
||||
mod context;
|
||||
mod ops;
|
||||
mod path;
|
||||
|
||||
pub use base::*;
|
||||
pub use execution::*;
|
||||
pub use context::*;
|
||||
pub use ops::*;
|
||||
pub use path::*;
|
||||
|
|
|
@ -13,7 +13,7 @@ pub trait Ops<B: FusionBackend>: Send + Sync {
|
|||
}
|
||||
|
||||
/// Describe all tensor operations possible.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum TensorOpsDescription {
|
||||
/// Basic operation on a float tensor.
|
||||
BaseOpsFloat(BaseOpsDescription),
|
||||
|
@ -36,7 +36,7 @@ pub enum TensorOpsDescription {
|
|||
}
|
||||
|
||||
/// Operation description specific to a float tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum FloatOpsDescription {
|
||||
/// Operation corresponding to [exp](burn_tensor::ops::TensorOps::exp).
|
||||
Exp(UnaryOpsDescription),
|
||||
|
@ -61,13 +61,13 @@ pub enum FloatOpsDescription {
|
|||
/// Operation corresponding to [matmul](burn_tensor::ops::TensorOps::matmul).
|
||||
Matmul(BinaryOpsDescription),
|
||||
/// Operation corresponding to [random](burn_tensor::ops::TensorOps::random).
|
||||
Random((TensorDescription, Distribution)),
|
||||
Random(RandomOpsDescription),
|
||||
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
|
||||
Recip(UnaryOpsDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to module.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum ModuleOpsDescription {
|
||||
/// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding).
|
||||
Embedding(EmbeddingDescription),
|
||||
|
@ -124,7 +124,7 @@ pub enum ModuleOpsDescription {
|
|||
}
|
||||
|
||||
/// Basic operations that can be done on any tensor type.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum BaseOpsDescription {
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
|
@ -177,8 +177,8 @@ pub enum BaseOpsDescription {
|
|||
}
|
||||
|
||||
/// Numeric operations on int and float tensors.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum NumericOpsDescription<E: Element> {
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum NumericOpsDescription<E> {
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [add](burn_tensor::ops::TensorOps::add).
|
||||
|
@ -392,14 +392,14 @@ pub enum NumericOpsDescription<E: Element> {
|
|||
}
|
||||
|
||||
/// Operation description specific to an int tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum IntOpsDescription {
|
||||
/// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float).
|
||||
IntoFloat(UnaryOpsDescription),
|
||||
}
|
||||
|
||||
/// Operation description specific to a bool tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum BoolOpsDescription {
|
||||
/// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float).
|
||||
IntoFloat(UnaryOpsDescription),
|
||||
|
@ -409,7 +409,7 @@ pub enum BoolOpsDescription {
|
|||
Not(UnaryOpsDescription),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
/// Swap dim operation description.
|
||||
pub struct SwapDimsDescription {
|
||||
/// Input tensor description.
|
||||
|
@ -422,15 +422,21 @@ pub struct SwapDimsDescription {
|
|||
pub dim2: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct RandomOpsDescription {
|
||||
pub out: TensorDescription,
|
||||
pub distribution: Distribution,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ReshapeDescription {
|
||||
pub input: TensorDescription,
|
||||
pub out: TensorDescription,
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct BinaryOpsDescription {
|
||||
pub lhs: TensorDescription,
|
||||
|
@ -438,14 +444,14 @@ pub struct BinaryOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct UnaryOpsDescription {
|
||||
pub input: TensorDescription,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ScalarOpsDescription<E> {
|
||||
pub lhs: TensorDescription,
|
||||
|
@ -453,7 +459,7 @@ pub struct ScalarOpsDescription<E> {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct GatherOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -462,7 +468,7 @@ pub struct GatherOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ScatterOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -472,7 +478,7 @@ pub struct ScatterOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SelectOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -481,7 +487,7 @@ pub struct SelectOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SelectAssignOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -491,7 +497,7 @@ pub struct SelectAssignOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SliceOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -499,7 +505,7 @@ pub struct SliceOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SliceAssignOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -508,7 +514,7 @@ pub struct SliceAssignOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaskWhereOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -517,7 +523,7 @@ pub struct MaskWhereOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaskFillOpsDescription<E> {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -526,7 +532,7 @@ pub struct MaskFillOpsDescription<E> {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ClampOpsDescription<E> {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -535,17 +541,16 @@ pub struct ClampOpsDescription<E> {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct RepeatOpsDescription {
|
||||
pub tensor: TensorDescription,
|
||||
pub dim: usize,
|
||||
pub times: usize,
|
||||
pub shape: Vec<usize>,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct CatOpsDescription {
|
||||
pub tensors: Vec<TensorDescription>,
|
||||
|
@ -553,7 +558,7 @@ pub struct CatOpsDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ReduceDimWithIndicesDescription {
|
||||
pub tensor: TensorDescription,
|
||||
|
@ -562,7 +567,7 @@ pub struct ReduceDimWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct EmbeddingDescription {
|
||||
pub weights: TensorDescription,
|
||||
|
@ -570,7 +575,7 @@ pub struct EmbeddingDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct EmbeddingBackwardDescription {
|
||||
pub weights: TensorDescription,
|
||||
|
@ -579,7 +584,7 @@ pub struct EmbeddingBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -589,7 +594,7 @@ pub struct Conv1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -599,7 +604,7 @@ pub struct Conv2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ConvTranspose1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -609,7 +614,7 @@ pub struct ConvTranspose1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ConvTranspose2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -619,7 +624,7 @@ pub struct ConvTranspose2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -630,7 +635,7 @@ pub struct AvgPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -641,7 +646,7 @@ pub struct AvgPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool1dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -653,7 +658,7 @@ pub struct AvgPool1dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AvgPool2dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -665,7 +670,7 @@ pub struct AvgPool2dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -673,7 +678,7 @@ pub struct AdaptiveAvgPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -681,7 +686,7 @@ pub struct AdaptiveAvgPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool1dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -689,7 +694,7 @@ pub struct AdaptiveAvgPool1dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct AdaptiveAvgPool2dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -697,7 +702,7 @@ pub struct AdaptiveAvgPool2dBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -708,7 +713,7 @@ pub struct MaxPool1dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dWithIndicesDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -720,7 +725,7 @@ pub struct MaxPool1dWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool1dWithIndicesBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -733,7 +738,7 @@ pub struct MaxPool1dWithIndicesBackwardDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool2dDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -744,8 +749,8 @@ pub struct MaxPool2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub struct MaxPool2dWithIndicesDescription {
|
||||
pub x: TensorDescription,
|
||||
pub kernel_size: [usize; 2],
|
||||
|
@ -756,7 +761,7 @@ pub struct MaxPool2dWithIndicesDescription {
|
|||
pub out_indices: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct MaxPool2dWithIndicesBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
|
@ -1112,3 +1117,85 @@ impl ModuleOpsDescription {
|
|||
}
|
||||
}
|
||||
}
|
||||
impl core::hash::Hash for RandomOpsDescription {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.out.hash(state);
|
||||
|
||||
match self.distribution {
|
||||
Distribution::Default => 1u8.hash(state),
|
||||
Distribution::Bernoulli(_) => 2u8.hash(state),
|
||||
Distribution::Uniform(_, _) => 3u8.hash(state),
|
||||
Distribution::Normal(_, _) => 4u8.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<E> core::hash::Hash for ScalarOpsDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.lhs.hash(state);
|
||||
self.out.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> core::hash::Hash for MaskFillOpsDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.tensor.hash(state);
|
||||
self.mask.hash(state);
|
||||
self.out.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> core::hash::Hash for ClampOpsDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.tensor.hash(state);
|
||||
self.out.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> core::hash::Hash for NumericOpsDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
NumericOpsDescription::Add(desc) => desc.hash(state),
|
||||
NumericOpsDescription::AddScalar(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Sub(desc) => desc.hash(state),
|
||||
NumericOpsDescription::SubScalar(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Div(desc) => desc.hash(state),
|
||||
NumericOpsDescription::DivScalar(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Mul(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MulScalar(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Abs(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Ones(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Zeros(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Full(desc) => desc.0.hash(state),
|
||||
NumericOpsDescription::Gather(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Scatter(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Select(desc) => desc.hash(state),
|
||||
NumericOpsDescription::SelectAssign(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MaskWhere(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MaskFill(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MeanDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Mean(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Sum(desc) => desc.hash(state),
|
||||
NumericOpsDescription::SumDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::EqualElem(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Greater(desc) => desc.hash(state),
|
||||
NumericOpsDescription::GreaterElem(desc) => desc.hash(state),
|
||||
NumericOpsDescription::GreaterEqual(desc) => desc.hash(state),
|
||||
NumericOpsDescription::GreaterEqualElem(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Lower(desc) => desc.hash(state),
|
||||
NumericOpsDescription::LowerElem(desc) => desc.hash(state),
|
||||
NumericOpsDescription::LowerEqual(desc) => desc.hash(state),
|
||||
NumericOpsDescription::LowerEqualElem(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ArgMax(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ArgMin(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Max(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MaxDimWithIndices(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MinDimWithIndices(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Min(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MinDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Clamp(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,535 @@
|
|||
use super::starter::Starters;
|
||||
use crate::graph::TensorOpsDescription;
|
||||
|
||||
/// The cache works by keeping track of all possible optimizations for the current graph path.
|
||||
///
|
||||
/// # Details
|
||||
///
|
||||
/// This is pretty different from a normal key-value cache.
|
||||
/// There is no key to access the cached values, since computing a key for a graph is very expensive.
|
||||
/// Instead, we keep track of each new edge added to the graph and invalidate potential optimizations
|
||||
/// when we see a different edge is added while keeping track of the current graph path.
|
||||
///
|
||||
/// Therefore, the overhead is very minimal, since the time-complexity of checking the cache
|
||||
/// scales with the number of concurrent potential optimizations for the current path, which isn't
|
||||
/// supposed to be big at any time.
|
||||
pub(crate) struct OptimizationCache<O> {
|
||||
candidates: Vec<OptimizationId>,
|
||||
availables: Vec<(OptimizationId, usize)>,
|
||||
optimizations: Vec<OptimizationItem<O>>,
|
||||
starters: Starters,
|
||||
found: Option<OptimizationId>,
|
||||
}
|
||||
|
||||
impl<O> OptimizationCache<O> {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
candidates: Vec::new(),
|
||||
availables: Vec::new(),
|
||||
optimizations: Vec::new(),
|
||||
starters: Starters::default(),
|
||||
found: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Follow the current path on the provided graph with the start/end condition.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// It is assumed that this function will be called for each new edge added to the graph (for
|
||||
/// each new operation). Only one graph can be cached at a time.
|
||||
pub(crate) fn follow<'a>(
|
||||
&'a mut self,
|
||||
graph: &[TensorOpsDescription],
|
||||
condition: Condition,
|
||||
) -> CacheResult<'a, O> {
|
||||
if graph.is_empty() {
|
||||
// When the graph is empty, we use the condition as the first operation to determine
|
||||
// the new possible opitmizations.
|
||||
let ops = match condition {
|
||||
Condition::NextOps(ops) => ops,
|
||||
Condition::Sync => return CacheResult::Miss, // Sync an empty graph doesn't make
|
||||
// sense.
|
||||
};
|
||||
let candidates = self.starters.get(ops);
|
||||
if candidates.is_empty() {
|
||||
return CacheResult::Miss;
|
||||
}
|
||||
self.candidates = candidates;
|
||||
return CacheResult::OnPath;
|
||||
}
|
||||
|
||||
if let Some(candidate) = self.found {
|
||||
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
|
||||
}
|
||||
|
||||
// Invalidate candidates.
|
||||
let mut invalidated_candidate = Vec::new();
|
||||
for id in self.candidates.iter() {
|
||||
let item = match self.optimizations.get(*id) {
|
||||
Some(item) => item,
|
||||
None => panic!("Should have an optimization"),
|
||||
};
|
||||
let next_ops = graph.last().expect("Validated earlier");
|
||||
let next_ops_index = graph.len() - 1;
|
||||
let next_ops_candidate = match item.graph.get(next_ops_index) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// Graph of different size, invalidated.
|
||||
invalidated_candidate.push(*id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if next_ops_candidate != next_ops {
|
||||
// Graph with different node at the current position, invalidated.
|
||||
invalidated_candidate.push(*id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Is it optimal?
|
||||
if item.graph.len() == graph.len() {
|
||||
let ops = match condition {
|
||||
Condition::NextOps(ops) => ops,
|
||||
Condition::Sync => {
|
||||
self.found = Some(*id);
|
||||
return CacheResult::Found(&item.value);
|
||||
}
|
||||
};
|
||||
|
||||
if item.end_conditions.contains(ops) {
|
||||
self.found = Some(*id);
|
||||
return CacheResult::Found(&item.value);
|
||||
} else {
|
||||
self.availables.push((*id, graph.len()));
|
||||
invalidated_candidate.push(*id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut updated_candidates = Vec::new();
|
||||
core::mem::swap(&mut updated_candidates, &mut self.candidates);
|
||||
|
||||
self.candidates = updated_candidates
|
||||
.into_iter()
|
||||
.filter(|candidate| !invalidated_candidate.contains(candidate))
|
||||
.collect();
|
||||
|
||||
if self.candidates.is_empty() {
|
||||
CacheResult::Miss
|
||||
} else {
|
||||
CacheResult::OnPath
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal the completion of a graph path that reached a new optimization.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The optimization factory will only be called if the optimization is on a new graph.
|
||||
/// When the optimization already exists, but with a different end condition, a new end
|
||||
/// condition will be registered, but the old optimization will be used in following call. This
|
||||
/// is intended since we want to factory to be called only once per graph, but reused as much as
|
||||
/// possible.
|
||||
pub fn complete<'a, Factory: OptimizationFactory<O>>(
|
||||
&'a mut self,
|
||||
factory: &Factory,
|
||||
graph: Vec<TensorOpsDescription>,
|
||||
next_ops: Option<TensorOpsDescription>,
|
||||
) -> &'a O {
|
||||
let existing_optim = self
|
||||
.availables
|
||||
.iter()
|
||||
.find(|(_candidate, len)| *len == graph.len());
|
||||
|
||||
if let Some((id, _)) = existing_optim {
|
||||
let optimization = self.optimizations.get_mut(*id).unwrap();
|
||||
|
||||
if let Some(ops) = next_ops {
|
||||
optimization.end_conditions.push(ops)
|
||||
};
|
||||
|
||||
return &optimization.value;
|
||||
};
|
||||
|
||||
self.starters
|
||||
.insert(graph.first().unwrap(), self.optimizations.len());
|
||||
let optimization = OptimizationItem {
|
||||
graph,
|
||||
end_conditions: match next_ops {
|
||||
Some(val) => vec![val],
|
||||
None => Vec::new(),
|
||||
},
|
||||
value: factory.create(),
|
||||
};
|
||||
|
||||
self.optimizations.push(optimization);
|
||||
&self.optimizations.last().unwrap().value
|
||||
}
|
||||
|
||||
// Signal that a new path will begin.
|
||||
pub(crate) fn reset(&mut self) {
|
||||
self.candidates.clear();
|
||||
self.availables.clear();
|
||||
self.found = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Action to be made depending on the graph.
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum CacheResult<'a, T> {
|
||||
/// Continue exploring optimizations using the [builder](crate::OptimizationBuilder).
|
||||
Miss,
|
||||
/// The current graph indicates that an optimization may be possible in the future, so the
|
||||
/// best action is to wait for the optimization to become available.
|
||||
///
|
||||
/// Sometimes, it can be a false positive and a new optimization should be built from scratch.
|
||||
/// Therefore it's important to keep the previous operations to rebuild the state if it
|
||||
/// happens.
|
||||
OnPath,
|
||||
/// An optimization has been found, and the best action is to execute it!
|
||||
Found(&'a T),
|
||||
}
|
||||
|
||||
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
|
||||
/// always optimal.
|
||||
#[derive(Clone)]
|
||||
pub enum Condition<'a> {
|
||||
/// The next operation that signals the start or end of the operation.
|
||||
NextOps(&'a TensorOpsDescription),
|
||||
/// When sync, we should execute the optimization if found no matter what comes next.
|
||||
Sync,
|
||||
}
|
||||
|
||||
impl<'a, T> core::fmt::Debug for CacheResult<'a, T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CacheResult::Miss => f.write_str("CacheResult::Miss"),
|
||||
CacheResult::OnPath => f.write_str("CacheResult::OnPath"),
|
||||
CacheResult::Found(_) => f.write_str("CacheResult::Found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an optimization.
|
||||
pub(crate) trait OptimizationFactory<T> {
|
||||
/// Call only when a new optimization is found.
|
||||
fn create(&self) -> T;
|
||||
}
|
||||
|
||||
pub(super) type OptimizationId = usize;
|
||||
|
||||
struct OptimizationItem<O> {
|
||||
graph: Vec<TensorOpsDescription>,
|
||||
end_conditions: Vec<TensorOpsDescription>,
|
||||
value: O,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
graph::{FloatOpsDescription, UnaryOpsDescription},
|
||||
TensorDescription, TensorId, TensorStatus,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn should_cache_optimization_end_condition_forced() {
|
||||
// A graph with 3 ops.
|
||||
let graph = TestGraph::new(2);
|
||||
let mut path = OptimizationCache::new();
|
||||
|
||||
// First following
|
||||
graph.follow_misses(&mut path);
|
||||
|
||||
// Register the action.
|
||||
let optimization = path.complete(&Optimization1, graph.edges[0..2].to_vec(), None);
|
||||
|
||||
assert_eq!(optimization, &Optimization1.create());
|
||||
|
||||
// Second following on the same ops.
|
||||
path.reset();
|
||||
let result1 = path.follow(&[], Condition::NextOps(&graph.edges[0]));
|
||||
assert_eq!(result1, CacheResult::OnPath);
|
||||
|
||||
let result2 = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
|
||||
assert_eq!(result2, CacheResult::OnPath);
|
||||
|
||||
let result3 = path.follow(&graph.edges[0..2], Condition::Sync);
|
||||
match result3 {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn once_found_perfect_should_always_return_found() {
|
||||
let mut graph = TestGraph::new(2);
|
||||
let mut path = OptimizationCache::new();
|
||||
graph.follow_misses(&mut path);
|
||||
|
||||
// Register the action.
|
||||
let _optimization = path.complete(
|
||||
&Optimization1,
|
||||
graph.edges[0..1].to_vec(),
|
||||
Some(graph.edges[1].clone()),
|
||||
);
|
||||
|
||||
path.reset();
|
||||
graph.new_ops();
|
||||
graph.new_ops();
|
||||
|
||||
let result = path.follow(&[], Condition::NextOps(&graph.edges[0]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
|
||||
match result {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
}
|
||||
|
||||
let result = path.follow(&graph.edges[0..2], Condition::NextOps(&graph.edges[2]));
|
||||
match result {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_cache_optimization_end_condition_next_ops() {
|
||||
// A graph with 4 ops.
|
||||
let graph = TestGraph::new(3);
|
||||
let mut path = OptimizationCache::new();
|
||||
|
||||
// First following
|
||||
graph.follow_misses(&mut path);
|
||||
|
||||
// Register the action.
|
||||
let optimization = path.complete(
|
||||
&Optimization1,
|
||||
graph.edges[0..2].to_vec(),
|
||||
Some(graph.edges[2].clone()),
|
||||
);
|
||||
|
||||
assert_eq!(optimization, &Optimization1.create());
|
||||
|
||||
// Second following on the same ops.
|
||||
path.reset();
|
||||
let result1 = path.follow(&[], Condition::NextOps(&graph.edges[0]));
|
||||
assert_eq!(result1, CacheResult::OnPath);
|
||||
|
||||
let result2 = path.follow(&graph.edges[0..1], Condition::NextOps(&graph.edges[1]));
|
||||
assert_eq!(result2, CacheResult::OnPath);
|
||||
|
||||
let result3 = path.follow(&graph.edges[0..2], Condition::NextOps(&graph.edges[2]));
|
||||
match result3 {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_many_different_end_conditions() {
|
||||
let mut graph1 = TestGraph::new(2);
|
||||
graph1.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(desc)));
|
||||
|
||||
let mut graph2 = TestGraph::new(2);
|
||||
graph2.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Log(desc)));
|
||||
|
||||
let mut path = OptimizationCache::<String>::new();
|
||||
let last_edge_index = graph1.edges.len() - 1;
|
||||
|
||||
// Follow graph 1 with only misses.
|
||||
graph1.follow_misses(&mut path);
|
||||
let _ = path.complete(
|
||||
&Optimization1,
|
||||
graph1.edges[0..last_edge_index].to_vec(),
|
||||
Some(graph1.edges[last_edge_index].clone()),
|
||||
);
|
||||
|
||||
// Follow graph 2.
|
||||
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
|
||||
assert_eq!(result, CacheResult::Miss);
|
||||
|
||||
let optimization = path.complete(
|
||||
&Optimization2,
|
||||
graph2.edges[0..last_edge_index].to_vec(),
|
||||
Some(graph2.edges[last_edge_index].clone()),
|
||||
);
|
||||
assert_eq!(
|
||||
optimization,
|
||||
&Optimization1.create(),
|
||||
"Optimization 1 should still be returned, since same graph but not same end condition."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_multiple_concurrent_paths() {
|
||||
// Two different graphs with a different second ops, but the same last ops.
|
||||
let mut graph1 = TestGraph::new(1);
|
||||
graph1.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Exp(desc)));
|
||||
graph1.new_ops();
|
||||
|
||||
let mut graph2 = TestGraph::new(1);
|
||||
graph2.register_ops(|desc| TensorOpsDescription::FloatOps(FloatOpsDescription::Cos(desc)));
|
||||
graph2.new_ops();
|
||||
|
||||
let mut path = OptimizationCache::<String>::new();
|
||||
|
||||
// Follow graph 1 with only misses.
|
||||
graph1.follow_misses(&mut path);
|
||||
|
||||
// Register the opitmization 1 for graph 1.
|
||||
let last_edge_index = graph1.edges.len() - 1;
|
||||
let _ = path.complete(
|
||||
&Optimization1,
|
||||
graph1.edges[0..last_edge_index].to_vec(),
|
||||
Some(graph1.edges[last_edge_index].clone()),
|
||||
);
|
||||
|
||||
// Follow graph 2 and register a new optimization.
|
||||
path.reset();
|
||||
|
||||
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
|
||||
assert_eq!(
|
||||
result,
|
||||
CacheResult::Miss,
|
||||
"Should invalidate the second operation"
|
||||
);
|
||||
|
||||
// Register new optimization for path 2.
|
||||
let _ = path.complete(
|
||||
&Optimization2,
|
||||
graph2.edges[0..last_edge_index].to_vec(),
|
||||
Some(graph2.edges[last_edge_index].clone()),
|
||||
);
|
||||
|
||||
// Now let's validate that the cache works.
|
||||
|
||||
// New path instance on graph 1.
|
||||
path.reset();
|
||||
|
||||
let result = path.follow(&[], Condition::NextOps(&graph1.edges[0]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph1.edges[0..1], Condition::NextOps(&graph1.edges[1]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph1.edges[0..2], Condition::NextOps(&graph1.edges[2]));
|
||||
match result {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization1.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
};
|
||||
|
||||
// New path instance on graph 2.
|
||||
path.reset();
|
||||
|
||||
let result = path.follow(&[], Condition::NextOps(&graph2.edges[0]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..1], Condition::NextOps(&graph2.edges[1]));
|
||||
assert_eq!(result, CacheResult::OnPath);
|
||||
|
||||
let result = path.follow(&graph2.edges[0..2], Condition::NextOps(&graph2.edges[2]));
|
||||
match result {
|
||||
CacheResult::Found(ops) => assert_eq!(ops, &Optimization2.create()),
|
||||
_ => panic!("Should have found the cached operation"),
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct TestGraph {
|
||||
nodes: Vec<TensorDescription>,
|
||||
edges: Vec<TensorOpsDescription>,
|
||||
}
|
||||
|
||||
impl TestGraph {
|
||||
/// Create a new test graph with `num_ops` operations registered.
|
||||
pub fn new(num_ops: usize) -> Self {
|
||||
let mut graph = Self::default();
|
||||
for _ in 0..num_ops {
|
||||
graph.new_ops();
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
|
||||
/// The first follow should only be cache miss.
|
||||
pub fn follow_misses(&self, path: &mut OptimizationCache<String>) {
|
||||
for i in 0..self.edges.len() {
|
||||
let result = path.follow(&self.edges[0..i], Condition::NextOps(&self.edges[i]));
|
||||
assert_eq!(result, CacheResult::Miss);
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a unary operation in the graph.
|
||||
pub fn register_ops<F>(&mut self, func: F)
|
||||
where
|
||||
F: Fn(UnaryOpsDescription) -> TensorOpsDescription,
|
||||
{
|
||||
self.new_empty_node();
|
||||
let desc = self.unary_description();
|
||||
self.edges.push(func(desc));
|
||||
}
|
||||
|
||||
/// Add a simple operation to the graph.
|
||||
pub fn new_ops(&mut self) {
|
||||
if self.nodes.is_empty() {
|
||||
// Root node.
|
||||
self.new_empty_node();
|
||||
}
|
||||
|
||||
// Out node.
|
||||
self.new_empty_node();
|
||||
|
||||
self.edges
|
||||
.push(TensorOpsDescription::FloatOps(FloatOpsDescription::Log(
|
||||
self.unary_description(),
|
||||
)));
|
||||
}
|
||||
|
||||
fn new_empty_node(&mut self) {
|
||||
self.nodes.push(TensorDescription {
|
||||
id: TensorId::new(self.nodes.len() as u64),
|
||||
shape: vec![32, 32, 1],
|
||||
status: TensorStatus::NotInit,
|
||||
});
|
||||
}
|
||||
|
||||
fn unary_description(&self) -> UnaryOpsDescription {
|
||||
let size = self.nodes.len();
|
||||
|
||||
UnaryOpsDescription {
|
||||
input: self.nodes[size - 2].clone(),
|
||||
out: self.nodes[size - 1].clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Optimization1;
|
||||
struct Optimization2;
|
||||
|
||||
impl OptimizationFactory<String> for Optimization1 {
|
||||
fn create(&self) -> String {
|
||||
"Optimization1".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl OptimizationFactory<String> for Optimization2 {
|
||||
fn create(&self) -> String {
|
||||
"Optimization2".to_string()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
mod starter;
|
|
@ -0,0 +1,75 @@
|
|||
use super::OptimizationId;
|
||||
use crate::graph::TensorOpsDescription;
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Starters {
|
||||
starter_indices: HashMap<u64, Vec<(TensorOpsDescription, usize)>>,
|
||||
starters: Vec<Vec<OptimizationId>>,
|
||||
}
|
||||
|
||||
impl Starters {
|
||||
pub(crate) fn get(&self, ops: &TensorOpsDescription) -> Vec<OptimizationId> {
|
||||
let key = self.graph_key(ops);
|
||||
let values = match self.starter_indices.get(&key) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
if values.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let (_, index) = match values.iter().find(|value| &value.0 == ops) {
|
||||
Some(val) => val,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
let val = match self.starters.get(*index) {
|
||||
Some(value) => value.clone(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
val
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&mut self, ops: &TensorOpsDescription, new_id: OptimizationId) {
|
||||
let key = self.graph_key(ops);
|
||||
let values = match self.starter_indices.get_mut(&key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// New starter ops.
|
||||
let index = self.starters.len();
|
||||
self.starters.push(vec![new_id]);
|
||||
self.starter_indices.insert(key, vec![(ops.clone(), index)]);
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
// New with hash collision.
|
||||
let index = self.starters.len();
|
||||
self.starters.push(vec![new_id]);
|
||||
values.push((ops.clone(), index));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// New optimization for an existing starter.
|
||||
self.starters
|
||||
.get_mut(*index)
|
||||
.expect("Should exist")
|
||||
.push(new_id);
|
||||
}
|
||||
|
||||
fn graph_key(&self, ops: &TensorOpsDescription) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
ops.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
|
@ -138,17 +138,16 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D1>(&self.desc.input);
|
||||
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
|
||||
let output = B::bool_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ReshapeDescription {
|
||||
input: tensor.into_description(),
|
||||
shape,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
|
|
|
@ -5,10 +5,10 @@ use crate::{
|
|||
graph::{
|
||||
BaseOpsDescription, BinaryOpsDescription, CatOpsDescription, ClampOpsDescription,
|
||||
FloatOpsDescription, GatherOpsDescription, MaskFillOpsDescription, MaskWhereOpsDescription,
|
||||
NumericOpsDescription, Ops, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOpsDescription, ScatterOpsDescription, SelectAssignOpsDescription,
|
||||
SelectOpsDescription, SliceAssignOpsDescription, SliceOpsDescription, SwapDimsDescription,
|
||||
TensorOpsDescription, UnaryOpsDescription,
|
||||
NumericOpsDescription, Ops, RandomOpsDescription, ReduceDimWithIndicesDescription,
|
||||
ReshapeDescription, ScalarOpsDescription, ScatterOpsDescription,
|
||||
SelectAssignOpsDescription, SelectOpsDescription, SliceAssignOpsDescription,
|
||||
SliceOpsDescription, SwapDimsDescription, TensorOpsDescription, UnaryOpsDescription,
|
||||
},
|
||||
ops::binary::binary_ops_shape,
|
||||
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, unary_float_ops, Fusion,
|
||||
|
@ -39,16 +39,15 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RandomOps<const D: usize> {
|
||||
out: TensorDescription,
|
||||
distribution: Distribution,
|
||||
desc: RandomOpsDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Ops<B> for RandomOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let shape = Shape::from(self.out.shape.clone());
|
||||
let shape = Shape::from(self.desc.out.shape.clone());
|
||||
let output: B::TensorPrimitive<D> =
|
||||
B::random(shape, self.distribution, &handles.device);
|
||||
handles.register_float_tensor(&self.out.id, output);
|
||||
B::random(shape, self.desc.distribution, &handles.device);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,10 +55,13 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
let client = get_client::<B>(&device.clone().into());
|
||||
let out = client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = (out.to_description_out(), distribution);
|
||||
let desc = RandomOpsDescription {
|
||||
out: out.to_description_out(),
|
||||
distribution,
|
||||
};
|
||||
client.register(
|
||||
TensorOpsDescription::FloatOps(FloatOpsDescription::Random(desc.clone())),
|
||||
RandomOps::<D>::new(desc.0, desc.1),
|
||||
RandomOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
|
@ -548,17 +550,16 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D1>(&self.desc.input);
|
||||
let output = B::reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
|
||||
let output = B::reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ReshapeDescription {
|
||||
input: tensor.into_description(),
|
||||
shape,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
|
|
|
@ -81,17 +81,16 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
impl<const D1: usize, const D2: usize, B: FusionBackend> Ops<B> for ReshapeDimsOps<D1, D2> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D1>(&self.desc.input);
|
||||
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.shape));
|
||||
let output = B::int_reshape::<D1, D2>(input, Shape::from(&self.desc.out.shape));
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let shape: Vec<usize> = shape.dims.into();
|
||||
let out = tensor.client.tensor_uninitialized(shape.clone());
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = ReshapeDescription {
|
||||
input: tensor.into_description(),
|
||||
shape,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
|
|
|
@ -1,69 +1,48 @@
|
|||
use crate::{
|
||||
graph::{Graph, GraphExecution, Ops, Optimization, TensorOpsDescription},
|
||||
FusionBackend, FusionProperties, FusionStatus, HandleContainer, TensorId,
|
||||
graph::{
|
||||
execution::{ExecutionMode, GraphExecution},
|
||||
Graph, Ops, TensorOpsDescription,
|
||||
},
|
||||
FusionBackend, HandleContainer, TensorId,
|
||||
};
|
||||
use burn_tensor::ops::{FloatElem, IntElem};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FusionServer<B, G>
|
||||
pub struct FusionServer<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
G: GraphExecution<B>,
|
||||
{
|
||||
optimizations: Vec<Optimization<B>>,
|
||||
execution: GraphExecution<B>,
|
||||
graph: Graph<B>,
|
||||
pub(crate) handles: HandleContainer<B>,
|
||||
execution: G,
|
||||
pub device: B::FusionDevice,
|
||||
pub num_skipped: usize,
|
||||
}
|
||||
|
||||
impl<B, G> FusionServer<B, G>
|
||||
impl<B> FusionServer<B>
|
||||
where
|
||||
B: FusionBackend,
|
||||
G: GraphExecution<B>,
|
||||
{
|
||||
pub fn new(device: B::FusionDevice) -> Self {
|
||||
let optimizations = B::operations(&device.clone().into())
|
||||
.into_iter()
|
||||
.map(|ops| Optimization::new(ops, FusionStatus::Open(FusionProperties::default())))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
optimizations,
|
||||
execution: GraphExecution::new(B::optimizations(&device.clone().into())),
|
||||
graph: Graph::new(),
|
||||
handles: HandleContainer::new(device.clone()),
|
||||
execution: G::default(),
|
||||
num_skipped: 0,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, desc: TensorOpsDescription, op: Box<dyn Ops<B>>) {
|
||||
let ops = Arc::new(desc);
|
||||
self.graph.add(ops.clone(), op);
|
||||
|
||||
self.optimizations
|
||||
.iter_mut()
|
||||
.for_each(|optimization| optimization.register(&ops));
|
||||
|
||||
self.execution.maybe_execute(
|
||||
&mut self.graph,
|
||||
&mut self.handles,
|
||||
&mut self.optimizations,
|
||||
false,
|
||||
);
|
||||
pub fn register(&mut self, ops_desc: TensorOpsDescription, ops: Box<dyn Ops<B>>) {
|
||||
self.graph.add(ops_desc, ops);
|
||||
self.execution
|
||||
.execute(&mut self.graph, &mut self.handles, ExecutionMode::NewOps);
|
||||
}
|
||||
|
||||
pub fn drain_graph(&mut self) {
|
||||
if self.graph.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.execution.maybe_execute(
|
||||
&mut self.graph,
|
||||
&mut self.handles,
|
||||
&mut self.optimizations,
|
||||
true,
|
||||
);
|
||||
// Check if we can execute.
|
||||
self.execution
|
||||
.execute(&mut self.graph, &mut self.handles, ExecutionMode::Sync);
|
||||
}
|
||||
|
||||
pub fn create_empty_handle(&mut self) -> Arc<TensorId> {
|
||||
|
|
|
@ -26,7 +26,7 @@ pub struct Data<E, const D: usize> {
|
|||
}
|
||||
|
||||
/// Distribution for random value of a tensor.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Distribution {
|
||||
/// Uniform distribution from 0 (inclusive) to 1 (exclusive).
|
||||
Default,
|
||||
|
|
|
@ -66,7 +66,7 @@ pub struct Conv1dBackward<B: Backend> {
|
|||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone, Hash)]
|
||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct ConvOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
@ -82,7 +82,7 @@ pub struct ConvOptions<const N: usize> {
|
|||
}
|
||||
|
||||
/// Transposed convolution options.
|
||||
#[derive(new, Debug, Clone, Hash)]
|
||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct ConvTransposeOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
use crate::{
|
||||
compute::{WgpuComputeClient, WgpuHandle},
|
||||
element::WgpuElement,
|
||||
fusion::FloatElementWiseFusionOps,
|
||||
fusion::FloatElementWiseBuilder,
|
||||
tensor::WgpuTensor,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu, WgpuDevice,
|
||||
};
|
||||
use burn_fusion::{
|
||||
client::MutexFusionClient, graph::GreedyGraphExecution, DeviceId, FusionBackend, FusionDevice,
|
||||
};
|
||||
use burn_fusion::{client::MutexFusionClient, DeviceId, FusionBackend, FusionDevice};
|
||||
use burn_tensor::Shape;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
@ -31,10 +29,10 @@ where
|
|||
{
|
||||
type FusionDevice = WgpuDevice;
|
||||
type Handle = WgpuFusionHandle;
|
||||
type FusionClient = MutexFusionClient<Self, GreedyGraphExecution>;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
||||
fn operations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::FusionOps<Self>>> {
|
||||
vec![Box::new(FloatElementWiseFusionOps::new(device.clone()))]
|
||||
fn optimizations(device: &WgpuDevice) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self>>> {
|
||||
vec![Box::new(FloatElementWiseBuilder::new(device.clone()))]
|
||||
}
|
||||
|
||||
fn float_tensor<const D: usize>(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
fusion::codegen::{Elem, Operator, Variable},
|
||||
fusion::kernel::FusionKernel,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_fusion::{
|
||||
|
@ -9,13 +8,16 @@ use burn_fusion::{
|
|||
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
|
||||
},
|
||||
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
|
||||
Optimization, OptimizationBuilder, OptimizationProperties, OptimizationStatus,
|
||||
TensorDescription, TensorId,
|
||||
};
|
||||
use burn_tensor::{Device, Element};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::optimization::FloatElementWise;
|
||||
|
||||
/// Fused element wise operations that are normally memory bound.
|
||||
pub struct FloatElementWiseFusionOps<G, F, I>
|
||||
pub(crate) struct FloatElementWiseBuilder<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
|
@ -24,48 +26,56 @@ where
|
|||
pub(crate) inputs: Vec<TensorDescription>,
|
||||
pub(crate) locals: HashMap<TensorId, u16>,
|
||||
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
|
||||
pub(crate) scalars_f32: Vec<f32>,
|
||||
pub(crate) scalars_i32: Vec<i32>,
|
||||
pub(crate) scalars_u32: Vec<u32>,
|
||||
pub(crate) booleans: Vec<bool>,
|
||||
pub(crate) scalars_f32: usize,
|
||||
pub(crate) scalars_i32: usize,
|
||||
pub(crate) scalars_u32: usize,
|
||||
pub(crate) booleans: usize,
|
||||
pub(crate) operators: Vec<Operator>,
|
||||
pub(crate) properties: FusionProperties,
|
||||
pub(crate) current_output_shape: Vec<usize>,
|
||||
device: Device<Wgpu<G, F, I>>,
|
||||
pub(crate) status: OptimizationStatus,
|
||||
pub(crate) device: Device<Wgpu<G, F, I>>,
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G, F, I>>
|
||||
for FloatElementWiseFusionOps<G, F, I>
|
||||
impl<G, F, I> OptimizationBuilder<Wgpu<G, F, I>> for FloatElementWiseBuilder<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
|
||||
fn register(&mut self, ops: &TensorOpsDescription) {
|
||||
if let OptimizationStatus::Closed = self.status {
|
||||
return;
|
||||
}
|
||||
|
||||
match ops {
|
||||
TensorOpsDescription::BaseOpsFloat(ops) => {
|
||||
if !self.register_base::<F>(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::FloatOps(ops) => {
|
||||
if !self.register_float::<F>(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::NumericOpsFloat(ops) => {
|
||||
if !self.register_numeric(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
self.properties.score += 1;
|
||||
self.properties.ready = self.operators.len() > 1;
|
||||
|
||||
FusionStatus::Open(self.properties)
|
||||
self.status = OptimizationStatus::Open;
|
||||
}
|
||||
|
||||
fn execute(&mut self, handles: &mut HandleContainer<Wgpu<G, F, I>>) {
|
||||
fn build(&self) -> Box<dyn Optimization<Wgpu<G, F, I>>> {
|
||||
let inputs = self.input_descriptions();
|
||||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
|
@ -73,32 +83,47 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
|||
.map(|out| *self.locals.get(&out.0.id).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
FusionKernel::new(&self.device)
|
||||
.inputs(&inputs, &self.scalars_f32)
|
||||
.body(&self.operators)
|
||||
.outputs(&outputs, &locals)
|
||||
.execute(handles);
|
||||
Box::new(FloatElementWise {
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
operators: self.operators.clone(),
|
||||
scalars_f32: self.scalars_f32,
|
||||
device: self.device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.inputs.clear();
|
||||
self.locals.drain();
|
||||
self.tensors.clear();
|
||||
self.scalars_f32.clear();
|
||||
self.scalars_i32.clear();
|
||||
self.scalars_u32.clear();
|
||||
self.booleans.clear();
|
||||
self.scalars_f32 = 0;
|
||||
self.scalars_i32 = 0;
|
||||
self.scalars_u32 = 0;
|
||||
self.booleans = 0;
|
||||
self.operators.clear();
|
||||
self.properties = FusionProperties::default();
|
||||
self.status = OptimizationStatus::Open;
|
||||
self.current_output_shape.clear();
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
fn status(&self) -> OptimizationStatus {
|
||||
self.status
|
||||
}
|
||||
|
||||
fn properties(&self) -> OptimizationProperties {
|
||||
let ready = match self.status {
|
||||
OptimizationStatus::Closed => false,
|
||||
OptimizationStatus::Open => self.operators.len() > 1,
|
||||
};
|
||||
|
||||
OptimizationProperties {
|
||||
ready,
|
||||
score: self.operators.len() as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G, F, I> FloatElementWiseFusionOps<G, F, I>
|
||||
impl<G, F, I> FloatElementWiseBuilder<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
|
@ -109,28 +134,28 @@ where
|
|||
inputs: Vec::new(),
|
||||
locals: HashMap::new(),
|
||||
tensors: HashMap::new(),
|
||||
scalars_f32: Vec::new(),
|
||||
scalars_i32: Vec::new(),
|
||||
scalars_u32: Vec::new(),
|
||||
booleans: Vec::new(),
|
||||
scalars_f32: 0,
|
||||
scalars_i32: 0,
|
||||
scalars_u32: 0,
|
||||
booleans: 0,
|
||||
operators: Vec::new(),
|
||||
current_output_shape: Vec::new(),
|
||||
properties: FusionProperties::default(),
|
||||
status: OptimizationStatus::Open,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||
fn input_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
let updated_tensor = self.tensors.get(&input.id).unwrap();
|
||||
updated_tensor
|
||||
updated_tensor.clone()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||
fn output_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut local_tensor_ids_input = Vec::new();
|
||||
let mut local_tensor_ids_output = Vec::new();
|
||||
|
@ -271,7 +296,7 @@ where
|
|||
let is_read = local_tensor_ids_input.contains(&out);
|
||||
|
||||
if !is_read {
|
||||
outputs.push(self.tensors.get(&out).unwrap());
|
||||
outputs.push(self.tensors.get(&out).unwrap().clone());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,7 +306,7 @@ where
|
|||
let (tensor, _) = &entry;
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if self.locals.contains_key(&tensor.id) {
|
||||
outputs.push(entry);
|
||||
outputs.push(entry.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -503,12 +528,13 @@ where
|
|||
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
|
||||
let out = self.output_to_var(&desc.out, E::elem_type());
|
||||
|
||||
self.operators.push(Operator::ConditionalAssign {
|
||||
let ops = Operator::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
};
|
||||
self.operators.push(ops);
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -600,19 +626,19 @@ where
|
|||
true
|
||||
}
|
||||
|
||||
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
|
||||
fn scalar_to_var<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
|
||||
match elem_type {
|
||||
Elem::F32 => {
|
||||
self.scalars_f32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
|
||||
self.scalars_f32 += 1;
|
||||
Variable::Scalar(self.scalars_f32 as u16 - 1, Elem::F32)
|
||||
}
|
||||
Elem::I32 => {
|
||||
self.scalars_i32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
|
||||
self.scalars_i32 += 1;
|
||||
Variable::Scalar(self.scalars_i32 as u16 - 1, Elem::I32)
|
||||
}
|
||||
Elem::U32 => {
|
||||
self.scalars_u32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
|
||||
self.scalars_u32 += 1;
|
||||
Variable::Scalar(self.scalars_u32 as u16 - 1, Elem::U32)
|
||||
}
|
||||
Elem::Bool => {
|
||||
panic!("Bool scalars not supported")
|
||||
|
@ -630,50 +656,3 @@ where
|
|||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_fusion::graph::Ops;
|
||||
use burn_fusion::{Fusion, FusionBackend};
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
struct FakeAddOps;
|
||||
|
||||
impl<B: FusionBackend> Ops<B> for FakeAddOps {
|
||||
fn execute(self: Box<Self>, _: &mut HandleContainer<B>) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_same_behavior() {
|
||||
type Backend = Wgpu;
|
||||
type FusedBackend = Fusion<Wgpu>;
|
||||
|
||||
let data_1 =
|
||||
Tensor::<Backend, 2>::random([1, 32], burn_tensor::Distribution::Default).into_data();
|
||||
let data_2 =
|
||||
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
|
||||
|
||||
let tensor_1 = Tensor::<Backend, 2>::from_data(data_1.clone());
|
||||
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4.clone() + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||
let mask = tensor_4.lower_equal(tensor_3);
|
||||
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||
|
||||
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4.clone() + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||
let mask = tensor_4.lower_equal(tensor_3);
|
||||
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||
|
||||
result_fused.assert_approx_eq(&result_ref, 3);
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod ops;
|
||||
mod builder;
|
||||
mod optimization;
|
||||
|
||||
pub use ops::*;
|
||||
pub(crate) use builder::*;
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
use crate::{
|
||||
fusion::codegen::{Elem, Operator},
|
||||
fusion::kernel::FusionKernel,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_fusion::{graph::Context, Optimization, TensorDescription};
|
||||
use burn_tensor::Device;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct FloatElementWise<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(crate) locals: Vec<u16>,
|
||||
pub(crate) operators: Vec<Operator>,
|
||||
pub(crate) scalars_f32: usize,
|
||||
pub(crate) device: Device<Wgpu<G, F, I>>,
|
||||
}
|
||||
|
||||
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn execute(&self, context: &mut Context<'_, Wgpu<G, F, I>>) {
|
||||
let inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// The context may contain scalars for the end condition, which may vary.
|
||||
let scalars_f32 = &context.scalar_floats[0..self.scalars_f32];
|
||||
|
||||
FusionKernel::new(&self.device)
|
||||
.inputs(&inputs, scalars_f32)
|
||||
.body(&self.operators)
|
||||
.outputs(&outputs, &self.locals)
|
||||
.execute(context.handles);
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_fusion::graph::Ops;
|
||||
use burn_fusion::{Fusion, FusionBackend};
|
||||
use burn_tensor::{backend::Backend, Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_fusion_same_behavior() {
|
||||
type Backend = Wgpu;
|
||||
type FusedBackend = Fusion<Wgpu>;
|
||||
|
||||
let data_1 = Tensor::<FusedBackend, 2>::random([1, 32], burn_tensor::Distribution::Default)
|
||||
.into_data();
|
||||
let data_2 =
|
||||
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
|
||||
|
||||
let result_ref = execute::<Backend>(
|
||||
data_1.clone(),
|
||||
data_2.clone(),
|
||||
ImplementationDetails::Variant1,
|
||||
);
|
||||
let result_fused = execute::<FusedBackend>(
|
||||
data_1.clone(),
|
||||
data_2.clone(),
|
||||
ImplementationDetails::Variant1,
|
||||
);
|
||||
|
||||
result_ref.assert_approx_eq(&result_fused, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fusion_same_behavior_different_variant() {
|
||||
type Backend = Wgpu;
|
||||
type FusedBackend = Fusion<Wgpu>;
|
||||
|
||||
let data_1 = Tensor::<FusedBackend, 2>::random([1, 32], burn_tensor::Distribution::Default)
|
||||
.into_data();
|
||||
let data_2 =
|
||||
Tensor::<Backend, 2>::random([32, 32], burn_tensor::Distribution::Default).into_data();
|
||||
|
||||
let result_ref = execute::<Backend>(
|
||||
data_1.clone(),
|
||||
data_2.clone(),
|
||||
ImplementationDetails::Variant2,
|
||||
);
|
||||
let result_fused_variant1 = execute::<FusedBackend>(
|
||||
data_1.clone(),
|
||||
data_2.clone(),
|
||||
ImplementationDetails::Variant1,
|
||||
);
|
||||
let result_fused_variant2 = execute::<FusedBackend>(
|
||||
data_1.clone(),
|
||||
data_2.clone(),
|
||||
ImplementationDetails::Variant2,
|
||||
);
|
||||
|
||||
result_ref.assert_approx_eq(&result_fused_variant1, 3);
|
||||
result_ref.assert_approx_eq(&result_fused_variant2, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_end_condition_scalar_ops() {
|
||||
type Backend = Fusion<Wgpu>;
|
||||
let tensor1 = Tensor::<Backend, 2>::ones([32, 32]);
|
||||
let tensor2 = Tensor::<Backend, 2>::ones([32, 42]);
|
||||
let output = tensor1.exp().log();
|
||||
|
||||
// This will add a scalar to the context, even if the actual operation can't be fused with
|
||||
// the preceding ones because of the shape difference.
|
||||
let _ = tensor2 + 2;
|
||||
|
||||
// When we try to execute the operations, the number of bindings can be different if we are
|
||||
// not careful.
|
||||
Backend::sync(&output.device());
|
||||
}
|
||||
|
||||
struct FakeAddOps;
|
||||
|
||||
impl<B: FusionBackend> Ops<B> for FakeAddOps {
|
||||
fn execute(self: Box<Self>, _: &mut burn_fusion::HandleContainer<B>) {
|
||||
panic!("Should always fused during tests.")
|
||||
}
|
||||
}
|
||||
|
||||
enum ImplementationDetails {
|
||||
Variant1,
|
||||
Variant2,
|
||||
}
|
||||
fn execute<B: Backend>(
|
||||
data_1: Data<f32, 2>,
|
||||
data_2: Data<f32, 2>,
|
||||
variant: ImplementationDetails,
|
||||
) -> Data<f32, 2> {
|
||||
let tensor_1 = Tensor::<B, 2>::from_data(data_1.convert());
|
||||
let tensor_2 = Tensor::<B, 2>::from_data(data_2.convert());
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let mut tensor_5 = tensor_4.clone() + 5.0;
|
||||
match variant {
|
||||
ImplementationDetails::Variant1 => {}
|
||||
ImplementationDetails::Variant2 => {
|
||||
tensor_5 = tensor_5 + 1;
|
||||
tensor_5 = tensor_5 - 1;
|
||||
}
|
||||
}
|
||||
let tensor_6 = burn_tensor::activation::gelu(tensor_5 + tensor_3.clone());
|
||||
let mask = tensor_4.lower_equal(tensor_3);
|
||||
let tmp = tensor_6.mask_fill(mask, 0.3);
|
||||
|
||||
tmp.into_data().convert()
|
||||
}
|
||||
}
|
|
@ -83,7 +83,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
|
|||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(
|
||||
mut self,
|
||||
inputs_tensor: &[&(TensorDescription, Elem)],
|
||||
inputs_tensor: &[(&TensorDescription, Elem)],
|
||||
inputs_scalar_f32: &[f32],
|
||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
||||
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
|
||||
|
@ -198,7 +198,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
|
|||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(
|
||||
mut self,
|
||||
outputs: &[&(TensorDescription, Elem)],
|
||||
outputs: &[(&TensorDescription, Elem)],
|
||||
locals: &[u16],
|
||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
let mut num_elems_launch_option = 0;
|
||||
|
|
|
@ -5,4 +5,4 @@ pub(crate) mod codegen;
|
|||
pub(crate) mod kernel;
|
||||
|
||||
pub use base::*;
|
||||
pub use elemwise::*;
|
||||
pub(crate) use elemwise::*;
|
||||
|
|
Loading…
Reference in New Issue