mirror of https://github.com/tracel-ai/burn.git
Implement fusing for recip() (#959)
This commit is contained in:
parent
24014aca33
commit
c0859dde59
|
@ -61,6 +61,10 @@ pub enum Operator {
|
|||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Recip {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
AssignGlobal {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
|
@ -102,6 +106,9 @@ impl Display for Operator {
|
|||
f.write_fmt(format_args!("let {out} = tanh({input});"))
|
||||
}
|
||||
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
|
||||
Operator::Recip { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
|
||||
}
|
||||
Operator::AssignGlobal { input, out } => {
|
||||
f.write_fmt(format_args!("{out}_global[id] = {input};"))
|
||||
}
|
||||
|
|
|
@ -209,6 +209,10 @@ where
|
|||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Recip { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -309,6 +313,9 @@ where
|
|||
FloatOpsDescription::Erf(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
|
||||
}
|
||||
FloatOpsDescription::Recip(desc, _) => {
|
||||
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out })
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -448,7 +455,7 @@ mod tests {
|
|||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_ref = tensor_6.into_data();
|
||||
let result_ref = tensor_6.recip().into_data();
|
||||
|
||||
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
|
||||
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
|
||||
|
@ -456,7 +463,7 @@ mod tests {
|
|||
let tensor_4 = tensor_3.clone() - tensor_1;
|
||||
let tensor_5 = tensor_4 + 5.0;
|
||||
let tensor_6 = tensor_5 + tensor_3;
|
||||
let result_fused = tensor_6.into_data();
|
||||
let result_fused = tensor_6.recip().into_data();
|
||||
|
||||
result_fused.assert_approx_eq(&result_ref, 3);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue