mirror of https://github.com/tracel-ai/burn.git
Feat/fusion/cmp (#992)
This commit is contained in:
parent
b86bc58761
commit
58273a8441
|
@ -9,6 +9,8 @@ where
|
||||||
fn type_name() -> &'static str;
|
fn type_name() -> &'static str;
|
||||||
fn as_bytes(slice: &[Self]) -> &[u8];
|
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||||
|
#[cfg(any(feature = "fusion", test))]
|
||||||
|
fn elem_type() -> crate::fusion::codegen::Elem;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The float element type for the wgpu backend.
|
/// The float element type for the wgpu backend.
|
||||||
|
@ -27,6 +29,10 @@ impl WgpuElement for u32 {
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
|
#[cfg(any(feature = "fusion", test))]
|
||||||
|
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||||
|
crate::fusion::codegen::Elem::U32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WgpuElement for i32 {
|
impl WgpuElement for i32 {
|
||||||
|
@ -39,6 +45,10 @@ impl WgpuElement for i32 {
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
|
#[cfg(any(feature = "fusion", test))]
|
||||||
|
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||||
|
crate::fusion::codegen::Elem::I32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WgpuElement for f32 {
|
impl WgpuElement for f32 {
|
||||||
|
@ -51,6 +61,11 @@ impl WgpuElement for f32 {
|
||||||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||||
bytemuck::cast_slice(bytes)
|
bytemuck::cast_slice(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(any(feature = "fusion", test))]
|
||||||
|
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||||
|
crate::fusion::codegen::Elem::F32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FloatElement for f32 {}
|
impl FloatElement for f32 {}
|
||||||
|
|
|
@ -65,6 +65,37 @@ pub enum Operator {
|
||||||
input: Variable,
|
input: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
Equal {
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
|
Lower {
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
|
Greater {
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
|
LowerEqual {
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
|
GreaterEqual {
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
|
ConditionalAssign {
|
||||||
|
cond: Variable,
|
||||||
|
lhs: Variable,
|
||||||
|
rhs: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
AssignGlobal {
|
AssignGlobal {
|
||||||
input: Variable,
|
input: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
|
@ -109,22 +140,41 @@ impl Display for Operator {
|
||||||
Operator::Recip { input, out } => {
|
Operator::Recip { input, out } => {
|
||||||
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
|
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
|
||||||
}
|
}
|
||||||
|
Operator::Equal { lhs, rhs, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = {lhs} == {rhs};"))
|
||||||
|
}
|
||||||
|
Operator::Lower { lhs, rhs, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = {lhs} < {rhs};"))
|
||||||
|
}
|
||||||
|
Operator::Greater { lhs, rhs, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = {lhs} > {rhs};"))
|
||||||
|
}
|
||||||
|
Operator::LowerEqual { lhs, rhs, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = {lhs} <= {rhs};"))
|
||||||
|
}
|
||||||
|
Operator::GreaterEqual { lhs, rhs, out } => {
|
||||||
|
f.write_fmt(format_args!("let {out} = {lhs} >= {rhs};"))
|
||||||
|
}
|
||||||
Operator::AssignGlobal { input, out } => {
|
Operator::AssignGlobal { input, out } => {
|
||||||
f.write_fmt(format_args!("{out}_global[id] = {input};"))
|
let elem = out.elem();
|
||||||
|
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
||||||
}
|
}
|
||||||
Operator::ReadGlobal {
|
Operator::ReadGlobal {
|
||||||
variable,
|
variable,
|
||||||
position,
|
position,
|
||||||
position_out,
|
position_out,
|
||||||
} => {
|
} => {
|
||||||
let (global, local) = match variable {
|
let (global, local, elem) = match variable {
|
||||||
Variable::Input(number) => {
|
Variable::Input(number, elem) => (
|
||||||
(format!("input_{number}_global"), format!("input_{number}"))
|
format!("input_{number}_global"),
|
||||||
}
|
format!("input_{number}"),
|
||||||
Variable::Local(_) => panic!("can't read globala local variable."),
|
elem,
|
||||||
Variable::Output(number) => (
|
),
|
||||||
|
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||||
|
Variable::Output(number, elem) => (
|
||||||
format!("output_{number}_global"),
|
format!("output_{number}_global"),
|
||||||
format!("output_{number}"),
|
format!("output_{number}"),
|
||||||
|
elem,
|
||||||
),
|
),
|
||||||
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||||
};
|
};
|
||||||
|
@ -144,7 +194,25 @@ for (var i: u32 = 1u; i <= rank; i++) {{
|
||||||
index_{local} += id / stride_out % shape * stride;
|
index_{local} += id / stride_out % shape * stride;
|
||||||
}}
|
}}
|
||||||
|
|
||||||
let {local} = {global}[index_{local}];
|
let {local} = {elem}({global}[index_{local}]);
|
||||||
|
"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
Operator::ConditionalAssign {
|
||||||
|
cond,
|
||||||
|
lhs,
|
||||||
|
rhs,
|
||||||
|
out,
|
||||||
|
} => {
|
||||||
|
let elem = out.elem();
|
||||||
|
f.write_fmt(format_args!(
|
||||||
|
"
|
||||||
|
var {out}: {elem};
|
||||||
|
if {cond} {{
|
||||||
|
{out} = {lhs};
|
||||||
|
}} else {{
|
||||||
|
{out} = {rhs};
|
||||||
|
}}
|
||||||
"
|
"
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,12 +19,13 @@ pub enum Visibility {
|
||||||
ReadWrite,
|
ReadWrite,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
||||||
pub enum Elem {
|
pub enum Elem {
|
||||||
F32,
|
F32,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
I32,
|
I32,
|
||||||
U32,
|
U32,
|
||||||
|
Bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(Hash, PartialEq, Eq)]
|
||||||
|
@ -187,6 +188,7 @@ impl Display for Elem {
|
||||||
Elem::F32 => f.write_str("f32"),
|
Elem::F32 => f.write_str("f32"),
|
||||||
Elem::I32 => f.write_str("i32"),
|
Elem::I32 => f.write_str("i32"),
|
||||||
Elem::U32 => f.write_str("u32"),
|
Elem::U32 => f.write_str("u32"),
|
||||||
|
Elem::Bool => f.write_str("bool"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,18 +3,29 @@ use std::fmt::Display;
|
||||||
|
|
||||||
#[derive(Debug, Hash, Clone)]
|
#[derive(Debug, Hash, Clone)]
|
||||||
pub enum Variable {
|
pub enum Variable {
|
||||||
Input(u16),
|
Input(u16, Elem),
|
||||||
Scalar(u16, Elem),
|
Scalar(u16, Elem),
|
||||||
Local(u16),
|
Local(u16, Elem),
|
||||||
Output(u16),
|
Output(u16, Elem),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Variable {
|
||||||
|
pub fn elem(&self) -> &Elem {
|
||||||
|
match self {
|
||||||
|
Variable::Input(_, e) => e,
|
||||||
|
Variable::Scalar(_, e) => e,
|
||||||
|
Variable::Local(_, e) => e,
|
||||||
|
Variable::Output(_, e) => e,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Variable {
|
impl Display for Variable {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Variable::Input(number) => f.write_fmt(format_args!("input_{number}")),
|
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||||
Variable::Local(number) => f.write_fmt(format_args!("local_{number}")),
|
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
|
||||||
Variable::Output(number) => f.write_fmt(format_args!("output_{number}")),
|
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||||
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
|
Variable::Scalar(number, elem) => f.write_fmt(format_args!("scalars_{elem}[{number}]")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
|
element::WgpuElement,
|
||||||
fusion::codegen::{Elem, Operator, Variable},
|
fusion::codegen::{Elem, Operator, Variable},
|
||||||
fusion::kernel::FusionKernel,
|
fusion::kernel::FusionKernel,
|
||||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||||
};
|
};
|
||||||
use burn_fusion::{
|
use burn_fusion::{
|
||||||
graph::{
|
graph::{
|
||||||
BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, ScalarOpsDescription,
|
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||||
TensorOpsDescription, UnaryOpsDescription,
|
ScalarOpsDescription, TensorOpsDescription, UnaryOpsDescription,
|
||||||
},
|
},
|
||||||
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
|
FusionOps, FusionProperties, FusionStatus, HandleContainer, TensorDescription, TensorId,
|
||||||
};
|
};
|
||||||
|
@ -22,8 +23,11 @@ where
|
||||||
{
|
{
|
||||||
pub(crate) inputs: Vec<TensorDescription>,
|
pub(crate) inputs: Vec<TensorDescription>,
|
||||||
pub(crate) locals: HashMap<TensorId, u16>,
|
pub(crate) locals: HashMap<TensorId, u16>,
|
||||||
pub(crate) tensors: HashMap<TensorId, TensorDescription>,
|
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
|
||||||
pub(crate) scalars_f32: Vec<f32>,
|
pub(crate) scalars_f32: Vec<f32>,
|
||||||
|
pub(crate) scalars_i32: Vec<i32>,
|
||||||
|
pub(crate) scalars_u32: Vec<u32>,
|
||||||
|
pub(crate) booleans: Vec<bool>,
|
||||||
pub(crate) operators: Vec<Operator>,
|
pub(crate) operators: Vec<Operator>,
|
||||||
pub(crate) properties: FusionProperties,
|
pub(crate) properties: FusionProperties,
|
||||||
pub(crate) current_output_shape: Vec<usize>,
|
pub(crate) current_output_shape: Vec<usize>,
|
||||||
|
@ -35,8 +39,13 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
||||||
{
|
{
|
||||||
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
|
fn register(&mut self, ops: &TensorOpsDescription) -> FusionStatus {
|
||||||
match ops {
|
match ops {
|
||||||
|
TensorOpsDescription::BaseOpsFloat(ops) => {
|
||||||
|
if !self.register_base::<F>(ops) {
|
||||||
|
return FusionStatus::Closed(self.properties);
|
||||||
|
}
|
||||||
|
}
|
||||||
TensorOpsDescription::FloatOps(ops) => {
|
TensorOpsDescription::FloatOps(ops) => {
|
||||||
if !self.register_float(ops) {
|
if !self.register_float::<F>(ops) {
|
||||||
return FusionStatus::Closed(self.properties);
|
return FusionStatus::Closed(self.properties);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,7 +70,7 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
||||||
let outputs = self.output_descriptions();
|
let outputs = self.output_descriptions();
|
||||||
let locals = outputs
|
let locals = outputs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|out| *self.locals.get(&out.id).unwrap())
|
.map(|out| *self.locals.get(&out.0.id).unwrap())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
FusionKernel::new(&self.device)
|
FusionKernel::new(&self.device)
|
||||||
|
@ -76,6 +85,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> FusionOps<Wgpu<G,
|
||||||
self.locals.drain();
|
self.locals.drain();
|
||||||
self.tensors.clear();
|
self.tensors.clear();
|
||||||
self.scalars_f32.clear();
|
self.scalars_f32.clear();
|
||||||
|
self.scalars_i32.clear();
|
||||||
|
self.scalars_u32.clear();
|
||||||
|
self.booleans.clear();
|
||||||
self.operators.clear();
|
self.operators.clear();
|
||||||
self.properties = FusionProperties::default();
|
self.properties = FusionProperties::default();
|
||||||
self.current_output_shape.clear();
|
self.current_output_shape.clear();
|
||||||
|
@ -98,6 +110,9 @@ where
|
||||||
locals: HashMap::new(),
|
locals: HashMap::new(),
|
||||||
tensors: HashMap::new(),
|
tensors: HashMap::new(),
|
||||||
scalars_f32: Vec::new(),
|
scalars_f32: Vec::new(),
|
||||||
|
scalars_i32: Vec::new(),
|
||||||
|
scalars_u32: Vec::new(),
|
||||||
|
booleans: Vec::new(),
|
||||||
operators: Vec::new(),
|
operators: Vec::new(),
|
||||||
current_output_shape: Vec::new(),
|
current_output_shape: Vec::new(),
|
||||||
properties: FusionProperties::default(),
|
properties: FusionProperties::default(),
|
||||||
|
@ -105,7 +120,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn input_descriptions(&self) -> Vec<&TensorDescription> {
|
fn input_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||||
self.inputs
|
self.inputs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|input| {
|
.map(|input| {
|
||||||
|
@ -115,7 +130,7 @@ where
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn output_descriptions(&self) -> Vec<&TensorDescription> {
|
fn output_descriptions(&self) -> Vec<&(TensorDescription, Elem)> {
|
||||||
let mut outputs = Vec::new();
|
let mut outputs = Vec::new();
|
||||||
let mut local_tensor_ids_input = Vec::new();
|
let mut local_tensor_ids_input = Vec::new();
|
||||||
let mut local_tensor_ids_output = Vec::new();
|
let mut local_tensor_ids_output = Vec::new();
|
||||||
|
@ -124,7 +139,7 @@ where
|
||||||
//
|
//
|
||||||
// Only local variables can become outputs.
|
// Only local variables can become outputs.
|
||||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||||
if let Variable::Local(index) = var {
|
if let Variable::Local(index, _) = var {
|
||||||
if let Some((id, _)) = self
|
if let Some((id, _)) = self
|
||||||
.locals
|
.locals
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -211,6 +226,42 @@ where
|
||||||
mark(input, &mut local_tensor_ids_input);
|
mark(input, &mut local_tensor_ids_input);
|
||||||
mark(out, &mut local_tensor_ids_output);
|
mark(out, &mut local_tensor_ids_output);
|
||||||
}
|
}
|
||||||
|
Operator::Lower { lhs, rhs, out } => {
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
|
Operator::Greater { lhs, rhs, out } => {
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
|
Operator::LowerEqual { lhs, rhs, out } => {
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
|
Operator::GreaterEqual { lhs, rhs, out } => {
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
|
Operator::Equal { lhs, rhs, out } => {
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
|
Operator::ConditionalAssign {
|
||||||
|
cond,
|
||||||
|
lhs,
|
||||||
|
rhs,
|
||||||
|
out,
|
||||||
|
} => {
|
||||||
|
mark(cond, &mut local_tensor_ids_input);
|
||||||
|
mark(lhs, &mut local_tensor_ids_input);
|
||||||
|
mark(rhs, &mut local_tensor_ids_input);
|
||||||
|
mark(out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,10 +277,11 @@ where
|
||||||
|
|
||||||
// All tensors where their latest description is read only should be written to since they
|
// All tensors where their latest description is read only should be written to since they
|
||||||
// are going to be used after the fused kernel by other operations.
|
// are going to be used after the fused kernel by other operations.
|
||||||
for tensor in self.tensors.values() {
|
for entry in self.tensors.values() {
|
||||||
|
let (tensor, _) = &entry;
|
||||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||||
if self.locals.contains_key(&tensor.id) {
|
if self.locals.contains_key(&tensor.id) {
|
||||||
outputs.push(tensor);
|
outputs.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,19 +289,19 @@ where
|
||||||
outputs
|
outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
fn input_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||||
|
|
||||||
let variable = match already_exists {
|
let variable = match already_exists {
|
||||||
false => {
|
false => {
|
||||||
// New input
|
// New input
|
||||||
let var = Variable::Input(self.inputs.len() as u16);
|
let var = Variable::Input(self.inputs.len() as u16, elem);
|
||||||
self.inputs.push(tensor.clone());
|
self.inputs.push(tensor.clone());
|
||||||
var
|
var
|
||||||
}
|
}
|
||||||
true => match self.locals.get(&tensor.id) {
|
true => match self.locals.get(&tensor.id) {
|
||||||
// Is a local variable.
|
// Is a local variable.
|
||||||
Some(local_index) => Variable::Local(*local_index),
|
Some(local_index) => Variable::Local(*local_index, elem),
|
||||||
// Isn't a local variable, so must be an existing input.
|
// Isn't a local variable, so must be an existing input.
|
||||||
None => {
|
None => {
|
||||||
let input = self
|
let input = self
|
||||||
|
@ -259,134 +311,234 @@ where
|
||||||
.find(|(_, input)| input.id == tensor.id)
|
.find(|(_, input)| input.id == tensor.id)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let input_index = input.0;
|
let input_index = input.0;
|
||||||
Variable::Input(input_index as u16)
|
Variable::Input(input_index as u16, elem)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Update the tensor description with the new version.
|
// Update the tensor description with the new version.
|
||||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
self.tensors
|
||||||
|
.insert(tensor.id.clone(), (tensor.clone(), elem));
|
||||||
|
|
||||||
variable
|
variable
|
||||||
}
|
}
|
||||||
|
|
||||||
fn output_to_var(&mut self, tensor: &TensorDescription) -> Variable {
|
fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||||
// Update the tensor description to the new version.
|
// Update the tensor description to the new version.
|
||||||
self.tensors.insert(tensor.id.clone(), tensor.clone());
|
self.tensors
|
||||||
|
.insert(tensor.id.clone(), (tensor.clone(), elem));
|
||||||
|
|
||||||
// Output already registered as a local variable.
|
// Output already registered as a local variable.
|
||||||
if let Some(index) = self.locals.get(&tensor.id) {
|
if let Some(index) = self.locals.get(&tensor.id) {
|
||||||
return Variable::Local(*index);
|
return Variable::Local(*index, elem);
|
||||||
}
|
}
|
||||||
|
|
||||||
// New local variable.
|
// New local variable.
|
||||||
let local_index = self.locals.len() as u16;
|
let local_index = self.locals.len() as u16;
|
||||||
self.locals.insert(tensor.id.clone(), local_index);
|
self.locals.insert(tensor.id.clone(), local_index);
|
||||||
Variable::Local(local_index)
|
Variable::Local(local_index, elem)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_float(&mut self, ops: &FloatOpsDescription) -> bool {
|
fn register_base<E: WgpuElement>(&mut self, ops: &BaseOpsDescription) -> bool {
|
||||||
|
match ops {
|
||||||
|
BaseOpsDescription::Equal(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_float<E: WgpuElement>(&mut self, ops: &FloatOpsDescription) -> bool {
|
||||||
match ops {
|
match ops {
|
||||||
FloatOpsDescription::Exp(desc) => {
|
FloatOpsDescription::Exp(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Exp { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Exp { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Log(desc) => {
|
FloatOpsDescription::Log(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Log { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Log { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Log1p(desc) => {
|
FloatOpsDescription::Log1p(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Log1p { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Log1p { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Cos(desc) => {
|
FloatOpsDescription::Cos(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Cos { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Cos { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Sin(desc) => {
|
FloatOpsDescription::Sin(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Sin { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
}
|
Operator::Sin { input, out }
|
||||||
FloatOpsDescription::Powf(desc) => {
|
})
|
||||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Powf { lhs, rhs, out })
|
|
||||||
}
|
}
|
||||||
|
FloatOpsDescription::Powf(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
|
|lhs, rhs, out| Operator::Powf { lhs, rhs, out },
|
||||||
|
),
|
||||||
FloatOpsDescription::Tanh(desc) => {
|
FloatOpsDescription::Tanh(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Tanh { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Tanh { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Erf(desc) => {
|
FloatOpsDescription::Erf(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Erf { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
FloatOpsDescription::Recip(desc) => {
|
FloatOpsDescription::Recip(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Recip { input, out }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_numeric<E: Element>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
|
fn register_numeric<E: WgpuElement>(&mut self, ops: &NumericOpsDescription<E>) -> bool {
|
||||||
match ops {
|
match ops {
|
||||||
NumericOpsDescription::Add(desc) => {
|
NumericOpsDescription::Add(desc) => self.register_binary_ops(
|
||||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
desc,
|
||||||
}
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
NumericOpsDescription::AddScalar(desc) => {
|
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Add { lhs, rhs, out })
|
),
|
||||||
}
|
NumericOpsDescription::AddScalar(desc) => self.register_scalar_ops(
|
||||||
NumericOpsDescription::Sub(desc) => {
|
desc,
|
||||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
}
|
|lhs, rhs, out| Operator::Add { lhs, rhs, out },
|
||||||
NumericOpsDescription::SubScalar(desc) => {
|
),
|
||||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Sub { lhs, rhs, out })
|
NumericOpsDescription::Sub(desc) => self.register_binary_ops(
|
||||||
}
|
desc,
|
||||||
NumericOpsDescription::Mul(desc) => {
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||||
}
|
),
|
||||||
NumericOpsDescription::MulScalar(desc) => {
|
NumericOpsDescription::SubScalar(desc) => self.register_scalar_ops(
|
||||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Mul { lhs, rhs, out })
|
desc,
|
||||||
}
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
NumericOpsDescription::Div(desc) => {
|
|lhs, rhs, out| Operator::Sub { lhs, rhs, out },
|
||||||
self.register_binary_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
),
|
||||||
}
|
NumericOpsDescription::Mul(desc) => self.register_binary_ops(
|
||||||
NumericOpsDescription::DivScalar(desc) => {
|
desc,
|
||||||
self.register_scalar_ops(desc, |lhs, rhs, out| Operator::Div { lhs, rhs, out })
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
}
|
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::MulScalar(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
|
|lhs, rhs, out| Operator::Mul { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::Div(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
|
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::DivScalar(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), E::elem_type()),
|
||||||
|
|lhs, rhs, out| Operator::Div { lhs, rhs, out },
|
||||||
|
),
|
||||||
NumericOpsDescription::Abs(desc) => {
|
NumericOpsDescription::Abs(desc) => {
|
||||||
self.register_unary_ops(desc, |input, out| Operator::Abs { input, out })
|
self.register_unary_ops(desc, (E::elem_type(), E::elem_type()), |input, out| {
|
||||||
|
Operator::Abs { input, out }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
NumericOpsDescription::Lower(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::LowerElem(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Lower { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::Greater(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::GreaterElem(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Greater { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::LowerEqual(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::LowerEqualElem(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::LowerEqual { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::GreaterEqual(desc) => self.register_binary_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::GreaterEqual { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::EqualElem(desc) => self.register_scalar_ops(
|
||||||
|
desc,
|
||||||
|
(E::elem_type(), E::elem_type(), Elem::Bool),
|
||||||
|
|lhs, rhs, out| Operator::Equal { lhs, rhs, out },
|
||||||
|
),
|
||||||
|
NumericOpsDescription::MaskWhere(desc) => {
|
||||||
|
if !self.output_is_compatible(&desc.out) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||||
|
let lhs = self.input_to_var(&desc.value, E::elem_type());
|
||||||
|
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
|
||||||
|
let out = self.output_to_var(&desc.out, E::elem_type());
|
||||||
|
|
||||||
|
self.operators.push(Operator::ConditionalAssign {
|
||||||
|
cond,
|
||||||
|
lhs,
|
||||||
|
rhs,
|
||||||
|
out,
|
||||||
|
});
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
NumericOpsDescription::MaskFill(desc) => {
|
||||||
|
if !self.output_is_compatible(&desc.out) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||||
|
let lhs = self.scalar_to_var(&desc.value, E::elem_type());
|
||||||
|
let rhs = self.input_to_var(&desc.tensor, E::elem_type());
|
||||||
|
let out = self.output_to_var(&desc.out, E::elem_type());
|
||||||
|
|
||||||
|
self.operators.push(Operator::ConditionalAssign {
|
||||||
|
cond,
|
||||||
|
lhs,
|
||||||
|
rhs,
|
||||||
|
out,
|
||||||
|
});
|
||||||
|
|
||||||
|
true
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_binary_ops<Func>(&mut self, desc: &BinaryOpsDescription, func: Func) -> bool
|
fn register_binary_ops<Func>(
|
||||||
where
|
|
||||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
|
||||||
{
|
|
||||||
if !self.output_is_compatible(&desc.out) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let lhs = self.input_to_var(&desc.lhs);
|
|
||||||
let rhs = self.input_to_var(&desc.rhs);
|
|
||||||
let out = self.output_to_var(&desc.out);
|
|
||||||
|
|
||||||
self.operators.push(func(lhs, rhs, out));
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_unary_ops<Func>(&mut self, desc: &UnaryOpsDescription, func: Func) -> bool
|
|
||||||
where
|
|
||||||
Func: Fn(Variable, Variable) -> Operator,
|
|
||||||
{
|
|
||||||
if !self.output_is_compatible(&desc.out) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = self.input_to_var(&desc.input);
|
|
||||||
let out = self.output_to_var(&desc.out);
|
|
||||||
|
|
||||||
self.operators.push(func(input, out));
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_scalar_ops<Func, E: Element>(
|
|
||||||
&mut self,
|
&mut self,
|
||||||
desc: &ScalarOpsDescription<E>,
|
desc: &BinaryOpsDescription,
|
||||||
|
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||||
func: Func,
|
func: Func,
|
||||||
) -> bool
|
) -> bool
|
||||||
where
|
where
|
||||||
|
@ -396,16 +548,78 @@ where
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
let lhs = self.input_to_var(&desc.lhs);
|
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||||
let rhs = Variable::Scalar(self.scalars_f32.len() as u16, Elem::F32);
|
let rhs = self.input_to_var(&desc.rhs, elem_rhs);
|
||||||
self.scalars_f32.push(desc.rhs.elem());
|
let out = self.output_to_var(&desc.out, elem_out);
|
||||||
let out = self.output_to_var(&desc.out);
|
|
||||||
|
|
||||||
self.operators.push(func(lhs, rhs, out));
|
self.operators.push(func(lhs, rhs, out));
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn register_unary_ops<Func>(
|
||||||
|
&mut self,
|
||||||
|
desc: &UnaryOpsDescription,
|
||||||
|
(elem_input, elem_out): (Elem, Elem),
|
||||||
|
func: Func,
|
||||||
|
) -> bool
|
||||||
|
where
|
||||||
|
Func: Fn(Variable, Variable) -> Operator,
|
||||||
|
{
|
||||||
|
if !self.output_is_compatible(&desc.out) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let input = self.input_to_var(&desc.input, elem_input);
|
||||||
|
let out = self.output_to_var(&desc.out, elem_out);
|
||||||
|
|
||||||
|
self.operators.push(func(input, out));
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_scalar_ops<Func, E: Element>(
|
||||||
|
&mut self,
|
||||||
|
desc: &ScalarOpsDescription<E>,
|
||||||
|
(elem_lhs, elem_rhs, elem_out): (Elem, Elem, Elem),
|
||||||
|
func: Func,
|
||||||
|
) -> bool
|
||||||
|
where
|
||||||
|
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||||
|
{
|
||||||
|
if !self.output_is_compatible(&desc.out) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||||
|
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
|
||||||
|
let out = self.output_to_var(&desc.out, elem_out);
|
||||||
|
|
||||||
|
self.operators.push(func(lhs, rhs, out));
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scalar_to_var<E: Element>(&mut self, value: &E, elem_type: Elem) -> Variable {
|
||||||
|
match elem_type {
|
||||||
|
Elem::F32 => {
|
||||||
|
self.scalars_f32.push(value.elem());
|
||||||
|
Variable::Scalar(self.scalars_f32.len() as u16 - 1, Elem::F32)
|
||||||
|
}
|
||||||
|
Elem::I32 => {
|
||||||
|
self.scalars_i32.push(value.elem());
|
||||||
|
Variable::Scalar(self.scalars_i32.len() as u16 - 1, Elem::I32)
|
||||||
|
}
|
||||||
|
Elem::U32 => {
|
||||||
|
self.scalars_u32.push(value.elem());
|
||||||
|
Variable::Scalar(self.scalars_u32.len() as u16 - 1, Elem::U32)
|
||||||
|
}
|
||||||
|
Elem::Bool => {
|
||||||
|
panic!("Bool scalars not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
|
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
|
||||||
if self.current_output_shape.is_empty() {
|
if self.current_output_shape.is_empty() {
|
||||||
self.current_output_shape = out.shape.clone();
|
self.current_output_shape = out.shape.clone();
|
||||||
|
@ -446,17 +660,19 @@ mod tests {
|
||||||
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
|
let tensor_2 = Tensor::<Backend, 2>::from_data(data_2.clone());
|
||||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||||
let tensor_5 = tensor_4 + 5.0;
|
let tensor_5 = tensor_4.clone() + 5.0;
|
||||||
let tensor_6 = tensor_5 + tensor_3;
|
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||||
let result_ref = tensor_6.recip().into_data();
|
let mask = tensor_4.lower_equal(tensor_3);
|
||||||
|
let result_ref = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||||
|
|
||||||
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
||||||
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
||||||
let tensor_3 = tensor_1.clone() + tensor_2;
|
let tensor_3 = tensor_1.clone() + tensor_2;
|
||||||
let tensor_4 = tensor_3.clone() - tensor_1;
|
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||||
let tensor_5 = tensor_4 + 5.0;
|
let tensor_5 = tensor_4.clone() + 5.0;
|
||||||
let tensor_6 = tensor_5 + tensor_3;
|
let tensor_6 = tensor_5 + tensor_3.clone();
|
||||||
let result_fused = tensor_6.recip().into_data();
|
let mask = tensor_4.lower_equal(tensor_3);
|
||||||
|
let result_fused = tensor_6.mask_fill(mask, 0.3).into_data();
|
||||||
|
|
||||||
result_fused.assert_approx_eq(&result_ref, 3);
|
result_fused.assert_approx_eq(&result_ref, 3);
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,25 +83,43 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Input
|
||||||
/// Register the inputs used by the kernel.
|
/// Register the inputs used by the kernel.
|
||||||
pub fn inputs(
|
pub fn inputs(
|
||||||
mut self,
|
mut self,
|
||||||
inputs_tensor: &[&TensorDescription],
|
inputs_tensor: &[&(TensorDescription, Elem)],
|
||||||
inputs_scalar_f32: &[f32],
|
inputs_scalar_f32: &[f32],
|
||||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
) -> FusionKernel<G, F, I, BodyPhase> {
|
||||||
for (i, input) in inputs_tensor.iter().enumerate() {
|
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
|
||||||
self.input_bindings.push((
|
if elem != &Elem::Bool {
|
||||||
Binding {
|
self.input_bindings.push((
|
||||||
elem: Elem::F32,
|
Binding {
|
||||||
visibility: Visibility::Read,
|
elem: *elem,
|
||||||
location: Location::Storage,
|
visibility: Visibility::Read,
|
||||||
size: None,
|
location: Location::Storage,
|
||||||
},
|
size: None,
|
||||||
(*input).clone(),
|
},
|
||||||
));
|
(*input).clone(),
|
||||||
|
));
|
||||||
|
|
||||||
self.operations.push(Operator::ReadGlobal {
|
self.operations.push(Operator::ReadGlobal {
|
||||||
variable: Variable::Input(i as u16),
|
variable: Variable::Input(i as u16, *elem),
|
||||||
position: i,
|
position: i,
|
||||||
position_out: inputs_tensor.len(), // First output
|
position_out: inputs_tensor.len(), // First output
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
self.input_bindings.push((
|
||||||
|
Binding {
|
||||||
|
elem: Elem::I32,
|
||||||
|
visibility: Visibility::Read,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: None,
|
||||||
|
},
|
||||||
|
(*input).clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
self.operations.push(Operator::ReadGlobal {
|
||||||
|
variable: Variable::Input(i as u16, *elem),
|
||||||
|
position: i,
|
||||||
|
position_out: inputs_tensor.len(), // First output
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !inputs_scalar_f32.is_empty() {
|
if !inputs_scalar_f32.is_empty() {
|
||||||
|
@ -180,31 +198,48 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, Outpu
|
||||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||||
pub fn outputs(
|
pub fn outputs(
|
||||||
mut self,
|
mut self,
|
||||||
outputs: &[&TensorDescription],
|
outputs: &[&(TensorDescription, Elem)],
|
||||||
locals: &[u16],
|
locals: &[u16],
|
||||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
||||||
let mut num_elems_launch_option = 0;
|
let mut num_elems_launch_option = 0;
|
||||||
|
|
||||||
for (i, (output, local)) in outputs.iter().zip(locals).enumerate() {
|
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
|
||||||
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
||||||
if num_elems_output > num_elems_launch_option {
|
if num_elems_output > num_elems_launch_option {
|
||||||
num_elems_launch_option = num_elems_output;
|
num_elems_launch_option = num_elems_output;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.output_bindings.push((
|
if elem != &Elem::Bool {
|
||||||
Binding {
|
self.output_bindings.push((
|
||||||
elem: Elem::F32,
|
Binding {
|
||||||
visibility: Visibility::ReadWrite,
|
elem: *elem,
|
||||||
location: Location::Storage,
|
visibility: Visibility::ReadWrite,
|
||||||
size: None,
|
location: Location::Storage,
|
||||||
},
|
size: None,
|
||||||
(*output).clone(),
|
},
|
||||||
));
|
(*output).clone(),
|
||||||
|
));
|
||||||
|
|
||||||
self.operations.push(Operator::AssignGlobal {
|
self.operations.push(Operator::AssignGlobal {
|
||||||
input: Variable::Local(*local),
|
input: Variable::Local(*local, *elem),
|
||||||
out: Variable::Output(i as u16),
|
out: Variable::Output(i as u16, *elem),
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
self.output_bindings.push((
|
||||||
|
Binding {
|
||||||
|
elem: Elem::I32, // I32 are used for bool tensors
|
||||||
|
visibility: Visibility::ReadWrite,
|
||||||
|
location: Location::Storage,
|
||||||
|
size: None,
|
||||||
|
},
|
||||||
|
(*output).clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
self.operations.push(Operator::AssignGlobal {
|
||||||
|
input: Variable::Local(*local, *elem),
|
||||||
|
out: Variable::Output(i as u16, Elem::I32),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.num_elems_output = num_elems_launch_option;
|
self.num_elems_output = num_elems_launch_option;
|
||||||
|
|
Loading…
Reference in New Issue