benchmarks

This commit is contained in:
louisfd 2023-12-19 14:15:46 -05:00
parent b5c49c5bf7
commit 07e3fb3511
5 changed files with 279 additions and 1 deletions

View File

@ -54,3 +54,7 @@ harness = false
[[bench]]
name = "custom_gelu"
harness = false
[[bench]]
name = "linear"
harness = false

View File

@ -0,0 +1,109 @@
use backend_comparison::persistence::Persistence;
use burn::{
nn::{LinearConfig, LinearTConfig},
tensor::{backend::Backend, Distribution, Shape, Tensor},
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
struct LinearBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
weight_shape: [usize; 2],
num_repeats: usize,
device: B::Device,
}
impl<B: Backend, const D: usize> Benchmark for LinearBenchmark<B, D> {
type Args = (Box<dyn Fn(Tensor<B, D>) -> ()>, Tensor<B, D>);
fn name(&self) -> String {
"Linear".to_string()
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
args.0(args.1.clone())
}
}
fn prepare(&self) -> Self::Args {
let conf = LinearConfig::new(self.weight_shape[0], self.weight_shape[1]);
let lin = conf.init();
let f = Box::new(move |x| {
lin.forward(x);
});
let input = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
(f, input)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[derive(new)]
struct LinearTBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
weight_shape: [usize; 2],
num_repeats: usize,
device: B::Device,
}
impl<B: Backend, const D: usize> Benchmark for LinearTBenchmark<B, D> {
type Args = (Box<dyn Fn(Tensor<B, D>) -> ()>, Tensor<B, D>);
fn name(&self) -> String {
"LinearT".to_string()
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
args.0(args.1.clone())
}
}
fn prepare(&self) -> Self::Args {
let conf = LinearTConfig::new(self.weight_shape[0], self.weight_shape[1]);
let lint = conf.init();
let f = Box::new(move |x| {
lint.forward(x);
});
let input = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
(f, input)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let weight_shape = [1024, 1024];
let shape: Shape<D> = [32, 512, weight_shape[0]].into();
let num_repeats = 10;
let lin =
LinearBenchmark::<B, D>::new(shape.clone(), weight_shape, num_repeats, device.clone());
let lint =
LinearTBenchmark::<B, D>::new(shape.clone(), weight_shape, num_repeats, device.clone());
Persistence::persist::<B>(
vec![
run_benchmark(lin), //
run_benchmark(lint), //
],
device,
)
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -0,0 +1,163 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::tensor::{backend::Backend, Tensor};
use libm::sqrt;
use super::Initializer;
/// Configuration to create a [Linear](Linear) layer.
#[derive(Config, Debug)]
pub struct LinearTConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the output features.
pub d_output: usize,
/// If a bias should be applied during the linear transformation.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")]
pub initializer: Initializer,
}
/// Applies a linear transformation to the input tensor:
///
/// `O = IW + b`
#[derive(Module, Debug)]
pub struct LinearT<B: Backend> {
/// Matrix of shape `[d_output, d_input]` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
pub weight: Param<Tensor<B, 2>>,
/// Vector of size `d_output` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
pub bias: Option<Param<Tensor<B, 1>>>,
}
impl LinearTConfig {
/// Initialize a new [linear](Linear) module.
pub fn init<B: Backend>(&self) -> LinearT<B> {
let shape = [self.d_output, self.d_input];
let weight = self
.initializer
.init_with(shape, Some(self.d_output), Some(self.d_input));
let bias = if self.bias {
Some(self.initializer.init_with(
[self.d_output],
Some(self.d_input),
Some(self.d_output),
))
} else {
None
};
LinearT {
weight: Param::from(weight),
bias: bias.map(Param::from),
}
}
/// Initialize a new [linear](Linear) module with a [record](LinearRecord).
pub fn init_with<B: Backend>(&self, record: LinearTRecord<B>) -> LinearT<B> {
LinearT {
weight: record.weight,
bias: record.bias,
}
}
}
impl<B: Backend> LinearT<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any, d_input]`
/// - output: `[..., any, d_output]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let output = self
.weight
.val()
.unsqueeze()
.matmul(input.transpose())
.transpose();
match &self.bias {
Some(bias) => output + bias.val().unsqueeze(),
None => output,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::{Data, Shape};
use libm::sqrt;
#[test]
fn initializer_default() {
TestBackend::seed(0);
let config = LinearTConfig::new(5, 5);
let k = sqrt(1.0 / config.d_input as f64) as f32;
let linear = config.init::<TestBackend>();
assert_eq!(
config.initializer,
Initializer::KaimingUniform {
gain: 1.0 / sqrt(3.0),
fan_out_only: false
}
);
linear.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
TestBackend::seed(0);
let config = LinearTConfig::new(5, 5).with_initializer(Initializer::Zeros);
let linear = config.init::<TestBackend>();
assert_eq!(config.initializer, Initializer::Zeros);
linear
.weight
.to_data()
.assert_approx_eq(&Data::zeros(linear.weight.shape()), 3);
}
#[test]
fn test_linear_forward_no_bias() {
TestBackend::seed(0);
let value = 2.;
let config = LinearTConfig::new(2, 3)
.with_initializer(Initializer::Constant { value })
.with_bias(false);
let linear = config.init();
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]));
let result = linear.forward(input);
let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]]);
assert_eq!(result.into_data(), expected_result.into_data());
}
#[test]
fn test_linear_forward_with_bias() {
TestBackend::seed(0);
let value = 2.;
let config = LinearTConfig::new(2, 3).with_initializer(Initializer::Constant { value });
let linear = config.init();
let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]));
let result = linear.forward(input);
let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]]);
assert_eq!(result.into_data(), expected_result.into_data());
}
}

View File

@ -21,6 +21,7 @@ mod embedding;
mod gelu;
mod initializer;
mod linear;
mod linear_transposed_weight;
mod norm;
mod padding;
mod pos_encoding;
@ -33,6 +34,7 @@ pub use embedding::*;
pub use gelu::*;
pub use initializer::*;
pub use linear::*;
pub use linear_transposed_weight::*;
pub use norm::*;
pub use padding::*;
pub use pos_encoding::*;

View File

@ -11,7 +11,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-wgpu"
version.workspace = true
[features]
default = ["autotune", "std"]
default = ["std"]
std = []
autotune = []
fusion = ["burn-fusion"]