mirror of https://github.com/tracel-ai/burn.git
Feat/fusion/cmp (#992)
This commit is contained in:
parent
b86bc58761
commit
58273a8441
|
@ -9,6 +9,8 @@ where
|
|||
fn type_name() -> &'static str;
|
||||
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem;
|
||||
}
|
||||
|
||||
/// The float element type for the wgpu backend.
|
||||
|
@ -27,6 +29,10 @@ impl WgpuElement for u32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::U32
|
||||
}
|
||||
}
|
||||
|
||||
impl WgpuElement for i32 {
|
||||
|
@ -39,6 +45,10 @@ impl WgpuElement for i32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::I32
|
||||
}
|
||||
}
|
||||
|
||||
impl WgpuElement for f32 {
|
||||
|
@ -51,6 +61,11 @@ impl WgpuElement for f32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::F32
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatElement for f32 {}
|
||||
|
|
|
@ -65,6 +65,37 @@ pub enum Operator {
|
|||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Equal {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Lower {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Greater {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
LowerEqual {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
GreaterEqual {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ConditionalAssign {
|
||||
cond: Variable,
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
AssignGlobal {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
|
@ -109,22 +140,41 @@ impl Display for Operator {
|
|||
Operator::Recip { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
|
||||
}
|
||||
Operator::Equal { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} == {rhs};"))
|
||||
}
|
||||
Operator::Lower { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} < {rhs};"))
|
||||
}
|
||||
Operator::Greater { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} > {rhs};"))
|
||||
}
|
||||
Operator::LowerEqual { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} <= {rhs};"))
|
||||
}
|
||||
Operator::GreaterEqual { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} >= {rhs};"))
|
||||
}
|
||||
Operator::AssignGlobal { input, out } => {
|
||||
f.write_fmt(format_args!("{out}_global[id] = {input};"))
|
||||
let elem = out.elem();
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
||||
}
|
||||
Operator::ReadGlobal {
|
||||
variable,
|
||||
position,
|
||||
position_out,
|
||||
} => {
|
||||
let (global, local) = match variable {
|
||||
Variable::Input(number) => {
|
||||
(format!("input_{number}_global"), format!("input_{number}"))
|
||||
}
|
||||
Variable::Local(_) => panic!("can't read globala local variable."),
|
||||
Variable::Output(number) => (
|
||||
let (global, local, elem) = match variable {
|
||||
Variable::Input(number, elem) => (
|
||||
format!("input_{number}_global"),
|
||||
format!("input_{number}"),
|
||||
elem,
|
||||
),
|
||||
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||
Variable::Output(number, elem) => (
|
||||
format!("output_{number}_global"),
|
||||
format!("output_{number}"),
|
||||
elem,
|
||||
),
|
||||
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||
};
|
||||
|
@ -144,7 +194,25 @@ for (var i: u32 = 1u; i <= rank; i++) {{
|
|||
index_{local} += id / stride_out % shape * stride;
|
||||
}}
|
||||
|
||||
let {local} = {global}[index_{local}];
|
||||
let {local} = {elem}({global}[index_{local}]);
|
||||
"
|
||||
))
|
||||
}
|
||||
Operator::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
} => {
|
||||
let elem = out.elem();
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var {out}: {elem};
|
||||
if {cond} {{
|
||||
{out} = {lhs};
|
||||
}} else {{
|
||||
{out} = {rhs};
|
||||
}}
|
||||
"
|
||||
))
|
||||
}
|
||||
|
|
|
@ -19,12 +19,13 @@ pub enum Visibility {
|
|||
ReadWrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
||||
pub enum Elem {
|
||||
F32,
|
||||
#[allow(dead_code)]
|
||||
I32,
|
||||
U32,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
|
@ -187,6 +188,7 @@ impl Display for Elem {
|
|||
Elem::F32 => f.write_str("f32"),
|
||||
Elem::I32 => f.write_str("i32"),
|
||||
Elem::U32 => f.write_str("u32"),
|
||||
Elem::Bool => f.write_str("bool"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,18 +3,29 @@ use std::fmt::Display;
|
|||
|
||||
#[derive(Debug, Hash, Clone)]
|
||||
pub enum Variable {
|
||||
Input(u16),
|
||||
Input(u16, Elem),
|
||||
Scalar(u16, Elem),
|
||||
Local(u16),
|
||||
Output(u16),
|
||||
Local(u16, Elem),
|
||||
Output(u16, Elem),
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn elem(&self) -> &Elem {
|
||||
match self {
|
||||
Variable::Input(_, e) => e,
|
||||
Variable::Scalar(_, e) => e,
|
||||
Variable::Local(_, e) => e,
|
||||
Variable::Output(_, e) => e,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Variable {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
|
||||
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
|
||||
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
fusion::codegen::{Elem, Operator, Variable},
|
||||
fusion::kernel::FusionKernel,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_fusion::{
|
||||
graph::{
|
||||
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription,
|
||||
TensorOpsDescription, UnaryOpsDescription,
|
||||
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
|
||||
},
|
||||
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
|
||||
};
|
||||
|
@ -22,8 +23,11 @@ where
|
|||
{
|
||||
pub(crate) inputs: Vec<TensorDescription>,
|
||||
pub(crate) locals: HashMap<TensorId, u16>,
|
||||
pub(crate) tensors: HashMap<TensorId, TensorDescription>,
|
||||
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
|
||||
pub(crate) scalars_f32: Vec<f32>,
|
||||
pub(crate) scalars_i32: Vec<i32>,
|
||||
pub(crate) scalars_u32: Vec<u32>,
|
||||
pub(crate) booleans: Vec<bool>,
|
||||
pub(crate) operators: Vec<Operator>,
|
||||
pub(crate) properties: FusionProperties,
|
||||
pub(crate) current_output_shape: Vec<usize>,
|
||||
|
@ -35,8 +39,13 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
|||
{
|
||||
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
|
||||
match ops {
|
||||
TensorOpsDescription::BaseOpsFloat(ops) => {
|
||||
if !self.register_base::<F>(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
}
|
||||
}
|
||||
TensorOpsDescription::FloatOps(ops) => {
|
||||
if !self.register_float(ops) {
|
||||
if !self.register_float::<F>(ops) {
|
||||
return FusionStatus::Closed(self.properties);
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +70,7 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
|||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
.iter()
|
||||
.map(|out| *self.locals.get(&out.id).unwrap())
|
||||
.map(|out| *self.locals.get(&out.0.id).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
FusionKernel::new(&self.device)
|
||||
|
@ -76,6 +85,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
|||
self.locals.drain();
|
||||
self.tensors.clear();
|
||||
self.scalars_f32.clear();
|
||||
self.scalars_i32.clear();
|
||||
self.scalars_u32.clear();
|
||||
self.booleans.clear();
|
||||
self.operators.clear();
|
||||
self.properties = FusionProperties::default();
|
||||
self.current_output_shape.clear();
|
||||
|
@ -98,6 +110,9 @@ where
|
|||
locals: HashMap::new(),
|
||||
tensors: HashMap::new(),
|
||||
scalars_f32: Vec::new(),
|
||||
scalars_i32: Vec::new(),
|
||||
scalars_u32: Vec::new(),
|
||||
booleans: Vec::new(),
|
||||
operators: Vec::new(),
|
||||
current_output_shape: Vec::new(),
|
||||
properties: FusionProperties::default(),
|
||||
|
@ -105,7 +120,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn input_descriptions(&self) -> Vec<&TensorDescription> {
|
||||
fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
|
@ -115,7 +130,7 @@ where
|
|||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn output_descriptions(&self) -> Vec<&TensorDescription> {
|
||||
fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut local_tensor_ids_input = Vec::new();
|
||||
let mut local_tensor_ids_output = Vec::new();
|
||||
|
@ -124,7 +139,7 @@ where
|
|||
//
|
||||
// Only local variables can become outputs.
|
||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||
if let Variable::Local(index) = var {
|
||||
if let Variable::Local(index, _) = var {
|
||||
if let Some((id, _)) = self
|
||||
.locals
|
||||
.iter()
|
||||
|
@ -211,6 +226,42 @@ where
|
|||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Lower { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Greater { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::LowerEqual { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::GreaterEqual { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Equal { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
} => {
|
||||
mark(cond, &mut local_tensor_ids_input);
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -226,10 +277,11 @@ where
|
|||
|
||||
// All tensors where their latest description is read only should be written to since they
|
||||
// are going to be used after the fused kernel by other operations.
|
||||
for tensor in self.tensors.values() {
|
||||
for entry in self.tensors.values() {
|
||||
let (tensor, _) = &entry;
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if self.locals.contains_key(&tensor.id) {
|
||||
outputs.push(tensor);
|
||||
outputs.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -237,19 +289,19 @@ where
|
|||
outputs
|
||||
}
|
||||
|
||||
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
||||
fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||
|
||||
let variable = match already_exists {
|
||||
false => {
|
||||
// New input
|
||||
let var = Variable::Input(self.inputs.len() as u16);
|
||||
let var = Variable::Input(self.inputs.len() as u16, elem);
|
||||
self.inputs.push(tensor.clone());
|
||||
var
|
||||
}
|
||||
true => match self.locals.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => Variable::Local(*local_index),
|
||||
Some(local_index) => Variable::Local(*local_index, elem),
|
||||
// Isn't a local variable, so must be an existing input.
|
||||
None => {
|
||||
let input = self
|
||||
|
@ -259,134 +311,234 @@ where
|
|||
.find(|(_, input)| input.id == tensor.id)
|
||||
.unwrap();
|
||||
let input_index = input.0;
|
||||
Variable::Input(input_index as u16)
|
||||
Variable::Input(input_index as u16, elem)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Update the tensor description with the new version.
|
||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
||||
self.tensors
|
||||
.insert(tensor.id.clone(), (tensor.clone(), elem));
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
||||
fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||
// Update the tensor description to the new version.
|
||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
||||
self.tensors
|
||||
.insert(tensor.id.clone(), (tensor.clone(), elem));
|
||||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.locals.get(&tensor.id) {
|
||||
return Variable::Local(*index);
|
||||
return Variable::Local(*index, elem);
|
||||
}
|
||||
|
||||
// New local variable.
|
||||
let local_index = self.locals.len() as u16;
|
||||
self.locals.insert(tensor.id.clone(), local_index);
|
||||
Variable::Local(local_index)
|
||||
Variable::Local(local_index, elem)
|
||||
}
|
||||
|
||||
fn register_float(&mut self, ops: &FloatOpsDescription) -> bool {
|
||||
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
|
||||
match ops {
|
||||
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||
),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
|
||||
match ops {
|
||||
FloatOpsDescription::Exp(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Exp { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Log(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Log { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Log { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Log1p(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Log1p { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Cos(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Cos { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Sin(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out })
|
||||
}
|
||||
FloatOpsDescription::Powf(desc) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Sin { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
|
||||
),
|
||||
FloatOpsDescription::Tanh(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Tanh { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Erf(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Erf { input, out }
|
||||
})
|
||||
}
|
||||
FloatOpsDescription::Recip(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Recip { input, out }
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_numeric<E: Element>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
|
||||
fn register_numeric<E: WgpuElement>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
|
||||
match ops {
|
||||
NumericOpsDescription::Add(desc) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::AddScalar(desc) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Sub(desc) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::SubScalar(desc) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Mul(desc) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::MulScalar(desc) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Div(desc) => {
|
||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::DivScalar(desc) => {
|
||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
||||
}
|
||||
NumericOpsDescription::Add(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Sub(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Mul(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Div(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Abs(desc) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out })
|
||||
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||
Operator::Abs { input, out }
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||
),
|
||||
NumericOpsDescription::MaskWhere(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||
let lhs = self.input_to_var(&desc.value, E::elem_type());
|
||||
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
|
||||
let out = self.output_to_var(&desc.out, E::elem_type());
|
||||
|
||||
self.operators.push(Operator::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
|
||||
true
|
||||
}
|
||||
NumericOpsDescription::MaskFill(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||
let lhs = self.scalar_to_var(&desc.value, E::elem_type());
|
||||
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
|
||||
let out = self.output_to_var(&desc.out, E::elem_type());
|
||||
|
||||
self.operators.push(Operator::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs);
|
||||
let rhs = self.input_to_var(&desc.rhs);
|
||||
let out = self.output_to_var(&desc.out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let input = self.input_to_var(&desc.input);
|
||||
let out = self.output_to_var(&desc.out);
|
||||
|
||||
self.operators.push(func(input, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_scalar_ops<Func, E: Element>(
|
||||
fn register_binary_ops<Func>(
|
||||
&mut self,
|
||||
desc: &ScalarOpsDescription<E>,
|
||||
desc: &BinaryOpsDescription,
|
||||
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
|
@ -396,16 +548,78 @@ where
|
|||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs);
|
||||
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32);
|
||||
self.scalars_f32.push(desc.rhs.elem());
|
||||
let out = self.output_to_var(&desc.out);
|
||||
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||
let rhs = self.input_to_var(&desc.rhs, elem_rhs);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_unary_ops<Func>(
|
||||
&mut self,
|
||||
desc: &UnaryOpsDescription,
|
||||
(elem_input, elem_out): (Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let input = self.input_to_var(&desc.input, elem_input);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(input, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn register_scalar_ops<Func, E: Element>(
|
||||
&mut self,
|
||||
desc: &ScalarOpsDescription<E>,
|
||||
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
|
||||
match elem_type {
|
||||
Elem::F32 => {
|
||||
self.scalars_f32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
|
||||
}
|
||||
Elem::I32 => {
|
||||
self.scalars_i32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
|
||||
}
|
||||
Elem::U32 => {
|
||||
self.scalars_u32.push(value.elem());
|
||||
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
|
||||
}
|
||||
Elem::Bool => {
|
||||
panic!("Bool scalars not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
|
||||
if self.current_output_shape.is_empty() {
|
||||
self.current_output_shape = out.shape.clone();
|
||||
|
@ -446,17 +660,19 @@ mod tests {
|
|||
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_ref = tensor_6.recip().into_data();
|
||||
let tensor_5 = tensor_4.clone() + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||
let mask = tensor_4.lower_equal(tensor_3);
|
||||
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||
|
||||
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_fused = tensor_6.recip().into_data();
|
||||
let tensor_5 = tensor_4.clone() + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||
let mask = tensor_4.lower_equal(tensor_3);
|
||||
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||
|
||||
result_fused.assert_approx_eq(&result_ref, 3);
|
||||
}
|
||||
|
|
|
@ -83,25 +83,43 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
|
|||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(
|
||||
mut self,
|
||||
inputs_tensor: &[&TensorDescription],
|
||||
inputs_tensor: &[&(TensorDescription, Elem)],
|
||||
inputs_scalar_f32: &[f32],
|
||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
||||
for (i, input) in inputs_tensor.iter().enumerate() {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
|
||||
if elem != &Elem::Bool {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: *elem,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16, *elem),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
} else {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::I32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16, *elem),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !inputs_scalar_f32.is_empty() {
|
||||
|
@ -180,31 +198,48 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
|
|||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(
|
||||
mut self,
|
||||
outputs: &[&TensorDescription],
|
||||
outputs: &[&(TensorDescription, Elem)],
|
||||
locals: &[u16],
|
||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
let mut num_elems_launch_option = 0;
|
||||
|
||||
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() {
|
||||
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
|
||||
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
||||
if num_elems_output > num_elems_launch_option {
|
||||
num_elems_launch_option = num_elems_output;
|
||||
}
|
||||
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
if elem != &Elem::Bool {
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: *elem,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local),
|
||||
out: Variable::Output(i as u16),
|
||||
});
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Output(i as u16, *elem),
|
||||
});
|
||||
} else {
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::I32, // I32 are used for bool tensors
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Output(i as u16, Elem::I32),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.num_elems_output = num_elems_launch_option;
|
||||
|
|
Loading…
Reference in New Issue