Feat/fusion/cmp (#992)

This commit is contained in:
Louis Fortier-Dubois 2023-11-23 12:52:37 -05:00 committed by GitHub
parent b86bc58761
commit 58273a8441
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 495 additions and 148 deletions

View File

@ -9,6 +9,8 @@ where
fn type_name() -> &'static str; fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8]; fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self]; fn from_bytes(bytes: &[u8]) -> &[Self];
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem;
} }
/// The float element type for the wgpu backend. /// The float element type for the wgpu backend.
@ -27,6 +29,10 @@ impl WgpuElement for u32 {
fn from_bytes(bytes: &[u8]) -> &[Self] { fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes) bytemuck::cast_slice(bytes)
} }
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::U32
}
} }
impl WgpuElement for i32 { impl WgpuElement for i32 {
@ -39,6 +45,10 @@ impl WgpuElement for i32 {
fn from_bytes(bytes: &[u8]) -> &[Self] { fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes) bytemuck::cast_slice(bytes)
} }
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::I32
}
} }
impl WgpuElement for f32 { impl WgpuElement for f32 {
@ -51,6 +61,11 @@ impl WgpuElement for f32 {
fn from_bytes(bytes: &[u8]) -> &[Self] { fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes) bytemuck::cast_slice(bytes)
} }
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::F32
}
} }
impl FloatElement for f32 {} impl FloatElement for f32 {}

View File

@ -65,6 +65,37 @@ pub enum Operator {
input: Variable, input: Variable,
out: Variable, out: Variable,
}, },
Equal {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Lower {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Greater {
lhs: Variable,
rhs: Variable,
out: Variable,
},
LowerEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
GreaterEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
ConditionalAssign {
cond: Variable,
lhs: Variable,
rhs: Variable,
out: Variable,
},
AssignGlobal { AssignGlobal {
input: Variable, input: Variable,
out: Variable, out: Variable,
@ -109,22 +140,41 @@ impl Display for Operator {
Operator::Recip { input, out } => { Operator::Recip { input, out } => {
f.write_fmt(format_args!("let {out} = 1.0 / {input};")) f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
} }
Operator::Equal { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} == {rhs};"))
}
Operator::Lower { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} < {rhs};"))
}
Operator::Greater { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} > {rhs};"))
}
Operator::LowerEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} <= {rhs};"))
}
Operator::GreaterEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} >= {rhs};"))
}
Operator::AssignGlobal { input, out } => { Operator::AssignGlobal { input, out } => {
f.write_fmt(format_args!("{out}_global[id] = {input};")) let elem = out.elem();
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
} }
Operator::ReadGlobal { Operator::ReadGlobal {
variable, variable,
position, position,
position_out, position_out,
} => { } => {
let (global, local) = match variable { let (global, local, elem) = match variable {
Variable::Input(number) => { Variable::Input(number, elem) => (
(format!("input_{number}_global"), format!("input_{number}")) format!("input_{number}_global"),
} format!("input_{number}"),
Variable::Local(_) => panic!("can't read globala local variable."), elem,
Variable::Output(number) => ( ),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, elem) => (
format!("output_{number}_global"), format!("output_{number}_global"),
format!("output_{number}"), format!("output_{number}"),
elem,
), ),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
}; };
@ -144,7 +194,25 @@ for (var i: u32 = 1u; i <= rank; i++) {{
index_{local} += id / stride_out % shape * stride; index_{local} += id / stride_out % shape * stride;
}} }}
let {local} = {global}[index_{local}]; let {local} = {elem}({global}[index_{local}]);
"
))
}
Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
let elem = out.elem();
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond} {{
{out} = {lhs};
}} else {{
{out} = {rhs};
}}
" "
)) ))
} }

View File

@ -19,12 +19,13 @@ pub enum Visibility {
ReadWrite, ReadWrite,
} }
#[derive(Debug, Clone, Hash, PartialEq, Eq)] #[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub enum Elem { pub enum Elem {
F32, F32,
#[allow(dead_code)] #[allow(dead_code)]
I32, I32,
U32, U32,
Bool,
} }
#[derive(Hash, PartialEq, Eq)] #[derive(Hash, PartialEq, Eq)]
@ -187,6 +188,7 @@ impl Display for Elem {
Elem::F32 => f.write_str("f32"), Elem::F32 => f.write_str("f32"),
Elem::I32 => f.write_str("i32"), Elem::I32 => f.write_str("i32"),
Elem::U32 => f.write_str("u32"), Elem::U32 => f.write_str("u32"),
Elem::Bool => f.write_str("bool"),
} }
} }
} }

View File

@ -3,18 +3,29 @@ use std::fmt::Display;
#[derive(Debug, Hash, Clone)] #[derive(Debug, Hash, Clone)]
pub enum Variable { pub enum Variable {
Input(u16), Input(u16, Elem),
Scalar(u16, Elem), Scalar(u16, Elem),
Local(u16), Local(u16, Elem),
Output(u16), Output(u16, Elem),
}
impl Variable {
pub fn elem(&self) -> &Elem {
match self {
Variable::Input(_, e) => e,
Variable::Scalar(_, e) => e,
Variable::Local(_, e) => e,
Variable::Output(_, e) => e,
}
}
} }
impl Display for Variable { impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")), Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")), Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")), Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")), Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
} }
} }

View File

@ -1,12 +1,13 @@
use crate::{ use crate::{
element::WgpuElement,
fusion::codegen::{Elem, Operator, Variable}, fusion::codegen::{Elem, Operator, Variable},
fusion::kernel::FusionKernel, fusion::kernel::FusionKernel,
FloatElement, GraphicsApi, IntElement, Wgpu, FloatElement, GraphicsApi, IntElement, Wgpu,
}; };
use burn_fusion::{ use burn_fusion::{
graph::{ graph::{
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription, BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
TensorOpsDescription, UnaryOpsDescription, ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
}, },
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId, FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
}; };
@ -22,8 +23,11 @@ where
{ {
pub(crate) inputs: Vec<TensorDescription>, pub(crate) inputs: Vec<TensorDescription>,
pub(crate) locals: HashMap<TensorId, u16>, pub(crate) locals: HashMap<TensorId, u16>,
pub(crate) tensors: HashMap<TensorId, TensorDescription>, pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
pub(crate) scalars_f32: Vec<f32>, pub(crate) scalars_f32: Vec<f32>,
pub(crate) scalars_i32: Vec<i32>,
pub(crate) scalars_u32: Vec<u32>,
pub(crate) booleans: Vec<bool>,
pub(crate) operators: Vec<Operator>, pub(crate) operators: Vec<Operator>,
pub(crate) properties: FusionProperties, pub(crate) properties: FusionProperties,
pub(crate) current_output_shape: Vec<usize>, pub(crate) current_output_shape: Vec<usize>,
@ -35,8 +39,13 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
{ {
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus { fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
match ops { match ops {
TensorOpsDescription::BaseOpsFloat(ops) => {
if !self.register_base::<F>(ops) {
return FusionStatus::Closed(self.properties);
}
}
TensorOpsDescription::FloatOps(ops) => { TensorOpsDescription::FloatOps(ops) => {
if !self.register_float(ops) { if !self.register_float::<F>(ops) {
return FusionStatus::Closed(self.properties); return FusionStatus::Closed(self.properties);
} }
} }
@ -61,7 +70,7 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
let outputs = self.output_descriptions(); let outputs = self.output_descriptions();
let locals = outputs let locals = outputs
.iter() .iter()
.map(|out| *self.locals.get(&out.id).unwrap()) .map(|out| *self.locals.get(&out.0.id).unwrap())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
FusionKernel::new(&self.device) FusionKernel::new(&self.device)
@ -76,6 +85,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
self.locals.drain(); self.locals.drain();
self.tensors.clear(); self.tensors.clear();
self.scalars_f32.clear(); self.scalars_f32.clear();
self.scalars_i32.clear();
self.scalars_u32.clear();
self.booleans.clear();
self.operators.clear(); self.operators.clear();
self.properties = FusionProperties::default(); self.properties = FusionProperties::default();
self.current_output_shape.clear(); self.current_output_shape.clear();
@ -98,6 +110,9 @@ where
locals: HashMap::new(), locals: HashMap::new(),
tensors: HashMap::new(), tensors: HashMap::new(),
scalars_f32: Vec::new(), scalars_f32: Vec::new(),
scalars_i32: Vec::new(),
scalars_u32: Vec::new(),
booleans: Vec::new(),
operators: Vec::new(), operators: Vec::new(),
current_output_shape: Vec::new(), current_output_shape: Vec::new(),
properties: FusionProperties::default(), properties: FusionProperties::default(),
@ -105,7 +120,7 @@ where
} }
} }
fn input_descriptions(&self) -> Vec<&TensorDescription> { fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
self.inputs self.inputs
.iter() .iter()
.map(|input| { .map(|input| {
@ -115,7 +130,7 @@ where
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
fn output_descriptions(&self) -> Vec<&TensorDescription> { fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
let mut outputs = Vec::new(); let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new(); let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new(); let mut local_tensor_ids_output = Vec::new();
@ -124,7 +139,7 @@ where
// //
// Only local variables can become outputs. // Only local variables can become outputs.
let mark = |var: &Variable, list: &mut Vec<TensorId>| { let mark = |var: &Variable, list: &mut Vec<TensorId>| {
if let Variable::Local(index) = var { if let Variable::Local(index, _) = var {
if let Some((id, _)) = self if let Some((id, _)) = self
.locals .locals
.iter() .iter()
@ -211,6 +226,42 @@ where
mark(input, &mut local_tensor_ids_input); mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output); mark(out, &mut local_tensor_ids_output);
} }
Operator::Lower { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Greater { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::LowerEqual { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::GreaterEqual { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Equal { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
mark(cond, &mut local_tensor_ids_input);
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
} }
} }
@ -226,10 +277,11 @@ where
// All tensors where their latest description is read only should be written to since they // All tensors where their latest description is read only should be written to since they
// are going to be used after the fused kernel by other operations. // are going to be used after the fused kernel by other operations.
for tensor in self.tensors.values() { for entry in self.tensors.values() {
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status { if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.locals.contains_key(&tensor.id) { if self.locals.contains_key(&tensor.id) {
outputs.push(tensor); outputs.push(entry);
} }
} }
} }
@ -237,19 +289,19 @@ where
outputs outputs
} }
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable { fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
let already_exists = self.tensors.contains_key(&tensor.id); let already_exists = self.tensors.contains_key(&tensor.id);
let variable = match already_exists { let variable = match already_exists {
false => { false => {
// New input // New input
let var = Variable::Input(self.inputs.len() as u16); let var = Variable::Input(self.inputs.len() as u16, elem);
self.inputs.push(tensor.clone()); self.inputs.push(tensor.clone());
var var
} }
true => match self.locals.get(&tensor.id) { true => match self.locals.get(&tensor.id) {
// Is a local variable. // Is a local variable.
Some(local_index) => Variable::Local(*local_index), Some(local_index) => Variable::Local(*local_index, elem),
// Isn't a local variable, so must be an existing input. // Isn't a local variable, so must be an existing input.
None => { None => {
let input = self let input = self
@ -259,134 +311,234 @@ where
.find(|(_, input)| input.id == tensor.id) .find(|(_, input)| input.id == tensor.id)
.unwrap(); .unwrap();
let input_index = input.0; let input_index = input.0;
Variable::Input(input_index as u16) Variable::Input(input_index as u16, elem)
} }
}, },
}; };
// Update the tensor description with the new version. // Update the tensor description with the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone()); self.tensors
.insert(tensor.id.clone(), (tensor.clone(), elem));
variable variable
} }
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable { fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
// Update the tensor description to the new version. // Update the tensor description to the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone()); self.tensors
.insert(tensor.id.clone(), (tensor.clone(), elem));
// Output already registered as a local variable. // Output already registered as a local variable.
if let Some(index) = self.locals.get(&tensor.id) { if let Some(index) = self.locals.get(&tensor.id) {
return Variable::Local(*index); return Variable::Local(*index, elem);
} }
// New local variable. // New local variable.
let local_index = self.locals.len() as u16; let local_index = self.locals.len() as u16;
self.locals.insert(tensor.id.clone(), local_index); self.locals.insert(tensor.id.clone(), local_index);
Variable::Local(local_index) Variable::Local(local_index, elem)
} }
fn register_float(&mut self, ops: &FloatOpsDescription) -> bool { fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
match ops {
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
),
_ => false,
}
}
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
match ops { match ops {
FloatOpsDescription::Exp(desc) => { FloatOpsDescription::Exp(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Exp { input, out }
})
} }
FloatOpsDescription::Log(desc) => { FloatOpsDescription::Log(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Log { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log { input, out }
})
} }
FloatOpsDescription::Log1p(desc) => { FloatOpsDescription::Log1p(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log1p { input, out }
})
} }
FloatOpsDescription::Cos(desc) => { FloatOpsDescription::Cos(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Cos { input, out }
})
} }
FloatOpsDescription::Sin(desc) => { FloatOpsDescription::Sin(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
} Operator::Sin { input, out }
FloatOpsDescription::Powf(desc) => { })
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
} }
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
),
FloatOpsDescription::Tanh(desc) => { FloatOpsDescription::Tanh(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Tanh { input, out }
})
} }
FloatOpsDescription::Erf(desc) => { FloatOpsDescription::Erf(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Erf { input, out }
})
} }
FloatOpsDescription::Recip(desc) => { FloatOpsDescription::Recip(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Recip { input, out }
})
} }
_ => false, _ => false,
} }
} }
fn register_numeric<E: Element>(&mut self, ops: &NumericOpsDescription<E>) -> bool { fn register_numeric<E: WgpuElement>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
match ops { match ops {
NumericOpsDescription::Add(desc) => { NumericOpsDescription::Add(desc) => self.register_binary_ops(
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) desc,
} (E::elem_type(), E::elem_type(), E::elem_type()),
NumericOpsDescription::AddScalar(desc) => { |lhs, rhs, out| Operator::Add { lhs, rhs, out },
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out }) ),
} NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
NumericOpsDescription::Sub(desc) => { desc,
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) (E::elem_type(), E::elem_type(), E::elem_type()),
} |lhs, rhs, out| Operator::Add { lhs, rhs, out },
NumericOpsDescription::SubScalar(desc) => { ),
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out }) NumericOpsDescription::Sub(desc) => self.register_binary_ops(
} desc,
NumericOpsDescription::Mul(desc) => { (E::elem_type(), E::elem_type(), E::elem_type()),
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) |lhs, rhs, out| Operator::Sub { lhs, rhs, out },
} ),
NumericOpsDescription::MulScalar(desc) => { NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out }) desc,
} (E::elem_type(), E::elem_type(), E::elem_type()),
NumericOpsDescription::Div(desc) => { |lhs, rhs, out| Operator::Sub { lhs, rhs, out },
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) ),
} NumericOpsDescription::Mul(desc) => self.register_binary_ops(
NumericOpsDescription::DivScalar(desc) => { desc,
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out }) (E::elem_type(), E::elem_type(), E::elem_type()),
} |lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::Div(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::Abs(desc) => { NumericOpsDescription::Abs(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out }) self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Abs { input, out }
})
}
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
),
NumericOpsDescription::MaskWhere(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.input_to_var(&desc.value, E::elem_type());
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
let out = self.output_to_var(&desc.out, E::elem_type());
self.operators.push(Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
true
}
NumericOpsDescription::MaskFill(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.scalar_to_var(&desc.value, E::elem_type());
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
let out = self.output_to_var(&desc.out, E::elem_type());
self.operators.push(Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
true
} }
_ => false, _ => false,
} }
} }
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool fn register_binary_ops<Func>(
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs);
let rhs = self.input_to_var(&desc.rhs);
let out = self.output_to_var(&desc.out);
self.operators.push(func(lhs, rhs, out));
true
}
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
where
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input);
let out = self.output_to_var(&desc.out);
self.operators.push(func(input, out));
true
}
fn register_scalar_ops<Func, E: Element>(
&mut self, &mut self,
desc: &ScalarOpsDescription<E>, desc: &BinaryOpsDescription,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func, func: Func,
) -> bool ) -> bool
where where
@ -396,16 +548,78 @@ where
return false; return false;
} }
let lhs = self.input_to_var(&desc.lhs); let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32); let rhs = self.input_to_var(&desc.rhs, elem_rhs);
self.scalars_f32.push(desc.rhs.elem()); let out = self.output_to_var(&desc.out, elem_out);
let out = self.output_to_var(&desc.out);
self.operators.push(func(lhs, rhs, out)); self.operators.push(func(lhs, rhs, out));
true true
} }
fn register_unary_ops<Func>(
&mut self,
desc: &UnaryOpsDescription,
(elem_input, elem_out): (Elem, Elem),
func: Func,
) -> bool
where
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input, elem_input);
let out = self.output_to_var(&desc.out, elem_out);
self.operators.push(func(input, out));
true
}
fn register_scalar_ops<Func, E: Element>(
&mut self,
desc: &ScalarOpsDescription<E>,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func,
) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
let out = self.output_to_var(&desc.out, elem_out);
self.operators.push(func(lhs, rhs, out));
true
}
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
match elem_type {
Elem::F32 => {
self.scalars_f32.push(value.elem());
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
}
Elem::I32 => {
self.scalars_i32.push(value.elem());
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
}
Elem::U32 => {
self.scalars_u32.push(value.elem());
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
}
Elem::Bool => {
panic!("Bool scalars not supported")
}
}
}
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
if self.current_output_shape.is_empty() { if self.current_output_shape.is_empty() {
self.current_output_shape = out.shape.clone(); self.current_output_shape = out.shape.clone();
@ -446,17 +660,19 @@ mod tests {
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone()); let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() + tensor_2; let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1; let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0; let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3; let tensor_6 = tensor_5 + tensor_3.clone();
let result_ref = tensor_6.recip().into_data(); let mask = tensor_4.lower_equal(tensor_3);
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1); let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2); let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
let tensor_3 = tensor_1.clone() + tensor_2; let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1; let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0; let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3; let tensor_6 = tensor_5 + tensor_3.clone();
let result_fused = tensor_6.recip().into_data(); let mask = tensor_4.lower_equal(tensor_3);
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
result_fused.assert_approx_eq(&result_ref, 3); result_fused.assert_approx_eq(&result_ref, 3);
} }

View File

@ -83,25 +83,43 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
/// Register the inputs used by the kernel. /// Register the inputs used by the kernel.
pub fn inputs( pub fn inputs(
mut self, mut self,
inputs_tensor: &[&TensorDescription], inputs_tensor: &[&(TensorDescription, Elem)],
inputs_scalar_f32: &[f32], inputs_scalar_f32: &[f32],
) -> FusionKernel<G, F, I, BodyPhase> { ) -> FusionKernel<G, F, I, BodyPhase> {
for (i, input) in inputs_tensor.iter().enumerate() { for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
self.input_bindings.push(( if elem != &Elem::Bool {
Binding { self.input_bindings.push((
elem: Elem::F32, Binding {
visibility: Visibility::Read, elem: *elem,
location: Location::Storage, visibility: Visibility::Read,
size: None, location: Location::Storage,
}, size: None,
(*input).clone(), },
)); (*input).clone(),
));
self.operations.push(Operator::ReadGlobal { self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16), variable: Variable::Input(i as u16, *elem),
position: i, position: i,
position_out: inputs_tensor.len(), // First output position_out: inputs_tensor.len(), // First output
}); });
} else {
self.input_bindings.push((
Binding {
elem: Elem::I32,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16, *elem),
position: i,
position_out: inputs_tensor.len(), // First output
});
}
} }
if !inputs_scalar_f32.is_empty() { if !inputs_scalar_f32.is_empty() {
@ -180,31 +198,48 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs( pub fn outputs(
mut self, mut self,
outputs: &[&TensorDescription], outputs: &[&(TensorDescription, Elem)],
locals: &[u16], locals: &[u16],
) -> FusionKernel<G, F, I, ExecutionPhase> { ) -> FusionKernel<G, F, I, ExecutionPhase> {
let mut num_elems_launch_option = 0; let mut num_elems_launch_option = 0;
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() { for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
if num_elems_output > num_elems_launch_option { if num_elems_output > num_elems_launch_option {
num_elems_launch_option = num_elems_output; num_elems_launch_option = num_elems_output;
} }
self.output_bindings.push(( if elem != &Elem::Bool {
Binding { self.output_bindings.push((
elem: Elem::F32, Binding {
visibility: Visibility::ReadWrite, elem: *elem,
location: Location::Storage, visibility: Visibility::ReadWrite,
size: None, location: Location::Storage,
}, size: None,
(*output).clone(), },
)); (*output).clone(),
));
self.operations.push(Operator::AssignGlobal { self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local), input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16), out: Variable::Output(i as u16, *elem),
}); });
} else {
self.output_bindings.push((
Binding {
elem: Elem::I32, // I32 are used for bool tensors
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16, Elem::I32),
});
}
} }
self.num_elems_output = num_elems_launch_option; self.num_elems_output = num_elems_launch_option;