diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index f1228d7c1..b1a713af5 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -54,3 +54,7 @@ harness = false [[bench]] name = "custom_gelu" harness = false + +[[bench]] +name = "linear" +harness = false diff --git a/backend-comparison/benches/linear.rs b/backend-comparison/benches/linear.rs new file mode 100644 index 000000000..1889a3490 --- /dev/null +++ b/backend-comparison/benches/linear.rs @@ -0,0 +1,109 @@ +use backend_comparison::persistence::Persistence; +use burn::{ + nn::{LinearConfig, LinearTConfig}, + tensor::{backend::Backend, Distribution, Shape, Tensor}, +}; +use burn_common::benchmark::{run_benchmark, Benchmark}; +use derive_new::new; + +#[derive(new)] +struct LinearBenchmark { + shape: Shape, + weight_shape: [usize; 2], + num_repeats: usize, + device: B::Device, +} + +impl Benchmark for LinearBenchmark { + type Args = (Box) -> ()>, Tensor); + + fn name(&self) -> String { + "Linear".to_string() + } + + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + args.0(args.1.clone()) + } + } + + fn prepare(&self) -> Self::Args { + let conf = LinearConfig::new(self.weight_shape[0], self.weight_shape[1]); + let lin = conf.init(); + let f = Box::new(move |x| { + lin.forward(x); + }); + + let input = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + + (f, input) + } + + fn sync(&self) { + B::sync(&self.device) + } +} + +#[derive(new)] +struct LinearTBenchmark { + shape: Shape, + weight_shape: [usize; 2], + num_repeats: usize, + device: B::Device, +} + +impl Benchmark for LinearTBenchmark { + type Args = (Box) -> ()>, Tensor); + + fn name(&self) -> String { + "LinearT".to_string() + } + + fn execute(&self, args: Self::Args) { + for _ in 0..self.num_repeats { + args.0(args.1.clone()) + } + } + + fn prepare(&self) -> Self::Args { + let conf = LinearTConfig::new(self.weight_shape[0], self.weight_shape[1]); + let lint = conf.init(); + let f = Box::new(move |x| { + lint.forward(x); + }); + + let input = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device); + + (f, input) + } + + fn sync(&self) { + B::sync(&self.device) + } +} + +#[allow(dead_code)] +fn bench(device: &B::Device) { + const D: usize = 3; + let weight_shape = [1024, 1024]; + let shape: Shape = [32, 512, weight_shape[0]].into(); + let num_repeats = 10; + + let lin = + LinearBenchmark::::new(shape.clone(), weight_shape, num_repeats, device.clone()); + + let lint = + LinearTBenchmark::::new(shape.clone(), weight_shape, num_repeats, device.clone()); + + Persistence::persist::( + vec![ + run_benchmark(lin), // + run_benchmark(lint), // + ], + device, + ) +} + +fn main() { + backend_comparison::bench_on_backend!(); +} diff --git a/burn-core/src/nn/linear_transposed_weight.rs b/burn-core/src/nn/linear_transposed_weight.rs new file mode 100644 index 000000000..b8c1509c0 --- /dev/null +++ b/burn-core/src/nn/linear_transposed_weight.rs @@ -0,0 +1,163 @@ +use crate as burn; + +use crate::config::Config; +use crate::module::Module; +use crate::module::Param; +use crate::tensor::{backend::Backend, Tensor}; +use libm::sqrt; + +use super::Initializer; + +/// Configuration to create a [Linear](Linear) layer. +#[derive(Config, Debug)] +pub struct LinearTConfig { + /// The size of the input features. + pub d_input: usize, + /// The size of the output features. + pub d_output: usize, + /// If a bias should be applied during the linear transformation. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config(default = "Initializer::KaimingUniform{gain:1.0/sqrt(3.0), fan_out_only:false}")] + pub initializer: Initializer, +} + +/// Applies a linear transformation to the input tensor: +/// +/// `O = IW + b` +#[derive(Module, Debug)] +pub struct LinearT { + /// Matrix of shape `[d_output, d_input]` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub weight: Param>, + /// Vector of size `d_output` initialized from a uniform distribution: + /// `U(-k, k)`, where `k = sqrt(1 / d_input)` + pub bias: Option>>, +} + +impl LinearTConfig { + /// Initialize a new [linear](Linear) module. + pub fn init(&self) -> LinearT { + let shape = [self.d_output, self.d_input]; + let weight = self + .initializer + .init_with(shape, Some(self.d_output), Some(self.d_input)); + let bias = if self.bias { + Some(self.initializer.init_with( + [self.d_output], + Some(self.d_input), + Some(self.d_output), + )) + } else { + None + }; + + LinearT { + weight: Param::from(weight), + bias: bias.map(Param::from), + } + } + + /// Initialize a new [linear](Linear) module with a [record](LinearRecord). + pub fn init_with(&self, record: LinearTRecord) -> LinearT { + LinearT { + weight: record.weight, + bias: record.bias, + } + } +} + +impl LinearT { + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_input]` + /// - output: `[..., any, d_output]` + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self + .weight + .val() + .unsqueeze() + .matmul(input.transpose()) + .transpose(); + + match &self.bias { + Some(bias) => output + bias.val().unsqueeze(), + None => output, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn_tensor::{Data, Shape}; + use libm::sqrt; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = LinearTConfig::new(5, 5); + let k = sqrt(1.0 / config.d_input as f64) as f32; + let linear = config.init::(); + + assert_eq!( + config.initializer, + Initializer::KaimingUniform { + gain: 1.0 / sqrt(3.0), + fan_out_only: false + } + ); + linear.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = LinearTConfig::new(5, 5).with_initializer(Initializer::Zeros); + let linear = config.init::(); + + assert_eq!(config.initializer, Initializer::Zeros); + linear + .weight + .to_data() + .assert_approx_eq(&Data::zeros(linear.weight.shape()), 3); + } + + #[test] + fn test_linear_forward_no_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearTConfig::new(2, 3) + .with_initializer(Initializer::Constant { value }) + .with_bias(false); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[4., 4., 4.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } + + #[test] + fn test_linear_forward_with_bias() { + TestBackend::seed(0); + + let value = 2.; + let config = LinearTConfig::new(2, 3).with_initializer(Initializer::Constant { value }); + let linear = config.init(); + + let input = Tensor::::ones(Shape::new([1, 2])); + let result = linear.forward(input); + let expected_result = Tensor::::from_data([[6., 6., 6.]]); + + assert_eq!(result.into_data(), expected_result.into_data()); + } +} diff --git a/burn-core/src/nn/mod.rs b/burn-core/src/nn/mod.rs index 678bb0d49..909e5a44c 100644 --- a/burn-core/src/nn/mod.rs +++ b/burn-core/src/nn/mod.rs @@ -21,6 +21,7 @@ mod embedding; mod gelu; mod initializer; mod linear; +mod linear_transposed_weight; mod norm; mod padding; mod pos_encoding; @@ -33,6 +34,7 @@ pub use embedding::*; pub use gelu::*; pub use initializer::*; pub use linear::*; +pub use linear_transposed_weight::*; pub use norm::*; pub use padding::*; pub use pos_encoding::*; diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml index f3e6fb688..074c2753b 100644 --- a/burn-wgpu/Cargo.toml +++ b/burn-wgpu/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-wgpu" version.workspace = true [features] -default = ["autotune", "std"] +default = ["std"] std = [] autotune = [] fusion = ["burn-fusion"]