From c0859dde5935d32586be70062035fd039777086f Mon Sep 17 00:00:00 2001 From: Zsombor Date: Wed, 15 Nov 2023 23:15:01 +0100 Subject: [PATCH] Implement fusing for recip() (#959) --- burn-wgpu/src/fusion/codegen/operator.rs | 7 +++++++ burn-wgpu/src/fusion/elemwise/ops.rs | 11 +++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/burn-wgpu/src/fusion/codegen/operator.rs b/burn-wgpu/src/fusion/codegen/operator.rs index ab0b193a8..922f7cf81 100644 --- a/burn-wgpu/src/fusion/codegen/operator.rs +++ b/burn-wgpu/src/fusion/codegen/operator.rs @@ -61,6 +61,10 @@ pub enum Operator { input: Variable, out: Variable, }, + Recip { + input: Variable, + out: Variable, + }, AssignGlobal { input: Variable, out: Variable, @@ -102,6 +106,9 @@ impl Display for Operator { f.write_fmt(format_args!("let {out} = tanh({input});")) } Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), + Operator::Recip { input, out } => { + f.write_fmt(format_args!("let {out} = 1.0 / {input};")) + } Operator::AssignGlobal { input, out } => { f.write_fmt(format_args!("{out}_global[id] = {input};")) } diff --git a/burn-wgpu/src/fusion/elemwise/ops.rs b/burn-wgpu/src/fusion/elemwise/ops.rs index 4a4eb45a9..78dec96fd 100644 --- a/burn-wgpu/src/fusion/elemwise/ops.rs +++ b/burn-wgpu/src/fusion/elemwise/ops.rs @@ -209,6 +209,10 @@ where mark(rhs, &mut local_tensor_ids_input); mark(out, &mut local_tensor_ids_output); } + Operator::Recip { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } } } @@ -309,6 +313,9 @@ where FloatOpsDescription::Erf(desc, _) => { self.register_unary_ops(desc, |input, out| Operator::Erf { input, out }) } + FloatOpsDescription::Recip(desc, _) => { + self.register_unary_ops(desc, |input, out| Operator::Recip { input, out }) + } _ => false, } } @@ -448,7 +455,7 @@ mod tests { let tensor_4 = tensor_3.clone() - tensor_1; let tensor_5 = tensor_4 + 5.0; let tensor_6 = tensor_5 + tensor_3; - let result_ref = tensor_6.into_data(); + let result_ref = tensor_6.recip().into_data(); let tensor_1 = Tensor::::from_data(data_1); let tensor_2 = Tensor::::from_data(data_2); @@ -456,7 +463,7 @@ mod tests { let tensor_4 = tensor_3.clone() - tensor_1; let tensor_5 = tensor_4 + 5.0; let tensor_6 = tensor_5 + tensor_3; - let result_fused = tensor_6.into_data(); + let result_fused = tensor_6.recip().into_data(); result_fused.assert_approx_eq(&result_ref, 3); }