2024-01-13 00:15:00 +08:00
use backend_comparison ::persistence ::save ;
2024-02-27 06:19:09 +08:00
use burn ::backend ::Autodiff ;
2023-11-16 21:15:21 +08:00
use burn ::tensor ::{ backend ::Backend , Distribution , Shape , Tensor } ;
use burn_common ::benchmark ::{ run_benchmark , Benchmark } ;
use core ::f64 ::consts ::SQRT_2 ;
use derive_new ::new ;
#[ derive(Debug) ]
enum GeluKind {
Reference ,
WithReferenceErf ,
WithCustomErf ,
}
/// Benchmark how well a backend executes a custom activation function with a lot of basic tensor
/// operations.
#[ derive(new) ]
struct CustomGeluBenchmark < B : Backend , const D : usize > {
2024-09-24 20:35:52 +08:00
shape : Shape ,
2023-11-16 21:15:21 +08:00
device : B ::Device ,
kind : GeluKind ,
2024-02-27 06:19:09 +08:00
autodiff : bool ,
2023-11-16 21:15:21 +08:00
}
impl < B : Backend , const D : usize > Benchmark for CustomGeluBenchmark < B , D > {
type Args = Tensor < B , D > ;
fn name ( & self ) -> String {
2024-02-27 06:19:09 +08:00
match self . autodiff {
2024-03-29 00:35:15 +08:00
true = > format! ( " gelu_autodiff_ {:?} " , self . kind ) ,
false = > format! ( " gelu_ {:?} " , self . kind ) ,
2024-02-27 06:19:09 +08:00
}
2024-01-13 00:15:00 +08:00
}
fn options ( & self ) -> Option < String > {
Some ( format! ( " {:?} " , self . kind ) )
}
fn shapes ( & self ) -> Vec < Vec < usize > > {
2024-09-24 20:35:52 +08:00
vec! [ self . shape . dims . clone ( ) ]
2023-11-16 21:15:21 +08:00
}
2024-02-27 06:19:09 +08:00
fn execute ( & self , tensor : Self ::Args ) {
match self . autodiff {
true = > {
let tensor : Tensor < Autodiff < B > , D > = Tensor ::from_inner ( tensor ) . require_grad ( ) ;
let output = match self . kind {
GeluKind ::Reference = > burn ::tensor ::activation ::gelu ( tensor . clone ( ) ) ,
GeluKind ::WithReferenceErf = > gelu_custom ( tensor . clone ( ) , Tensor ::erf ) ,
GeluKind ::WithCustomErf = > gelu_custom ( tensor . clone ( ) , erf_custom ) ,
} ;
let mut gradients = output . sum ( ) . backward ( ) ;
let _tmp = tensor . grad_remove ( & mut gradients ) . unwrap ( ) ;
}
false = > {
match self . kind {
GeluKind ::Reference = > burn ::tensor ::activation ::gelu ( tensor ) ,
GeluKind ::WithReferenceErf = > gelu_custom ( tensor , Tensor ::erf ) ,
GeluKind ::WithCustomErf = > gelu_custom ( tensor , erf_custom ) ,
} ;
}
2024-01-11 01:37:17 +08:00
} ;
2023-11-16 21:15:21 +08:00
}
fn prepare ( & self ) -> Self ::Args {
2023-12-21 06:49:59 +08:00
Tensor ::random ( self . shape . clone ( ) , Distribution ::Default , & self . device )
2023-11-16 21:15:21 +08:00
}
fn sync ( & self ) {
2024-10-17 03:56:12 +08:00
B ::sync ( & self . device )
2023-11-16 21:15:21 +08:00
}
2024-01-09 05:58:39 +08:00
fn num_samples ( & self ) -> usize {
2024-02-27 06:19:09 +08:00
10
2024-01-09 05:58:39 +08:00
}
2023-11-16 21:15:21 +08:00
}
fn gelu_custom < B , const D : usize , Erf > ( x : Tensor < B , D > , erf : Erf ) -> Tensor < B , D >
where
B : Backend ,
Erf : Fn ( Tensor < B , D > ) -> Tensor < B , D > ,
{
let x = x . clone ( ) * ( erf ( x / SQRT_2 ) + 1 ) ;
2023-11-17 08:35:38 +08:00
x / 2
2023-11-16 21:15:21 +08:00
}
fn erf_custom < B : Backend , const D : usize > ( x : Tensor < B , D > ) -> Tensor < B , D > {
let x1 = - erf_positive ( - x . clone ( ) ) ;
let x2 = erf_positive ( x . clone ( ) ) ;
let mask = x . greater_elem ( 0 ) ;
x1 . mask_where ( mask , x2 )
}
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5× 10− 7)
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = − erf(− x).
fn erf_positive < B : Backend , const D : usize > ( x : Tensor < B , D > ) -> Tensor < B , D > {
let p = 0.3275911 ;
let a1 = 0.254829592 ;
let a2 = - 0.284496736 ;
let a3 = 1.421413741 ;
let a4 = - 1.453152027 ;
let a5 = 1.061405429 ;
let x1 = x . clone ( ) . abs ( ) * p + 1 ;
let t = x1 . recip ( ) ;
let tmp = ( ( ( ( ( t . clone ( ) * a5 ) + a4 ) * t . clone ( ) ) + a3 ) * t . clone ( ) + a2 ) * t . clone ( ) + a1 ;
2023-11-17 08:35:38 +08:00
- ( tmp * t * ( - x . clone ( ) * x ) . exp ( ) ) + 1.0
2023-11-16 21:15:21 +08:00
}
#[ allow(dead_code) ]
2024-04-01 21:48:44 +08:00
fn bench < B : Backend > (
device : & B ::Device ,
feature_name : & str ,
url : Option < & str > ,
token : Option < & str > ,
) {
2023-11-16 21:15:21 +08:00
const D : usize = 3 ;
2024-09-24 20:35:52 +08:00
let shape : Shape = [ 32 , 512 , 2048 ] . into ( ) ;
2024-01-11 01:37:17 +08:00
2024-02-27 06:19:09 +08:00
let run = | autodiff : bool | {
let reference_gelu = CustomGeluBenchmark ::< B , D > ::new (
shape . clone ( ) ,
device . clone ( ) ,
GeluKind ::Reference ,
autodiff ,
) ;
let reference_erf_gelu = CustomGeluBenchmark ::< B , D > ::new (
shape . clone ( ) ,
device . clone ( ) ,
GeluKind ::WithReferenceErf ,
autodiff ,
) ;
let custom_erf_gelu = CustomGeluBenchmark ::< B , D > ::new (
shape . clone ( ) ,
device . clone ( ) ,
GeluKind ::WithCustomErf ,
autodiff ,
) ;
save ::< B > (
vec! [
run_benchmark ( reference_gelu ) ,
run_benchmark ( reference_erf_gelu ) ,
run_benchmark ( custom_erf_gelu ) ,
] ,
device ,
2024-04-01 21:48:44 +08:00
feature_name ,
2024-03-03 00:38:18 +08:00
url ,
token ,
2024-02-27 06:19:09 +08:00
)
. unwrap ( ) ;
} ;
run ( false ) ;
run ( true ) ;
2023-11-16 21:15:21 +08:00
}
fn main ( ) {
backend_comparison ::bench_on_backend! ( ) ;
}