Feat/fusion/cmp (#992)

This commit is contained in:
Louis Fortier-Dubois 2023-11-23 12:52:37 -05:00 committed by GitHub
parent b86bc58761
commit 58273a8441
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 495 additions and 148 deletions

View File

@ -9,6 +9,8 @@ where
fn type_name() -> &'static str;
fn as_bytes(slice: &[Self]) -> &[u8];
fn from_bytes(bytes: &[u8]) -> &[Self];
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem;
}
/// The float element type for the wgpu backend.
@ -27,6 +29,10 @@ impl WgpuElement for u32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::U32
}
}
impl WgpuElement for i32 {
@ -39,6 +45,10 @@ impl WgpuElement for i32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::I32
}
}
impl WgpuElement for f32 {
@ -51,6 +61,11 @@ impl WgpuElement for f32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
#[cfg(any(feature = "fusion", test))]
fn elem_type() -> crate::fusion::codegen::Elem {
crate::fusion::codegen::Elem::F32
}
}
impl FloatElement for f32 {}

View File

@ -65,6 +65,37 @@ pub enum Operator {
input: Variable,
out: Variable,
},
Equal {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Lower {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Greater {
lhs: Variable,
rhs: Variable,
out: Variable,
},
LowerEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
GreaterEqual {
lhs: Variable,
rhs: Variable,
out: Variable,
},
ConditionalAssign {
cond: Variable,
lhs: Variable,
rhs: Variable,
out: Variable,
},
AssignGlobal {
input: Variable,
out: Variable,
@ -109,22 +140,41 @@ impl Display for Operator {
Operator::Recip { input, out } => {
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
}
Operator::Equal { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} == {rhs};"))
}
Operator::Lower { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} < {rhs};"))
}
Operator::Greater { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} > {rhs};"))
}
Operator::LowerEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} <= {rhs};"))
}
Operator::GreaterEqual { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} >= {rhs};"))
}
Operator::AssignGlobal { input, out } => {
f.write_fmt(format_args!("{out}_global[id] = {input};"))
let elem = out.elem();
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
}
Operator::ReadGlobal {
variable,
position,
position_out,
} => {
let (global, local) = match variable {
Variable::Input(number) => {
(format!("input_{number}_global"), format!("input_{number}"))
}
Variable::Local(_) => panic!("can't read globala local variable."),
Variable::Output(number) => (
let (global, local, elem) = match variable {
Variable::Input(number, elem) => (
format!("input_{number}_global"),
format!("input_{number}"),
elem,
),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, elem) => (
format!("output_{number}_global"),
format!("output_{number}"),
elem,
),
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
};
@ -144,7 +194,25 @@ for (var i: u32 = 1u; i <= rank; i++) {{
index_{local} += id / stride_out % shape * stride;
}}
let {local} = {global}[index_{local}];
let {local} = {elem}({global}[index_{local}]);
"
))
}
Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
let elem = out.elem();
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond} {{
{out} = {lhs};
}} else {{
{out} = {rhs};
}}
"
))
}

View File

@ -19,12 +19,13 @@ pub enum Visibility {
ReadWrite,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
#[allow(dead_code)]
I32,
U32,
Bool,
}
#[derive(Hash, PartialEq, Eq)]
@ -187,6 +188,7 @@ impl Display for Elem {
Elem::F32 => f.write_str("f32"),
Elem::I32 => f.write_str("i32"),
Elem::U32 => f.write_str("u32"),
Elem::Bool => f.write_str("bool"),
}
}
}

View File

@ -3,18 +3,29 @@ use std::fmt::Display;
#[derive(Debug, Hash, Clone)]
pub enum Variable {
Input(u16),
Input(u16, Elem),
Scalar(u16, Elem),
Local(u16),
Output(u16),
Local(u16, Elem),
Output(u16, Elem),
}
impl Variable {
pub fn elem(&self) -> &Elem {
match self {
Variable::Input(_, e) => e,
Variable::Scalar(_, e) => e,
Variable::Local(_, e) => e,
Variable::Output(_, e) => e,
}
}
}
impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
}
}

View File

@ -1,12 +1,13 @@
use crate::{
element::WgpuElement,
fusion::codegen::{Elem, Operator, Variable},
fusion::kernel::FusionKernel,
FloatElement, GraphicsApi, IntElement, Wgpu,
};
use burn_fusion::{
graph::{
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription,
TensorOpsDescription, UnaryOpsDescription,
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
},
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
};
@ -22,8 +23,11 @@ where
{
pub(crate) inputs: Vec<TensorDescription>,
pub(crate) locals: HashMap<TensorId, u16>,
pub(crate) tensors: HashMap<TensorId, TensorDescription>,
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
pub(crate) scalars_f32: Vec<f32>,
pub(crate) scalars_i32: Vec<i32>,
pub(crate) scalars_u32: Vec<u32>,
pub(crate) booleans: Vec<bool>,
pub(crate) operators: Vec<Operator>,
pub(crate) properties: FusionProperties,
pub(crate) current_output_shape: Vec<usize>,
@ -35,8 +39,13 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
{
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
match ops {
TensorOpsDescription::BaseOpsFloat(ops) => {
if !self.register_base::<F>(ops) {
return FusionStatus::Closed(self.properties);
}
}
TensorOpsDescription::FloatOps(ops) => {
if !self.register_float(ops) {
if !self.register_float::<F>(ops) {
return FusionStatus::Closed(self.properties);
}
}
@ -61,7 +70,7 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
let outputs = self.output_descriptions();
let locals = outputs
.iter()
.map(|out| *self.locals.get(&out.id).unwrap())
.map(|out| *self.locals.get(&out.0.id).unwrap())
.collect::<Vec<_>>();
FusionKernel::new(&self.device)
@ -76,6 +85,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
self.locals.drain();
self.tensors.clear();
self.scalars_f32.clear();
self.scalars_i32.clear();
self.scalars_u32.clear();
self.booleans.clear();
self.operators.clear();
self.properties = FusionProperties::default();
self.current_output_shape.clear();
@ -98,6 +110,9 @@ where
locals: HashMap::new(),
tensors: HashMap::new(),
scalars_f32: Vec::new(),
scalars_i32: Vec::new(),
scalars_u32: Vec::new(),
booleans: Vec::new(),
operators: Vec::new(),
current_output_shape: Vec::new(),
properties: FusionProperties::default(),
@ -105,7 +120,7 @@ where
}
}
fn input_descriptions(&self) -> Vec<&TensorDescription> {
fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
self.inputs
.iter()
.map(|input| {
@ -115,7 +130,7 @@ where
.collect::<Vec<_>>()
}
fn output_descriptions(&self) -> Vec<&TensorDescription> {
fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new();
@ -124,7 +139,7 @@ where
//
// Only local variables can become outputs.
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
if let Variable::Local(index) = var {
if let Variable::Local(index, _) = var {
if let Some((id, _)) = self
.locals
.iter()
@ -211,6 +226,42 @@ where
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Lower { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Greater { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::LowerEqual { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::GreaterEqual { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Equal { lhs, rhs, out } => {
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
} => {
mark(cond, &mut local_tensor_ids_input);
mark(lhs, &mut local_tensor_ids_input);
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
}
}
@ -226,10 +277,11 @@ where
// All tensors where their latest description is read only should be written to since they
// are going to be used after the fused kernel by other operations.
for tensor in self.tensors.values() {
for entry in self.tensors.values() {
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.locals.contains_key(&tensor.id) {
outputs.push(tensor);
outputs.push(entry);
}
}
}
@ -237,19 +289,19 @@ where
outputs
}
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable {
fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
let already_exists = self.tensors.contains_key(&tensor.id);
let variable = match already_exists {
false => {
// New input
let var = Variable::Input(self.inputs.len() as u16);
let var = Variable::Input(self.inputs.len() as u16, elem);
self.inputs.push(tensor.clone());
var
}
true => match self.locals.get(&tensor.id) {
// Is a local variable.
Some(local_index) => Variable::Local(*local_index),
Some(local_index) => Variable::Local(*local_index, elem),
// Isn't a local variable, so must be an existing input.
None => {
let input = self
@ -259,134 +311,234 @@ where
.find(|(_, input)| input.id == tensor.id)
.unwrap();
let input_index = input.0;
Variable::Input(input_index as u16)
Variable::Input(input_index as u16, elem)
}
},
};
// Update the tensor description with the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone());
self.tensors
.insert(tensor.id.clone(), (tensor.clone(), elem));
variable
}
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable {
fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
// Update the tensor description to the new version.
self.tensors.insert(tensor.id.clone(), tensor.clone());
self.tensors
.insert(tensor.id.clone(), (tensor.clone(), elem));
// Output already registered as a local variable.
if let Some(index) = self.locals.get(&tensor.id) {
return Variable::Local(*index);
return Variable::Local(*index, elem);
}
// New local variable.
let local_index = self.locals.len() as u16;
self.locals.insert(tensor.id.clone(), local_index);
Variable::Local(local_index)
Variable::Local(local_index, elem)
}
fn register_float(&mut self, ops: &FloatOpsDescription) -> bool {
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
match ops {
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
),
_ => false,
}
}
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
match ops {
FloatOpsDescription::Exp(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Exp { input, out }
})
}
FloatOpsDescription::Log(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Log { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log { input, out }
})
}
FloatOpsDescription::Log1p(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Log1p { input, out }
})
}
FloatOpsDescription::Cos(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Cos { input, out }
})
}
FloatOpsDescription::Sin(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out })
}
FloatOpsDescription::Powf(desc) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Sin { input, out }
})
}
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
),
FloatOpsDescription::Tanh(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Tanh { input, out }
})
}
FloatOpsDescription::Erf(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Erf { input, out }
})
}
FloatOpsDescription::Recip(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Recip { input, out }
})
}
_ => false,
}
}
fn register_numeric<E: Element>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
fn register_numeric<E: WgpuElement>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
match ops {
NumericOpsDescription::Add(desc) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
}
NumericOpsDescription::AddScalar(desc) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
}
NumericOpsDescription::Sub(desc) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
}
NumericOpsDescription::SubScalar(desc) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
}
NumericOpsDescription::Mul(desc) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
}
NumericOpsDescription::MulScalar(desc) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
}
NumericOpsDescription::Div(desc) => {
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
}
NumericOpsDescription::DivScalar(desc) => {
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
}
NumericOpsDescription::Add(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
),
NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
),
NumericOpsDescription::Sub(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
),
NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
),
NumericOpsDescription::Mul(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
),
NumericOpsDescription::Div(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), E::elem_type()),
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
),
NumericOpsDescription::Abs(desc) => {
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out })
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
Operator::Abs { input, out }
})
}
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
),
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
),
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
),
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
desc,
(E::elem_type(), E::elem_type(), Elem::Bool),
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
),
NumericOpsDescription::MaskWhere(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.input_to_var(&desc.value, E::elem_type());
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
let out = self.output_to_var(&desc.out, E::elem_type());
self.operators.push(Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
true
}
NumericOpsDescription::MaskFill(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.scalar_to_var(&desc.value, E::elem_type());
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
let out = self.output_to_var(&desc.out, E::elem_type());
self.operators.push(Operator::ConditionalAssign {
cond,
lhs,
rhs,
out,
});
true
}
_ => false,
}
}
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs);
let rhs = self.input_to_var(&desc.rhs);
let out = self.output_to_var(&desc.out);
self.operators.push(func(lhs, rhs, out));
true
}
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
where
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input);
let out = self.output_to_var(&desc.out);
self.operators.push(func(input, out));
true
}
fn register_scalar_ops<Func, E: Element>(
fn register_binary_ops<Func>(
&mut self,
desc: &ScalarOpsDescription<E>,
desc: &BinaryOpsDescription,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func,
) -> bool
where
@ -396,16 +548,78 @@ where
return false;
}
let lhs = self.input_to_var(&desc.lhs);
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32);
self.scalars_f32.push(desc.rhs.elem());
let out = self.output_to_var(&desc.out);
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = self.input_to_var(&desc.rhs, elem_rhs);
let out = self.output_to_var(&desc.out, elem_out);
self.operators.push(func(lhs, rhs, out));
true
}
fn register_unary_ops<Func>(
&mut self,
desc: &UnaryOpsDescription,
(elem_input, elem_out): (Elem, Elem),
func: Func,
) -> bool
where
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input, elem_input);
let out = self.output_to_var(&desc.out, elem_out);
self.operators.push(func(input, out));
true
}
fn register_scalar_ops<Func, E: Element>(
&mut self,
desc: &ScalarOpsDescription<E>,
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
func: Func,
) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
let out = self.output_to_var(&desc.out, elem_out);
self.operators.push(func(lhs, rhs, out));
true
}
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
match elem_type {
Elem::F32 => {
self.scalars_f32.push(value.elem());
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
}
Elem::I32 => {
self.scalars_i32.push(value.elem());
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
}
Elem::U32 => {
self.scalars_u32.push(value.elem());
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
}
Elem::Bool => {
panic!("Bool scalars not supported")
}
}
}
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
if self.current_output_shape.is_empty() {
self.current_output_shape = out.shape.clone();
@ -446,17 +660,19 @@ mod tests {
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_ref = tensor_6.recip().into_data();
let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3.clone();
let mask = tensor_4.lower_equal(tensor_3);
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
let tensor_3 = tensor_1.clone() + tensor_2;
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_fused = tensor_6.recip().into_data();
let tensor_5 = tensor_4.clone() + 5.0;
let tensor_6 = tensor_5 + tensor_3.clone();
let mask = tensor_4.lower_equal(tensor_3);
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
result_fused.assert_approx_eq(&result_ref, 3);
}

View File

@ -83,25 +83,43 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
/// Register the inputs used by the kernel.
pub fn inputs(
mut self,
inputs_tensor: &[&TensorDescription],
inputs_tensor: &[&(TensorDescription, Elem)],
inputs_scalar_f32: &[f32],
) -> FusionKernel<G, F, I, BodyPhase> {
for (i, input) in inputs_tensor.iter().enumerate() {
self.input_bindings.push((
Binding {
elem: Elem::F32,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
if elem != &Elem::Bool {
self.input_bindings.push((
Binding {
elem: *elem,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16),
position: i,
position_out: inputs_tensor.len(), // First output
});
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16, *elem),
position: i,
position_out: inputs_tensor.len(), // First output
});
} else {
self.input_bindings.push((
Binding {
elem: Elem::I32,
visibility: Visibility::Read,
location: Location::Storage,
size: None,
},
(*input).clone(),
));
self.operations.push(Operator::ReadGlobal {
variable: Variable::Input(i as u16, *elem),
position: i,
position_out: inputs_tensor.len(), // First output
});
}
}
if !inputs_scalar_f32.is_empty() {
@ -180,31 +198,48 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(
mut self,
outputs: &[&TensorDescription],
outputs: &[&(TensorDescription, Elem)],
locals: &[u16],
) -> FusionKernel<G, F, I, ExecutionPhase> {
let mut num_elems_launch_option = 0;
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() {
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
if num_elems_output > num_elems_launch_option {
num_elems_launch_option = num_elems_output;
}
self.output_bindings.push((
Binding {
elem: Elem::F32,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
if elem != &Elem::Bool {
self.output_bindings.push((
Binding {
elem: *elem,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local),
out: Variable::Output(i as u16),
});
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16, *elem),
});
} else {
self.output_bindings.push((
Binding {
elem: Elem::I32, // I32 are used for bool tensors
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
},
(*output).clone(),
));
self.operations.push(Operator::AssignGlobal {
input: Variable::Local(*local, *elem),
out: Variable::Output(i as u16, Elem::I32),
});
}
}
self.num_elems_output = num_elems_launch_option;