mirror of https://github.com/tracel-ai/burn.git
Merge branch 'main' into feat/cube/tiling2d
This commit is contained in:
commit
babeac6d80
|
@ -7,7 +7,10 @@ use burn::{
|
|||
Distribution, Tensor,
|
||||
},
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
|
||||
pub struct AutodiffOverheadBenchmark<B: AutodiffBackend> {
|
||||
config: nn::LstmConfig,
|
||||
|
@ -47,7 +50,7 @@ impl<B: AutodiffBackend> Benchmark for AutodiffOverheadBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
|
||||
pub struct BinaryBenchmark<B: Backend, const D: usize> {
|
||||
shape: Shape<D>,
|
||||
|
@ -31,7 +34,7 @@ impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,10 @@ use backend_comparison::persistence::save;
|
|||
use burn::tensor::{
|
||||
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
|
||||
pub struct Conv2dBenchmark<B: Backend> {
|
||||
input_shape: Shape<4>,
|
||||
|
@ -48,7 +51,7 @@ impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,10 @@ use burn::tensor::{
|
|||
backend::Backend, module::conv_transpose2d, ops::ConvTransposeOptions, Distribution, Shape,
|
||||
Tensor,
|
||||
};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
|
||||
pub struct ConvTranspose2dBenchmark<B: Backend> {
|
||||
input_shape: Shape<4>,
|
||||
|
@ -49,7 +52,7 @@ impl<B: Backend> Benchmark for ConvTranspose2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use backend_comparison::persistence::save;
|
|||
use burn::backend::Autodiff;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use core::f64::consts::SQRT_2;
|
||||
use derive_new::new;
|
||||
|
||||
|
@ -68,7 +69,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -29,7 +32,7 @@ impl<B: Backend, const D: usize> Benchmark for ToDataBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,7 +69,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ use burn::tensor::backend::Backend;
|
|||
use burn::tensor::Device;
|
||||
use burn::{config::Config, module::Module, nn};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
@ -93,7 +94,7 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -37,7 +40,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, module::max_pool2d, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
|
||||
pub struct MaxPool2dBenchmark<B: Backend> {
|
||||
shape: Shape<4>,
|
||||
|
@ -37,7 +40,7 @@ impl<B: Backend> Benchmark for MaxPool2dBenchmark<B> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use burn_common::{
|
||||
benchmark::{run_benchmark, Benchmark},
|
||||
sync_type::SyncType,
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -30,7 +33,7 @@ impl<B: Backend, const D: usize> Benchmark for UnaryBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device)
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ use crate::{
|
|||
tensor::AutodiffTensor,
|
||||
AutodiffBridge,
|
||||
};
|
||||
use burn_common::sync_type::SyncType;
|
||||
use burn_tensor::backend::{AutodiffBackend, Backend};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
@ -43,8 +44,8 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
B::seed(seed)
|
||||
}
|
||||
|
||||
fn sync(device: &B::Device) {
|
||||
B::sync(device);
|
||||
fn sync(device: &B::Device, sync_type: SyncType) {
|
||||
B::sync(device, sync_type)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceId, DeviceOps},
|
||||
backend::{Backend, DeviceId, DeviceOps, SyncType},
|
||||
Device,
|
||||
};
|
||||
use candle_core::DeviceLocation;
|
||||
|
@ -105,7 +105,9 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
panic!("Manual seed not supported by Candle. ")
|
||||
}
|
||||
|
||||
fn sync(device: &Device<Self>) {
|
||||
fn sync(device: &Device<Self>, sync_type: SyncType) {
|
||||
match sync_type {
|
||||
SyncType::Wait => {
|
||||
let device: candle_core::Device = (*device).into();
|
||||
|
||||
match device {
|
||||
|
@ -121,4 +123,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
}
|
||||
}
|
||||
}
|
||||
SyncType::Flush => (), // Nothhing to flush.
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@ pub mod benchmark;
|
|||
/// notation.
|
||||
pub mod reader;
|
||||
|
||||
/// Synchronization type module, used both by ComputeServer and Backends.
|
||||
pub mod sync_type;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Network utilities.
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
/// What kind of synchronization to use.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SyncType {
|
||||
/// Submit all outstanding tasks to the task queue if any.
|
||||
Flush,
|
||||
/// Submit all tasks to the task queue and wait for all of them to complete.
|
||||
Wait,
|
||||
}
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
storage::ComputeStorage,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
|
||||
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
|
||||
/// while ensuring thread-safety
|
||||
|
@ -26,6 +26,6 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
|
|||
/// Executes the `kernel` over the given `bindings`.
|
||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&self);
|
||||
/// Perform some synchronization of commands on the server.
|
||||
fn sync(&self, sync_type: SyncType);
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::storage::ComputeStorage;
|
|||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
|
||||
///
|
||||
|
@ -68,8 +69,8 @@ where
|
|||
.execute(kernel_description, bindings)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.server.borrow_mut().sync()
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
self.server.borrow_mut().sync(sync_type)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::{
|
|||
thread,
|
||||
};
|
||||
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::{
|
||||
|
@ -44,7 +44,7 @@ where
|
|||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
|
||||
Sync(Callback<()>),
|
||||
Sync(SyncType, Callback<()>),
|
||||
}
|
||||
|
||||
impl<Server> MpscComputeChannel<Server>
|
||||
|
@ -77,8 +77,8 @@ where
|
|||
Message::ExecuteKernel(kernel, bindings) => {
|
||||
server.execute(kernel, bindings);
|
||||
}
|
||||
Message::Sync(callback) => {
|
||||
server.sync();
|
||||
Message::Sync(sync_type, callback) => {
|
||||
server.sync(sync_type);
|
||||
callback.send(()).unwrap();
|
||||
}
|
||||
};
|
||||
|
@ -157,11 +157,12 @@ where
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state.sender.send(Message::Sync(callback)).unwrap();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Sync(sync_type, callback))
|
||||
.unwrap();
|
||||
self.response(response)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::storage::ComputeStorage;
|
|||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::sync_type::SyncType;
|
||||
use spin::Mutex;
|
||||
|
||||
/// The MutexComputeChannel ensures thread-safety by locking the server
|
||||
|
@ -59,7 +60,7 @@ where
|
|||
self.server.lock().execute(kernel, handles)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.server.lock().sync()
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
self.server.lock().sync(sync_type)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@ use crate::{
|
|||
};
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::stub::RwLock;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
|
@ -69,8 +69,8 @@ where
|
|||
}
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
pub fn sync(&self) {
|
||||
self.channel.sync()
|
||||
pub fn sync(&self, sync_type: SyncType) {
|
||||
self.channel.sync(sync_type)
|
||||
}
|
||||
|
||||
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
tune::AutotuneKey,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use core::fmt::Debug;
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
|
@ -46,7 +46,7 @@ where
|
|||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self);
|
||||
fn sync(&mut self, command: SyncType);
|
||||
}
|
||||
|
||||
/// Server handle containing the [memory handle](MemoryManagement::Handle).
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use burn_common::benchmark::Benchmark;
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
use crate::channel::ComputeChannel;
|
||||
use crate::client::ComputeClient;
|
||||
|
@ -41,6 +42,7 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
|||
}
|
||||
|
||||
fn sync(&self) {
|
||||
self.client.sync();
|
||||
// For benchmarks - we need to wait for all tasks to complete before returning.
|
||||
self.client.sync(SyncType::Wait);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use burn_compute::{
|
||||
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
|
@ -62,7 +62,7 @@ where
|
|||
kernel.compute(&mut resources);
|
||||
}
|
||||
|
||||
fn sync(&mut self) {
|
||||
fn sync(&mut self, _: SyncType) {
|
||||
// Nothing to do with dummy backend.
|
||||
}
|
||||
}
|
||||
|
|
|
@ -246,6 +246,8 @@ fn autotune_cache_different_keys_return_a_cache_miss() {
|
|||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_different_checksums_return_a_cache_miss() {
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
|
||||
let runtime = Runtime::new();
|
||||
let client = runtime.client(&DummyDevice, dummy::init_client);
|
||||
|
@ -260,7 +262,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
|
|||
let cache_test_autotune_kernel_1 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
|
||||
client.sync();
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
// we use a second compute client in order to have freshly initialized autotune cache
|
||||
// and test invalidation of the cache when the checksum of the operation set is
|
||||
|
@ -278,7 +280,7 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
|
|||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
|
||||
cache_test_autotune_kernel_2.generate_random_checksum = true;
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
|
||||
client.sync();
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use burn_tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Shape, Tensor};
|
||||
use crate::tensor::{backend::Backend, Bool, Data, ElementConversion, Int, Shape, Tensor};
|
||||
|
||||
/// Generate an autoregressive attention mask.
|
||||
///
|
||||
|
@ -89,9 +89,9 @@ pub fn generate_padding_mask<B: Backend>(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use alloc::vec;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_generate_autoregressive_mask() {
|
||||
|
|
|
@ -12,21 +12,21 @@ use crate::{
|
|||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer.
|
||||
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct MultiHeadAttentionConfig {
|
||||
/// The size of each linear layer.
|
||||
d_model: usize,
|
||||
pub d_model: usize,
|
||||
/// The number of heads.
|
||||
n_heads: usize,
|
||||
pub n_heads: usize,
|
||||
/// The dropout rate. Default: 0.1
|
||||
#[config(default = 0.1)]
|
||||
dropout: f64,
|
||||
pub dropout: f64,
|
||||
/// The minimum value a float can take. Default: -1.0e4
|
||||
/// This is used to mask attention scores before calculating attention weights.
|
||||
/// A value too low might result in NaN.
|
||||
#[config(default = -1.0e4)]
|
||||
min_float: f64,
|
||||
pub min_float: f64,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
|
@ -34,7 +34,7 @@ pub struct MultiHeadAttentionConfig {
|
|||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
quiet_softmax: bool,
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
|
||||
|
@ -50,6 +50,8 @@ pub struct MultiHeadAttentionConfig {
|
|||
/// - key: [Linear](nn::Linear) layer with `d_model` input and output features.
|
||||
/// - value: [Linear](nn::Linear) layer with `d_model` input and output features.
|
||||
/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created with [MultiHeadAttentionConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
query: nn::Linear<B>,
|
||||
|
@ -67,8 +69,11 @@ pub struct MultiHeadAttention<B: Backend> {
|
|||
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaInput<B: Backend> {
|
||||
/// Shape `[batch_size, seq_length_1, d_model]`
|
||||
query: Tensor<B, 3>,
|
||||
/// Shape `[batch_size, seq_length_2, d_model]`
|
||||
key: Tensor<B, 3>,
|
||||
/// Shape `[batch_size, seq_length_2, d_model]`
|
||||
value: Tensor<B, 3>,
|
||||
mask_pad: Option<Tensor<B, 2, Bool>>,
|
||||
mask_attn: Option<Tensor<B, 3, Bool>>,
|
||||
|
@ -101,6 +106,9 @@ impl MultiHeadAttentionConfig {
|
|||
impl<B: Backend> MhaInput<B> {
|
||||
/// Create a [multihead attention](MultiHeadAttention) input argument
|
||||
/// by setting the query, key and value to the given tensor.
|
||||
///
|
||||
/// # Shape
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
|
||||
Self {
|
||||
query: tensor.clone(),
|
||||
|
@ -138,15 +146,17 @@ impl<B: Backend> MhaInput<B> {
|
|||
/// [Multihead attention](MultiHeadAttention) outputs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaOutput<B: Backend> {
|
||||
/// The attention weights [batch_size, n_heads, seq_length_1, seq_length_2].
|
||||
/// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.
|
||||
pub weights: Tensor<B, 4>,
|
||||
/// The context tensor [batch_size, seq_length_1, d_model].
|
||||
/// The context tensor `[batch_size, seq_length_1, d_model]`.
|
||||
pub context: Tensor<B, 3>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadAttention<B> {
|
||||
/// Applies the forward pass on the input tensors.
|
||||
///
|
||||
/// See [MultiHeadAttention](MultiHeadAttention) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - query: `[batch_size, seq_length_1, d_model]`
|
||||
|
@ -310,10 +320,10 @@ impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Int;
|
||||
use crate::tensor::{Distribution, Shape};
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use alloc::vec::Vec;
|
||||
use burn::tensor::{Distribution, Shape};
|
||||
use burn_tensor::Int;
|
||||
|
||||
#[test]
|
||||
fn test_self_attention_shapes() {
|
||||
|
|
|
@ -3,15 +3,14 @@ use crate as burn;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::{Initializer, PaddingConfig1d};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv1d;
|
||||
use crate::tensor::ops::ConvOptions;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv1d;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
use super::checks;
|
||||
|
||||
/// Configuration to create an [1D convolution](Conv1d) layer.
|
||||
/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Conv1dConfig {
|
||||
/// The number of input channels.
|
||||
|
@ -44,14 +43,10 @@ pub struct Conv1dConfig {
|
|||
|
||||
/// Applies a 1D convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - weight: Tensor of shape [channels_out, channels_in / groups, kernel_size]
|
||||
///
|
||||
/// - bias: Tensor of shape `[channels_out]`
|
||||
/// Should be created with [Conv1dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Conv1d<B: Backend> {
|
||||
/// Tensor of shape [channels_out, channels_in / groups, kernel_size]
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
|
||||
pub weight: Param<Tensor<B, 3>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
|
@ -102,10 +97,12 @@ impl Conv1dConfig {
|
|||
impl<B: Backend> Conv1d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [conv1d](crate::tensor::module::conv1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels_in, length_in],
|
||||
/// - output: [batch_size, channels_out, length_out],
|
||||
/// - input: `[batch_size, channels_in, length_in]`
|
||||
/// - output: `[batch_size, channels_out, length_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_batch_size, _channels, length] = input.dims();
|
||||
let padding = self
|
||||
|
@ -124,8 +121,8 @@ impl<B: Backend> Conv1d<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -6,13 +6,13 @@ use crate::module::Param;
|
|||
use crate::nn::Initializer;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv2d;
|
||||
use crate::tensor::ops::ConvOptions;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
|
||||
use super::checks;
|
||||
use crate::nn::conv::checks;
|
||||
|
||||
/// Configuration to create an [2D convolution](Conv2d) layer.
|
||||
/// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Conv2dConfig {
|
||||
/// The number of channels.
|
||||
|
@ -43,11 +43,7 @@ pub struct Conv2dConfig {
|
|||
|
||||
/// Applies a 2D convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - weight: Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
///
|
||||
/// - bias: Tensor of shape `[channels_out]`
|
||||
/// Should be created with [Conv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Conv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
|
@ -106,10 +102,12 @@ impl Conv2dConfig {
|
|||
impl<B: Backend> Conv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [conv2d](crate::tensor::module::conv2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels_in, height_in, width_in],
|
||||
/// - output: [batch_size, channels_out, height_out, width_out],
|
||||
/// - input: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels_out, height_out, width_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
let padding =
|
||||
|
@ -127,8 +125,8 @@ impl<B: Backend> Conv2d<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -3,15 +3,15 @@ use crate as burn;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv_transpose1d;
|
||||
use crate::tensor::ops::ConvTransposeOptions;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv_transpose1d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
|
||||
use super::checks;
|
||||
|
||||
/// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer.
|
||||
/// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer
|
||||
/// using the [init function](ConvTranspose1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvTranspose1dConfig {
|
||||
/// The number of channels.
|
||||
|
@ -44,12 +44,6 @@ pub struct ConvTranspose1dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - weight: Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
|
||||
///
|
||||
/// - bias: Tensor of shape `[channels_out]`
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvTranspose1d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
|
||||
|
@ -104,10 +98,12 @@ impl ConvTranspose1dConfig {
|
|||
impl<B: Backend> ConvTranspose1d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [conv_transpose1d](crate::tensor::module::conv_transpose1d).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels_in, length_in],
|
||||
/// - output: [batch_size, channels_out, length_out],
|
||||
/// - input: `[batch_size, channels_in, length_in]`
|
||||
/// - output: `[batch_size, channels_out, length_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
conv_transpose1d(
|
||||
input,
|
||||
|
@ -127,8 +123,8 @@ impl<B: Backend> ConvTranspose1d<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
use crate as burn;
|
||||
|
||||
use super::checks;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::conv::checks;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::conv_transpose2d;
|
||||
use crate::tensor::ops::ConvTransposeOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use burn_tensor::module::conv_transpose2d;
|
||||
use burn_tensor::ops::ConvTransposeOptions;
|
||||
|
||||
/// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer.
|
||||
/// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer
|
||||
/// using the [init function](ConvTranspose2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct ConvTranspose2dConfig {
|
||||
/// The number of channels.
|
||||
|
@ -44,12 +44,6 @@ pub struct ConvTranspose2dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 2D transposed convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - weight: Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
|
||||
///
|
||||
/// - bias: Tensor of shape `[channels_out]`
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvTranspose2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
|
||||
|
@ -105,10 +99,12 @@ impl ConvTranspose2dConfig {
|
|||
impl<B: Backend> ConvTranspose2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [conv_transpose2d](crate::tensor::module::conv_transpose2d).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels_in, height_in, width_in],
|
||||
/// - output: [batch_size, channels_out, height_out, width_out],
|
||||
/// - input: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels_out, height_out, width_out]`
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
conv_transpose2d(
|
||||
input,
|
||||
|
@ -128,8 +124,8 @@ impl<B: Backend> ConvTranspose2d<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::module::Module;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor};
|
||||
|
||||
/// Configuration to create a [Dropout](Dropout) layer.
|
||||
/// Configuration to create a [Dropout](Dropout) layer using the [init function](DropoutConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DropoutConfig {
|
||||
/// The probability of randomly zeroes some elements of the input tensor during training.
|
||||
|
@ -18,6 +18,8 @@ pub struct DropoutConfig {
|
|||
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
|
||||
///
|
||||
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
|
||||
///
|
||||
/// Should be created with [DropoutConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Dropout {
|
||||
prob: f64,
|
||||
|
@ -33,6 +35,8 @@ impl DropoutConfig {
|
|||
impl Dropout {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [Dropout](Dropout) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
|
|
|
@ -5,16 +5,18 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Int;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::Int;
|
||||
|
||||
/// Configuration to create an [Embedding](Embedding) layer.
|
||||
use crate::tensor::module::embedding;
|
||||
|
||||
/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// The number of embedding vectors.
|
||||
n_embedding: usize,
|
||||
pub n_embedding: usize,
|
||||
/// The size of each vector.
|
||||
d_model: usize,
|
||||
pub d_model: usize,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(default = "Initializer::Normal{mean:0.0, std:1.0}")]
|
||||
pub initializer: Initializer,
|
||||
|
@ -22,13 +24,10 @@ pub struct EmbeddingConfig {
|
|||
|
||||
/// Lookup table to store a fix number of vectors.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - weight: Matrix of shape `[n_embedding, d_model]` initialized from a normal distribution:
|
||||
/// `N(0, 1)`
|
||||
/// Should be created with [EmbeddingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Embedding<B: Backend> {
|
||||
/// The learnable weights of the module of shape [n_embedding, d_model] initialized
|
||||
/// The learnable weights of the module of shape `[n_embedding, d_model]` initialized
|
||||
/// from a normal distribution `N(0, 1)`.
|
||||
pub weight: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
@ -47,20 +46,22 @@ impl EmbeddingConfig {
|
|||
impl<B: Backend> Embedding<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [embedding](crate::tensor::module::embedding).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, seq_length]
|
||||
/// - output: [batch_size, d_model]
|
||||
/// - input: `[batch_size, seq_length]`
|
||||
/// - output: `[batch_size, d_model]`
|
||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
burn_tensor::module::embedding(self.weight.val(), input)
|
||||
embedding(self.weight.val(), input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -5,6 +5,7 @@ use crate::tensor::backend::Backend;
|
|||
use crate::tensor::Tensor;
|
||||
|
||||
/// Applies the Gaussian Error Linear Units function element-wise.
|
||||
/// See also [gelu](burn::tensor::activation::gelu)
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Gelu {}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::Shape;
|
||||
use crate::tensor::Shape;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Param, ParamId};
|
||||
|
@ -200,7 +200,7 @@ fn normal_draw<B: Backend, const D: usize, S: Into<Shape<D>>>(
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use burn_tensor::{Data, ElementConversion};
|
||||
use crate::tensor::{Data, ElementConversion};
|
||||
use num_traits::Pow;
|
||||
|
||||
pub type TB = burn_ndarray::NdArray<f32>;
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use crate::tensor::activation::leaky_relu;
|
||||
|
||||
/// Leaky ReLu layer.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LeakyRelu<B: Backend> {
|
||||
///
|
||||
/// Should be created with [LeakyReluConfig](LeakyReluConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct LeakyRelu {
|
||||
/// The negative slope.
|
||||
pub negative_slope: f64,
|
||||
phantom: PhantomData<B>,
|
||||
}
|
||||
/// Configuration to create a [Leaky Relu](LeakyRelu) layer.
|
||||
/// Configuration to create a [Leaky Relu](LeakyRelu) layer using the [init function](LeakyReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LeakyReluConfig {
|
||||
/// The negative slope. Default is 0.01
|
||||
|
@ -22,39 +23,36 @@ pub struct LeakyReluConfig {
|
|||
}
|
||||
impl LeakyReluConfig {
|
||||
/// Initialize a new [Leaky Relu](LeakyRelu) Layer
|
||||
pub fn init<B: Backend>(&self) -> LeakyRelu<B> {
|
||||
pub fn init(&self) -> LeakyRelu {
|
||||
LeakyRelu {
|
||||
negative_slope: self.negative_slope,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LeakyRelu<B> {
|
||||
impl LeakyRelu {
|
||||
/// Forward pass for the Leaky ReLu layer.
|
||||
///
|
||||
/// # Arguments
|
||||
/// See [leaky_relu](crate::tensor::activation::leaky_relu) for more information.
|
||||
///
|
||||
/// * `input` - The input tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
crate::tensor::activation::leaky_relu(input, self.negative_slope)
|
||||
/// # Shapes
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
leaky_relu(input, self.negative_slope)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init();
|
||||
let model = LeakyReluConfig::new().init();
|
||||
let input = Tensor::<TestBackend, 2>::from_data(Data::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
assert_eq!(out.to_data(), Data::from([[0.4410, -0.002507]]));
|
||||
|
@ -87,7 +85,7 @@ mod tests {
|
|||
];
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init();
|
||||
let model = LeakyReluConfig::new().init();
|
||||
let input_data = Tensor::<TestBackend, 3>::from_data(Data::from(input), &device);
|
||||
let actual_output = model.forward(input_data);
|
||||
actual_output
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::tensor::{backend::Backend, Tensor};
|
|||
|
||||
use super::Initializer;
|
||||
|
||||
/// Configuration to create a [Linear](Linear) layer.
|
||||
/// Configuration to create a [Linear](Linear) layer using the [init function](LinearConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LinearConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -26,6 +26,8 @@ pub struct LinearConfig {
|
|||
|
||||
/// Applies a linear transformation to the input tensor:
|
||||
///
|
||||
/// Should be created with [LinearConfig]
|
||||
///
|
||||
/// `O = IW + b`
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Linear<B: Backend> {
|
||||
|
@ -84,8 +86,8 @@ impl<B: Backend> Linear<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, Shape};
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::{Data, Shape};
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::tensor::activation::log_sigmoid;
|
||||
use crate::tensor::{backend::Backend, Int, Tensor};
|
||||
use crate::{config::Config, module::Module};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::activation::log_sigmoid;
|
||||
use burn_tensor::{backend::Backend, Int, Tensor};
|
||||
|
||||
/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
|
||||
/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BinaryCrossEntropyLossConfig {
|
||||
/// Create weighted binary cross-entropy with a weight for each class.
|
||||
|
@ -17,11 +17,11 @@ pub struct BinaryCrossEntropyLossConfig {
|
|||
///
|
||||
/// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
|
||||
/// Alpha = 0 would be the same as default.
|
||||
smoothing: Option<f32>,
|
||||
pub smoothing: Option<f32>,
|
||||
|
||||
/// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
|
||||
#[config(default = false)]
|
||||
logits: bool,
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl BinaryCrossEntropyLossConfig {
|
||||
|
@ -56,6 +56,8 @@ impl BinaryCrossEntropyLossConfig {
|
|||
}
|
||||
|
||||
/// Calculate the binary cross entropy loss from the input logits and the targets.
|
||||
///
|
||||
/// Should be created using [BinaryCrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct BinaryCrossEntropyLoss<B: Backend> {
|
||||
/// Weights for cross-entropy.
|
||||
|
@ -146,8 +148,8 @@ impl<B: Backend> BinaryCrossEntropyLoss<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{activation::sigmoid, Data};
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::{activation::sigmoid, Data};
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy() {
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::tensor::activation::log_softmax;
|
||||
use crate::tensor::{backend::Backend, Bool, Int, Tensor};
|
||||
use crate::{config::Config, module::Module};
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::activation::log_softmax;
|
||||
use burn_tensor::{backend::Backend, Bool, Int, Tensor};
|
||||
|
||||
/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss).
|
||||
/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct CrossEntropyLossConfig {
|
||||
/// Create padded cross entropy.
|
||||
///
|
||||
/// Prevents pad tokens from impacting loss calculation.
|
||||
pad_tokens: Option<Vec<usize>>,
|
||||
pub pad_tokens: Option<Vec<usize>>,
|
||||
|
||||
/// Create weighted cross-entropy.
|
||||
///
|
||||
|
@ -21,18 +21,18 @@ pub struct CrossEntropyLossConfig {
|
|||
/// # Pre-conditions
|
||||
/// - The order of the weight vector should correspond to the label integer assignment.
|
||||
/// - Targets assigned negative Int's will not be allowed.
|
||||
weights: Option<Vec<f32>>,
|
||||
pub weights: Option<Vec<f32>>,
|
||||
|
||||
/// Create cross-entropy with label smoothing.
|
||||
///
|
||||
/// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
|
||||
/// Alpha = 0 would be the same as default.
|
||||
smoothing: Option<f32>,
|
||||
pub smoothing: Option<f32>,
|
||||
|
||||
/// Create cross-entropy with probabilities as input instead of logits.
|
||||
///
|
||||
#[config(default = true)]
|
||||
logits: bool,
|
||||
pub logits: bool,
|
||||
}
|
||||
|
||||
impl CrossEntropyLossConfig {
|
||||
|
@ -68,6 +68,8 @@ impl CrossEntropyLossConfig {
|
|||
}
|
||||
|
||||
/// Calculate the cross entropy loss from the input logits and the targets.
|
||||
///
|
||||
/// Should be created using [CrossEntropyLossConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CrossEntropyLoss<B: Backend> {
|
||||
pad_tokens: Option<Vec<usize>>,
|
||||
|
@ -214,8 +216,8 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{loss::cross_entropy_with_logits, Data, Distribution};
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::{loss::cross_entropy_with_logits, Data, Distribution};
|
||||
|
||||
macro_rules! setup {
|
||||
() => {{
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::{config::Config, module::Module};
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Tensor;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::Reduction;
|
||||
|
@ -124,8 +124,8 @@ impl<B: Backend> HuberLoss<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::nn::loss::reduction::Reduction;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::Backend, Tensor};
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
/// Calculate the mean squared error loss from the input logits and the targets.
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -55,8 +55,8 @@ impl<B: Backend> MseLoss<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_mse_loss() {
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [BatchNorm](BatchNorm) layer.
|
||||
/// Configuration to create a [BatchNorm](BatchNorm) layer using the [init function](BatchNormConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BatchNormConfig {
|
||||
/// The number of features.
|
||||
|
@ -23,18 +23,31 @@ pub struct BatchNormConfig {
|
|||
/// Applies Batch Normalization over a tensor as described in the paper [Batch Normalization](https://arxiv.org/abs/1502.03167)
|
||||
///
|
||||
/// `Y = norm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `norm` is the normalization function
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
/// Should be created using [BatchNormConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct BatchNorm<B: Backend, const D: usize> {
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
running_mean: RunningState<Tensor<B, 1>>,
|
||||
running_var: RunningState<Tensor<B, 1>>,
|
||||
/// The learnable weight gamma.
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// The learnable weight beta.
|
||||
pub beta: Param<Tensor<B, 1>>,
|
||||
/// The running mean.
|
||||
pub running_mean: RunningState<Tensor<B, 1>>,
|
||||
/// The running variance.
|
||||
pub running_var: RunningState<Tensor<B, 1>>,
|
||||
momentum: f64,
|
||||
epsilon: f64,
|
||||
}
|
||||
|
||||
impl BatchNormConfig {
|
||||
/// Initialize a new [batch norm](BatchNorm) module.
|
||||
/// Initializes a new [batch norm](BatchNorm) module.
|
||||
pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
|
||||
let gamma = Initializer::Ones.init([self.num_features], device);
|
||||
let beta = Initializer::Zeros.init([self.num_features], device);
|
||||
|
@ -56,10 +69,16 @@ impl BatchNormConfig {
|
|||
impl<const D: usize, B: Backend> BatchNorm<B, D> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [BatchNorm](BatchNorm) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels, ...]`
|
||||
/// - output: `[batch_size, channels, ...]`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if the input tensor has a dimension different from `D + 2`.
|
||||
pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
|
||||
// Should be move to a compilation error when const generic support that kind of
|
||||
// validation. https://github.com/rust-lang/rust/issues/76560
|
||||
|
@ -168,8 +187,8 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
|
|||
#[cfg(test)]
|
||||
mod tests_1d {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::{module::AutodiffModule, TestAutodiffBackend};
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_train() {
|
||||
|
@ -228,8 +247,8 @@ mod tests_1d {
|
|||
#[cfg(test)]
|
||||
mod tests_2d {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::{module::AutodiffModule, TestAutodiffBackend};
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn batch_norm_forward_train() {
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::module::Param;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Configuration to create a [GroupNorm](GroupNorm) layer.
|
||||
/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct GroupNormConfig {
|
||||
/// The number of groups to separate the channels into
|
||||
|
@ -24,17 +24,28 @@ pub struct GroupNormConfig {
|
|||
pub affine: bool,
|
||||
}
|
||||
|
||||
/// Applies Group Normalization over a mini-batch of inputs.
|
||||
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
///
|
||||
/// `Y = groupnorm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
/// Should be created using [GroupNormConfig](GroupNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
num_groups: usize,
|
||||
num_channels: usize,
|
||||
gamma: Option<Param<Tensor<B, 1>>>,
|
||||
beta: Option<Param<Tensor<B, 1>>>,
|
||||
epsilon: f64,
|
||||
affine: bool,
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
|
||||
pub(crate) num_groups: usize,
|
||||
pub(crate) num_channels: usize,
|
||||
pub(crate) epsilon: f64,
|
||||
pub(crate) affine: bool,
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
|
@ -69,11 +80,57 @@ impl GroupNormConfig {
|
|||
impl<B: Backend> GroupNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [GroupNorm](GroupNorm) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
/// - output: `[..., any, d_model]`
|
||||
/// - input: `[batch_size, num_channels, *]`
|
||||
/// - output: `[batch_size, num_channels, *]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if input.shape().dims[1] != self.num_channels {
|
||||
panic!(
|
||||
"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
|
||||
self.num_channels,
|
||||
input.shape().dims[1]
|
||||
);
|
||||
}
|
||||
|
||||
let gamma = self.gamma.as_ref().map(|x| x.val());
|
||||
let beta = self.beta.as_ref().map(|x| x.val());
|
||||
|
||||
group_norm(
|
||||
input,
|
||||
gamma,
|
||||
beta,
|
||||
self.num_groups,
|
||||
self.epsilon,
|
||||
self.affine,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||
///
|
||||
/// `Y = groupnorm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
pub(crate) fn group_norm<B: Backend, const D: usize>(
|
||||
input: Tensor<B, D>,
|
||||
gamma: Option<Tensor<B, 1>>,
|
||||
beta: Option<Tensor<B, 1>>,
|
||||
num_groups: usize,
|
||||
epsilon: f64,
|
||||
affine: bool,
|
||||
) -> Tensor<B, D> {
|
||||
if (beta.is_none() || gamma.is_none()) && affine {
|
||||
panic!("Affine is set to true, but gamma or beta is None");
|
||||
}
|
||||
|
||||
let shape = input.shape();
|
||||
if shape.num_elements() <= 2 {
|
||||
panic!(
|
||||
|
@ -85,42 +142,33 @@ impl<B: Backend> GroupNorm<B> {
|
|||
let batch_size = shape.dims[0];
|
||||
let num_channels = shape.dims[1];
|
||||
|
||||
if num_channels != self.num_channels {
|
||||
panic!(
|
||||
"expected {} channels but got {}",
|
||||
self.num_channels, num_channels
|
||||
);
|
||||
}
|
||||
|
||||
let hidden_size =
|
||||
shape.dims[2..].iter().product::<usize>() * num_channels / self.num_groups;
|
||||
let input = input.reshape([batch_size, self.num_groups, hidden_size]);
|
||||
let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
|
||||
let input = input.reshape([batch_size, num_groups, hidden_size]);
|
||||
|
||||
let mean = input.clone().sum_dim(2) / hidden_size as f64;
|
||||
let input = input.sub(mean);
|
||||
|
||||
let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
|
||||
let input_normalized = input.div(var.sqrt().add_scalar(self.epsilon));
|
||||
let input_normalized = input.div(var.sqrt().add_scalar(epsilon));
|
||||
|
||||
if self.affine {
|
||||
if affine {
|
||||
let mut affine_shape = [1; D];
|
||||
affine_shape[1] = num_channels;
|
||||
|
||||
input_normalized
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.clone().unwrap().val().reshape(affine_shape))
|
||||
.add(self.beta.clone().unwrap().val().reshape(affine_shape))
|
||||
.mul(gamma.clone().unwrap().reshape(affine_shape))
|
||||
.add(beta.clone().unwrap().reshape(affine_shape))
|
||||
} else {
|
||||
input_normalized.reshape(shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn group_norm_forward_affine_false() {
|
||||
|
|
|
@ -1,45 +1,56 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::{Module, Param};
|
||||
use crate::nn::norm::group_norm;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
use super::{GroupNorm, GroupNormConfig};
|
||||
|
||||
/// Configuration to create a [InstanceNorm](InstanceNorm) layer.
|
||||
/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct InstanceNormConfig {
|
||||
/// The number of channels expected in the input
|
||||
num_channels: usize,
|
||||
pub num_channels: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f64,
|
||||
pub epsilon: f64,
|
||||
/// A boolean value that when set to `true`, this module has learnable
|
||||
/// per-channel affine parameters initialized to ones (for weights)
|
||||
/// and zeros (for biases). Default: `true`
|
||||
#[config(default = true)]
|
||||
affine: bool,
|
||||
pub affine: bool,
|
||||
}
|
||||
|
||||
/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)
|
||||
///
|
||||
/// Should be created using [InstanceNormConfig](InstanceNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
pub struct InstanceNorm<B: Backend> {
|
||||
group_norm: GroupNorm<B>,
|
||||
/// The learnable weight
|
||||
pub gamma: Option<Param<Tensor<B, 1>>>,
|
||||
/// The learnable bias
|
||||
pub beta: Option<Param<Tensor<B, 1>>>,
|
||||
|
||||
num_channels: usize,
|
||||
epsilon: f64,
|
||||
affine: bool,
|
||||
}
|
||||
|
||||
impl InstanceNormConfig {
|
||||
/// Initialize a new [instance norm](InstanceNorm) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
|
||||
InstanceNorm {
|
||||
group_norm: self.to_group_norm().init(device),
|
||||
}
|
||||
}
|
||||
let (gamma, beta) = if self.affine {
|
||||
let gamma = Initializer::Ones.init([self.num_channels], device);
|
||||
let beta = Initializer::Zeros.init([self.num_channels], device);
|
||||
|
||||
fn to_group_norm(&self) -> GroupNormConfig {
|
||||
GroupNormConfig {
|
||||
// Group norm is equivalent to instance norm, when the number of groups is
|
||||
// equal to the number of channels.
|
||||
num_groups: self.num_channels,
|
||||
(Some(gamma), Some(beta))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
InstanceNorm {
|
||||
gamma,
|
||||
beta,
|
||||
num_channels: self.num_channels,
|
||||
epsilon: self.epsilon,
|
||||
affine: self.affine,
|
||||
|
@ -50,20 +61,28 @@ impl InstanceNormConfig {
|
|||
impl<B: Backend> InstanceNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See also [InstanceNormConfig](InstanceNormConfig) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
/// - output: `[..., any, d_model]`
|
||||
/// - input: `[batch_size, num_channels, *]`
|
||||
/// - output: `[batch_size, num_channels, *]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
self.group_norm.forward(input)
|
||||
// Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
|
||||
let num_groups = self.num_channels;
|
||||
|
||||
let gamma = self.gamma.as_ref().map(|x| x.val());
|
||||
let beta = self.beta.as_ref().map(|x| x.val());
|
||||
|
||||
group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn instance_norm_forward_affine_false() {
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use crate as burn;
|
||||
use crate::nn::Initializer;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Configuration to create a [LayerNorm](LayerNorm) layer.
|
||||
/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct LayerNormConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -20,9 +20,19 @@ pub struct LayerNormConfig {
|
|||
/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
|
||||
///
|
||||
/// `Y = norm(X) * γ + β`
|
||||
///
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `γ` is the learnable weight
|
||||
/// - `β` is the learnable bias
|
||||
///
|
||||
/// Should be created using [LayerNormConfig](LayerNormConfig).
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LayerNorm<B: Backend> {
|
||||
/// The learnable weight.
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
/// The learnable bias.
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
epsilon: f64,
|
||||
}
|
||||
|
@ -44,6 +54,8 @@ impl LayerNormConfig {
|
|||
impl<B: Backend> LayerNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See the [LayerNorm](LayerNorm) documentation for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
|
@ -62,7 +74,7 @@ impl<B: Backend> LayerNorm<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
use crate::tensor::Data;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::{TestAutodiffBackend, TestBackend};
|
||||
|
|
|
@ -7,18 +7,22 @@ use crate::nn::Initializer;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Configuration to create a [RMS Norm](RmsNorm) layer.
|
||||
/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct RmsNormConfig {
|
||||
/// The size of the input features.
|
||||
d_model: usize,
|
||||
pub d_model: usize,
|
||||
/// A value required for numerical stability. Default: 1e-5
|
||||
#[config(default = 1e-5)]
|
||||
epsilon: f64,
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl RmsNormConfig {
|
||||
/// Initialize a new [RMS Norm](RmsNorm) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `epsilon` is not positive.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
|
||||
assert!(self.epsilon > 0.0, "epsilon must be positive.");
|
||||
|
||||
|
@ -35,11 +39,18 @@ impl RmsNormConfig {
|
|||
///
|
||||
/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
|
||||
///
|
||||
/// where `eps` is a small value to avoid division by zero.
|
||||
/// Where:
|
||||
/// - `X` is the input tensor
|
||||
/// - `Y` is the output tensor
|
||||
/// - `gamma` is the learnable weight
|
||||
/// - `mean` is the mean operation
|
||||
/// - `eps` is a small value to avoid division by zero.
|
||||
///
|
||||
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct RmsNorm<B: Backend> {
|
||||
/// The learnable parameter to scale the normalized tensor
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
pub gamma: Param<Tensor<B, 1>>,
|
||||
/// A value required for numerical stability
|
||||
epsilon: f64,
|
||||
}
|
||||
|
@ -47,6 +58,8 @@ pub struct RmsNorm<B: Backend> {
|
|||
impl<B: Backend> RmsNorm<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See the [RmsNorm](RmsNorm) documentation for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any, d_model]`
|
||||
|
@ -61,8 +74,8 @@ impl<B: Backend> RmsNorm<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Data;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn rms_norm_forward() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
use crate::tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
|
|
@ -4,9 +4,10 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::adaptive_avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer.
|
||||
use crate::tensor::module::adaptive_avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct AdaptiveAvgPool1dConfig {
|
||||
/// The size of the output.
|
||||
|
@ -14,6 +15,8 @@ pub struct AdaptiveAvgPool1dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 1D adaptive avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AdaptiveAvgPool1dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct AdaptiveAvgPool1d {
|
||||
output_size: usize,
|
||||
|
@ -31,10 +34,12 @@ impl AdaptiveAvgPool1dConfig {
|
|||
impl AdaptiveAvgPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [adaptive_avg_pool1d](crate::tensor::module::adaptive_avg_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, length],
|
||||
/// - output: [batch_size, channels, length_out],
|
||||
/// - input: `[batch_size, channels, length]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
adaptive_avg_pool1d(input, self.output_size)
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::adaptive_avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer.
|
||||
use crate::tensor::module::adaptive_avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct AdaptiveAvgPool2dConfig {
|
||||
/// The size of the output.
|
||||
|
@ -14,6 +15,8 @@ pub struct AdaptiveAvgPool2dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 2D adaptive avg pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [AdaptiveAvgPool2dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct AdaptiveAvgPool2d {
|
||||
output_size: [usize; 2],
|
||||
|
@ -31,10 +34,12 @@ impl AdaptiveAvgPool2dConfig {
|
|||
impl AdaptiveAvgPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [adaptive_avg_pool2d](crate::tensor::module::adaptive_avg_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, height_in, width_in],
|
||||
/// - output: [batch_size, channels, height_out, width_out],
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
adaptive_avg_pool2d(input, self.output_size)
|
||||
}
|
||||
|
|
|
@ -5,9 +5,10 @@ use crate::module::Module;
|
|||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
|
||||
use crate::tensor::module::avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AvgPool1dConfig {
|
||||
/// The size of the kernel.
|
||||
|
@ -25,7 +26,7 @@ pub struct AvgPool1dConfig {
|
|||
|
||||
/// Applies a 1D avg pooling over input tensors.
|
||||
///
|
||||
/// See [AvgPool1dConfig](AvgPool1dConfig) for details.
|
||||
/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
|
@ -61,10 +62,12 @@ impl AvgPool1dConfig {
|
|||
impl AvgPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [avg_pool1d](crate::tensor::module::avg_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, length_in],
|
||||
/// - output: [batch_size, channels, length_out],
|
||||
/// - input: `[batch_size, channels, length_in]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_batch_size, _channels, length] = input.dims();
|
||||
let padding = self
|
||||
|
|
|
@ -5,9 +5,10 @@ use crate::module::Module;
|
|||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D avg pooling](AvgPool2d) layer.
|
||||
use crate::tensor::module::avg_pool2d;
|
||||
|
||||
/// Configuration to create a [2D avg pooling](AvgPool2d) layer using the [init function](AvgPool2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AvgPool2dConfig {
|
||||
/// The size of the kernel.
|
||||
|
@ -25,7 +26,7 @@ pub struct AvgPool2dConfig {
|
|||
|
||||
/// Applies a 2D avg pooling over input tensors.
|
||||
///
|
||||
/// See [AvgPool2dConfig](AvgPool2dConfig) for details.
|
||||
/// Should be created with [AvgPool2dConfig](AvgPool2dConfig).
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
|
@ -60,10 +61,12 @@ impl AvgPool2dConfig {
|
|||
impl AvgPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [avg_pool2d](crate::tensor::module::avg_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, height_in, width_in],
|
||||
/// - output: [batch_size, channels, height_out, width_out],
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
let padding =
|
||||
|
|
|
@ -5,9 +5,10 @@ use crate::module::Module;
|
|||
use crate::nn::PaddingConfig1d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::max_pool1d;
|
||||
|
||||
/// Configuration to create a [1D max pooling](MaxPool1d) layer.
|
||||
use crate::tensor::module::max_pool1d;
|
||||
|
||||
/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MaxPool1dConfig {
|
||||
/// The size of the kernel.
|
||||
|
@ -24,6 +25,8 @@ pub struct MaxPool1dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 1D max pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct MaxPool1d {
|
||||
stride: usize,
|
||||
|
@ -47,10 +50,12 @@ impl MaxPool1dConfig {
|
|||
impl MaxPool1d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [max_pool1d](crate::tensor::module::max_pool1d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, length_in],
|
||||
/// - output: [batch_size, channels, length_out],
|
||||
/// - input: `[batch_size, channels, length_in]`
|
||||
/// - output: `[batch_size, channels, length_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let [_batch_size, _channels, length] = input.dims();
|
||||
let padding = self
|
||||
|
|
|
@ -5,9 +5,10 @@ use crate::module::Module;
|
|||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::max_pool2d;
|
||||
|
||||
/// Configuration to create an [2D max pooling](MaxPool2d) layer.
|
||||
use crate::tensor::module::max_pool2d;
|
||||
|
||||
/// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init).
|
||||
#[derive(Debug, Config)]
|
||||
pub struct MaxPool2dConfig {
|
||||
/// The size of the kernel.
|
||||
|
@ -24,6 +25,8 @@ pub struct MaxPool2dConfig {
|
|||
}
|
||||
|
||||
/// Applies a 2D max pooling over input tensors.
|
||||
///
|
||||
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct MaxPool2d {
|
||||
stride: [usize; 2],
|
||||
|
@ -47,10 +50,12 @@ impl MaxPool2dConfig {
|
|||
impl MaxPool2d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [max_pool2d](crate::tensor::module::max_pool2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: [batch_size, channels, height_in, width_in],
|
||||
/// - output: [batch_size, channels, height_out, width_out],
|
||||
/// - input: `[batch_size, channels, height_in, width_in]`
|
||||
/// - output: `[batch_size, channels, height_out, width_out]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
let padding =
|
||||
|
|
|
@ -4,25 +4,25 @@ use crate as burn;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Data;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
/// Configuration to create an [PositionalEncoding](PositionalEncoding) layer.
|
||||
/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct PositionalEncodingConfig {
|
||||
/// Maximum sequence size to use.
|
||||
#[config(default = "5_000")]
|
||||
max_sequence_size: usize,
|
||||
pub max_sequence_size: usize,
|
||||
|
||||
/// The size of each vector.
|
||||
d_model: usize,
|
||||
pub d_model: usize,
|
||||
|
||||
/// Max time scale to use.
|
||||
#[config(default = "10_000")]
|
||||
max_timescale: usize,
|
||||
pub max_timescale: usize,
|
||||
}
|
||||
|
||||
/// Positional encoding layer for transformer models.
|
||||
|
@ -37,6 +37,8 @@ pub struct PositionalEncodingConfig {
|
|||
/// The reference implementation can be found here:
|
||||
/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
|
||||
/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
|
||||
///
|
||||
/// Should be created using [PositionalEncodingConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PositionalEncoding<B: Backend> {
|
||||
sinusoids: Tensor<B, 3>,
|
||||
|
|
|
@ -6,13 +6,15 @@ use crate::nn::Initializer;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
/// Parametric Relu layer.
|
||||
///
|
||||
/// Should be created using [PReluConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PRelu<B: Backend> {
|
||||
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
|
||||
/// be the same as number of channels in the input tensor
|
||||
pub alpha: Param<Tensor<B, 1>>,
|
||||
}
|
||||
/// Configuration to create a [Parametric Relu](PRelu) layer.
|
||||
/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PReluConfig {
|
||||
/// The number of parameters.
|
||||
|
@ -39,6 +41,8 @@ impl<B: Backend> PRelu<B> {
|
|||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
///
|
||||
/// See also [prelu](crate::tensor::activation::prelu) for more information.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
crate::tensor::activation::prelu(input, self.alpha.val())
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ use crate::module::Module;
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Applies the rectified linear unit function element-wise:
|
||||
/// Applies the rectified linear unit function element-wise
|
||||
/// See also [relu](burn::tensor::activation::relu)
|
||||
///
|
||||
/// `y = max(0, x)`
|
||||
#[derive(Module, Clone, Debug, Default)]
|
||||
pub struct Relu {}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::module::Module;
|
||||
use crate::nn::{Initializer, Linear, LinearConfig};
|
||||
use burn_tensor::{backend::Backend, Tensor};
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
/// A GateController represents a gate in an LSTM cell. An
|
||||
/// LSTM cell generally contains three gates: an input gate,
|
||||
|
|
|
@ -4,13 +4,13 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::nn::rnn::gate_controller;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::activation;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::activation;
|
||||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
/// The configuration for a [gru](Gru) module.
|
||||
/// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct GruConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -24,7 +24,11 @@ pub struct GruConfig {
|
|||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// The Gru module. This implementation is for a unidirectional, stateless, Gru.
|
||||
/// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru.
|
||||
///
|
||||
/// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).
|
||||
///
|
||||
/// Should be created with [GruConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Gru<B: Backend> {
|
||||
update_gate: GateController<B>,
|
||||
|
@ -73,13 +77,11 @@ impl<B: Backend> Gru<B> {
|
|||
/// Applies the forward pass on the input tensor. This GRU implementation
|
||||
/// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size].
|
||||
///
|
||||
/// Parameters:
|
||||
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
|
||||
/// state: An optional tensor representing an initial cell state with the same dimensions
|
||||
/// # Shapes
|
||||
/// - batched_input: `[batch_size, sequence_length, input_size]`.
|
||||
/// - state: An optional tensor representing an initial cell state with the same dimensions
|
||||
/// as batched_input. If none is provided, one will be generated.
|
||||
///
|
||||
/// Returns:
|
||||
/// The resulting state tensor, with shape [batch_size, sequence_length, hidden_size].
|
||||
/// - output: `[batch_size, sequence_length, hidden_size]`.
|
||||
pub fn forward(
|
||||
&self,
|
||||
batched_input: Tensor<B, 3>,
|
||||
|
@ -177,8 +179,8 @@ impl<B: Backend> Gru<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, Distribution};
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::{Data, Distribution};
|
||||
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
|
|
|
@ -4,9 +4,9 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::nn::rnn::gate_controller::GateController;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::activation;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::activation;
|
||||
|
||||
/// A LstmState is used to store cell state and hidden state in LSTM.
|
||||
pub struct LstmState<B: Backend, const D: usize> {
|
||||
|
@ -23,7 +23,7 @@ impl<B: Backend, const D: usize> LstmState<B, D> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The configuration for a [lstm](Lstm) module.
|
||||
/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct LstmConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -38,6 +38,10 @@ pub struct LstmConfig {
|
|||
}
|
||||
|
||||
/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
|
||||
///
|
||||
/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).
|
||||
///
|
||||
/// Should be created with [LstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Lstm<B: Backend> {
|
||||
/// The input gate regulates which information to update and store in the cell state at each time step.
|
||||
|
@ -171,7 +175,7 @@ impl<B: Backend> Lstm<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The configuration for a [Bidirectional LSTM](BiLstm) module.
|
||||
/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct BiLstmConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -186,6 +190,10 @@ pub struct BiLstmConfig {
|
|||
}
|
||||
|
||||
/// The BiLstm module. This implementation is for Bidirectional LSTM.
|
||||
///
|
||||
/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).
|
||||
///
|
||||
/// Should be created with [BiLstmConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct BiLstm<B: Backend> {
|
||||
/// LSTM for the forward direction.
|
||||
|
@ -298,8 +306,8 @@ impl<B: Backend> BiLstm<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::{Data, Device, Distribution};
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::{Data, Device, Distribution};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use crate::TestAutodiffBackend;
|
||||
|
@ -451,7 +459,7 @@ mod tests {
|
|||
#[test]
|
||||
#[cfg(feature = "std")]
|
||||
fn test_batched_backward_pass() {
|
||||
use burn_tensor::Shape;
|
||||
use crate::tensor::Shape;
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 32, true).init(&device);
|
||||
let shape: Shape<3> = [8, 10, 64].into();
|
||||
|
|
|
@ -2,25 +2,25 @@ use crate as burn;
|
|||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Int;
|
||||
use crate::tensor::Tensor;
|
||||
use alloc::vec;
|
||||
use burn_tensor::Int;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer.
|
||||
/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RotaryEncodingConfig {
|
||||
/// Maximum sequence length of input
|
||||
max_sequence_length: usize,
|
||||
pub max_sequence_length: usize,
|
||||
|
||||
/// Size of the input embedding or hidden dimension
|
||||
d_model: usize,
|
||||
pub d_model: usize,
|
||||
|
||||
/// Scaling factor for frequency computation. Defaults to 10000.0
|
||||
#[config(default = "10000.0")]
|
||||
theta: f32,
|
||||
pub theta: f32,
|
||||
}
|
||||
|
||||
impl RotaryEncodingConfig {
|
||||
|
@ -84,6 +84,8 @@ impl RotaryEncodingConfig {
|
|||
/// explicit relative position dependency in self-attention formulation.
|
||||
///
|
||||
/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
|
||||
///
|
||||
/// Should be created using [RotaryEncodingConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct RotaryEncoding<B: Backend> {
|
||||
/// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::tensor::{backend::Backend, Tensor};
|
|||
|
||||
use super::{Initializer, Linear, LinearConfig};
|
||||
|
||||
/// Configuration to create a [SwiGlu](SwiGlu) activation layer.
|
||||
/// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SwiGluConfig {
|
||||
/// The size of the input features.
|
||||
|
@ -29,16 +29,15 @@ pub struct SwiGluConfig {
|
|||
/// The SwiGLU activation function is defined as:
|
||||
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - linear inner: The inner linear layer for Swish activation function
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
/// - linear outer: Outer Linear layer for element wise multiplication
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
/// Should be created with [SwiGluConfig].
|
||||
#[derive(Module, Debug)]
|
||||
pub struct SwiGlu<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
/// The inner linear layer for Swish activation function
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
pub linear_inner: Linear<B>,
|
||||
/// The outer linear layer for element wise multiplication
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
pub linear_outer: Linear<B>,
|
||||
}
|
||||
|
||||
impl SwiGluConfig {
|
||||
|
@ -58,11 +57,11 @@ impl SwiGluConfig {
|
|||
}
|
||||
|
||||
impl<B: Backend> SwiGlu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
/// Applies the Swish Gated Linear Unit to the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_input]`
|
||||
/// - input: `[batch_size, seq_length, d_input]`
|
||||
/// - output: `[batch_size, seq_length, d_output]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.linear_inner.forward(input.clone());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::tensor::Bool;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::Bool;
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
|
@ -17,7 +17,7 @@ use crate::{
|
|||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer.
|
||||
/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct TransformerDecoderConfig {
|
||||
/// The size of the model.
|
||||
|
@ -54,6 +54,8 @@ pub struct TransformerDecoderConfig {
|
|||
/// # Params
|
||||
///
|
||||
/// - layers: transformer decoder layers with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created using [TransformerDecoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerDecoder<B: Backend> {
|
||||
layers: Vec<TransformerDecoderLayer<B>>,
|
||||
|
@ -204,6 +206,7 @@ impl<B: Backend> TransformerDecoderLayer<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Applies the TransformerDecoder forward pass to the input tensor.
|
||||
fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
|
||||
// Self attention residual path.
|
||||
let x = input.target;
|
||||
|
@ -401,8 +404,8 @@ impl<B: Backend> TransformerDecoder<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Distribution;
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_last() {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::tensor::Bool;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::Bool;
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
|
@ -17,7 +17,7 @@ use crate::{
|
|||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer.
|
||||
/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct TransformerEncoderConfig {
|
||||
/// The size of the model.
|
||||
|
@ -54,6 +54,8 @@ pub struct TransformerEncoderConfig {
|
|||
/// # Params
|
||||
///
|
||||
/// - layers: transformer encoder layers with `d_model` input and output features.
|
||||
///
|
||||
/// Should be created using [TransformerEncoderConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerEncoder<B: Backend> {
|
||||
layers: Vec<TransformerEncoderLayer<B>>,
|
||||
|
@ -338,8 +340,8 @@ impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::Distribution;
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_norm_last() {
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::{
|
|||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer.
|
||||
/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
|
||||
#[derive(Config)]
|
||||
pub struct PositionWiseFeedForwardConfig {
|
||||
/// The size of the input and output features.
|
||||
|
@ -25,12 +25,16 @@ pub struct PositionWiseFeedForwardConfig {
|
|||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies the position-wise feed-forward network to the input tensor.
|
||||
/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
|
||||
/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
|
||||
///
|
||||
/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
|
||||
///
|
||||
/// Should be created using [PositionWiseFeedForwardConfig]
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
|
|
|
@ -2,12 +2,13 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::module::unfold4d;
|
||||
use burn_tensor::ops::UnfoldOptions;
|
||||
use burn_tensor::Tensor;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::UnfoldOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Configuration to create an [unfold 4D](Unfold4d) layer.
|
||||
use crate::tensor::module::unfold4d;
|
||||
|
||||
/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Unfold4dConfig {
|
||||
/// The size of the kernel.
|
||||
|
@ -24,13 +25,15 @@ pub struct Unfold4dConfig {
|
|||
}
|
||||
|
||||
/// Four-dimensional unfolding.
|
||||
///
|
||||
/// Should be created with [Unfold4dConfig].
|
||||
#[derive(Module, Clone, Debug)]
|
||||
pub struct Unfold4d {
|
||||
config: Unfold4dConfig,
|
||||
}
|
||||
|
||||
impl Unfold4dConfig {
|
||||
/// Initialize a new [unfold 4k](Unfold4d) module.
|
||||
/// Initializes a new [Unfold4d] module.
|
||||
pub fn init(&self) -> Unfold4d {
|
||||
Unfold4d {
|
||||
config: self.clone(),
|
||||
|
@ -41,10 +44,12 @@ impl Unfold4dConfig {
|
|||
impl Unfold4d {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [unfold4d](crate::tensor::module::unfold4d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// input: `[batch_size, channels_in, height, width]`,
|
||||
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`,
|
||||
/// input: `[batch_size, channels_in, height, width]`
|
||||
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
unfold4d(
|
||||
input,
|
||||
|
|
|
@ -7,6 +7,7 @@ use burn_compute::{
|
|||
use burn_cube::ir::CubeDim;
|
||||
use burn_cube::prelude::*;
|
||||
use burn_jit::JitAutotuneKey;
|
||||
use burn_tensor::backend::SyncType;
|
||||
use cudarc::driver::sys::CUctx_st;
|
||||
use cudarc::driver::sys::CUfunc_st;
|
||||
use std::collections::HashMap;
|
||||
|
@ -110,10 +111,17 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
// self.memory_management.storage().perform_deallocations();
|
||||
}
|
||||
|
||||
fn sync(&mut self) {
|
||||
fn sync(&mut self, sync_type: SyncType) {
|
||||
match sync_type {
|
||||
// Synchronize the stream if waiting.
|
||||
SyncType::Wait => {
|
||||
let ctx = self.get_context();
|
||||
ctx.sync();
|
||||
}
|
||||
// Nothing to do - all tasks are already submitted to the stream.
|
||||
SyncType::Flush => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&mut self,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use syn::{Attribute, Ident, Meta};
|
||||
use syn::{Attribute, Meta};
|
||||
|
||||
pub struct AttributeAnalyzer {
|
||||
attr: Attribute,
|
||||
|
@ -6,7 +6,6 @@ pub struct AttributeAnalyzer {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct AttributeItem {
|
||||
pub ident: Ident,
|
||||
pub value: syn::Lit,
|
||||
}
|
||||
|
||||
|
@ -27,10 +26,7 @@ impl AttributeAnalyzer {
|
|||
_ => panic!("Only literal is supported"),
|
||||
};
|
||||
|
||||
AttributeItem {
|
||||
ident: value.path.get_ident().unwrap().clone(),
|
||||
value: lit,
|
||||
}
|
||||
AttributeItem { value: lit }
|
||||
}
|
||||
|
||||
pub fn has_name(&self, name: &str) -> bool {
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceOps},
|
||||
backend::{Backend, DeviceOps, SyncType},
|
||||
ops::FloatTensor,
|
||||
repr::{OperationDescription, ReprBackend},
|
||||
Device,
|
||||
|
@ -45,10 +45,10 @@ impl<B: FusionBackend> Backend for Fusion<B> {
|
|||
B::seed(seed);
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
let client = CLIENTS.client::<B::FusionRuntime>(&device.clone());
|
||||
client.drain();
|
||||
B::sync(device)
|
||||
B::sync(device, sync_type);
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
|
|
|
@ -11,7 +11,6 @@ use std::sync::Arc;
|
|||
pub struct FusionServer<R: FusionRuntime> {
|
||||
streams: MultiStream<R>,
|
||||
pub(crate) handles: HandleContainer<R::FusionHandle>,
|
||||
pub device: R::FusionDevice,
|
||||
}
|
||||
|
||||
impl<R> FusionServer<R>
|
||||
|
@ -22,7 +21,6 @@ where
|
|||
Self {
|
||||
streams: MultiStream::new(device.clone()),
|
||||
handles: HandleContainer::new(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
tensor::JitTensor, FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge,
|
||||
};
|
||||
use burn_compute::server::ComputeServer;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::backend::{Backend, SyncType};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{marker::PhantomData, sync::Mutex};
|
||||
|
||||
|
@ -48,9 +48,9 @@ where
|
|||
false
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
let client = R::client(device);
|
||||
client.sync();
|
||||
client.sync(sync_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -151,7 +151,7 @@ macro_rules! make_elem {
|
|||
|
||||
#[inline(always)]
|
||||
fn int_abs_elem(self) -> Self {
|
||||
(self as i32).abs() as $ty
|
||||
(self as i32).unsigned_abs() as $ty
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::PrecisionBridge;
|
|||
|
||||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps, SyncType};
|
||||
use burn_tensor::ops::IntTensorOps;
|
||||
use burn_tensor::{Int, Tensor};
|
||||
|
||||
|
@ -114,7 +114,8 @@ impl<E: TchElement> Backend for LibTorch<E> {
|
|||
"tch".to_string()
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
fn sync(device: &Self::Device, sync_type: SyncType) {
|
||||
if sync_type == SyncType::Wait {
|
||||
match device {
|
||||
LibTorchDevice::Cpu => (),
|
||||
LibTorchDevice::Cuda(index) => {
|
||||
|
@ -122,12 +123,12 @@ impl<E: TchElement> Backend for LibTorch<E> {
|
|||
}
|
||||
_ => {
|
||||
// When there is no explicit way to synchronize, we write and read one value to sync
|
||||
Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
|
||||
[1].into(),
|
||||
device,
|
||||
))
|
||||
Tensor::<Self, 1, Int>::from_primitive(
|
||||
<Self as IntTensorOps<Self>>::int_zeros([1].into(), device),
|
||||
)
|
||||
.into_data();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,10 @@ use crate::backend::Backend;
|
|||
use crate::check::TensorCheck;
|
||||
use crate::{check, Tensor};
|
||||
|
||||
/// Applies the rectified linear unit function.
|
||||
/// Applies the rectified linear unit function as described in the paper [Deep Learning using
|
||||
/// Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
|
||||
///
|
||||
/// `y = max(0, x)`
|
||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
tensor.relu()
|
||||
}
|
||||
|
@ -20,12 +23,12 @@ pub fn leaky_relu<const D: usize, B: Backend>(
|
|||
))
|
||||
}
|
||||
|
||||
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||
/// Applies the Gaussian Error Linear Units function as described in the paper [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(B::gelu(tensor.primitive))
|
||||
}
|
||||
|
||||
/// Applies Parametric ReLu activation
|
||||
/// Applies Parametric ReLu activation function as described in the paper [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
|
||||
/// ` PReLu(x) = max(0,x) + \alpha * min(0,x)`
|
||||
/// tensor is assumed to be of shape \[batch_size, channels, ...\]
|
||||
/// alpha is assumed to be of shape \[channels\] or \[1\]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use alloc::string::String;
|
||||
pub use burn_common::sync_type::SyncType;
|
||||
|
||||
use crate::ops::*;
|
||||
use crate::tensor::Element;
|
||||
|
@ -96,7 +97,7 @@ pub trait Backend:
|
|||
fn seed(seed: u64);
|
||||
|
||||
/// Sync the backend, ensure that all computation are finished.
|
||||
fn sync(_device: &Self::Device) {}
|
||||
fn sync(_device: &Self::Device, _sync_type: SyncType) {}
|
||||
}
|
||||
|
||||
/// Trait that allows a backend to support autodiff.
|
||||
|
|
|
@ -8,7 +8,7 @@ use burn_compute::{
|
|||
};
|
||||
use burn_cube::prelude::*;
|
||||
use burn_jit::JitAutotuneKey;
|
||||
use burn_tensor::Reader;
|
||||
use burn_tensor::{backend::SyncType, Reader};
|
||||
use hashbrown::HashMap;
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt, StagingBelt},
|
||||
|
@ -60,23 +60,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn submit(&mut self) {
|
||||
self.staging_belt.finish();
|
||||
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
self.tasks_count = 0;
|
||||
|
||||
// Cleanup allocations and deallocations.
|
||||
self.memory_management.storage().perform_deallocations();
|
||||
|
||||
self.staging_belt.recall();
|
||||
}
|
||||
|
||||
fn register_compute(
|
||||
&mut self,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
|
@ -150,7 +133,7 @@ where
|
|||
);
|
||||
self.tasks_count += 1;
|
||||
|
||||
self.submit();
|
||||
self.sync(SyncType::Flush);
|
||||
|
||||
BufferReader::new(buffer_dest)
|
||||
}
|
||||
|
@ -304,12 +287,29 @@ where
|
|||
self.register_compute(pipeline, bind_group, work_group);
|
||||
|
||||
if self.tasks_count >= self.tasks_max {
|
||||
self.submit();
|
||||
self.sync(SyncType::Flush);
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&mut self) {
|
||||
self.submit();
|
||||
fn sync(&mut self, sync_type: SyncType) {
|
||||
// Flush commands to the queue.
|
||||
self.staging_belt.finish();
|
||||
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
self.tasks_count = 0;
|
||||
|
||||
// Cleanup allocations and deallocations.
|
||||
self.memory_management.storage().perform_deallocations();
|
||||
|
||||
self.staging_belt.recall();
|
||||
|
||||
if sync_type == SyncType::Wait {
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue