Implement fusing for recip() (#959)

This commit is contained in:
Zsombor 2023-11-15 23:15:01 +01:00 committed by GitHub
parent 24014aca33
commit c0859dde59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -61,6 +61,10 @@ pub enum Operator {
input: Variable,
out: Variable,
},
Recip {
input: Variable,
out: Variable,
},
AssignGlobal {
input: Variable,
out: Variable,
@ -102,6 +106,9 @@ impl Display for Operator {
f.write_fmt(format_args!("let {out} = tanh({input});"))
}
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operator::Recip { input, out } => {
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
}
Operator::AssignGlobal { input, out } => {
f.write_fmt(format_args!("{out}_global[id] = {input};"))
}

View File

@ -209,6 +209,10 @@ where
mark(rhs, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
Operator::Recip { input, out } => {
mark(input, &mut local_tensor_ids_input);
mark(out, &mut local_tensor_ids_output);
}
}
}
@ -309,6 +313,9 @@ where
FloatOpsDescription::Erf(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Erf { input, out })
}
FloatOpsDescription::Recip(desc, _) => {
self.register_unary_ops(desc, |input, out| Operator::Recip { input, out })
}
_ => false,
}
}
@ -448,7 +455,7 @@ mod tests {
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_ref = tensor_6.into_data();
let result_ref = tensor_6.recip().into_data();
let tensor_1 = Tensor::<FusedBackend, 2>::from_data(data_1);
let tensor_2 = Tensor::<FusedBackend, 2>::from_data(data_2);
@ -456,7 +463,7 @@ mod tests {
let tensor_4 = tensor_3.clone() - tensor_1;
let tensor_5 = tensor_4 + 5.0;
let tensor_6 = tensor_5 + tensor_3;
let result_fused = tensor_6.into_data();
let result_fused = tensor_6.recip().into_data();
result_fused.assert_approx_eq(&result_ref, 3);
}