This commit is contained in:
Nathaniel Simard 2022-11-26 15:48:26 -05:00 committed by GitHub
parent 46d06f0c90
commit 8bd0b17296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 383 additions and 2 deletions

View File

@ -55,6 +55,13 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::bool_into_data(tensor)
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D2> {
B::bool_reshape(tensor, shape)
}
fn device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::Device {

View File

@ -153,6 +153,21 @@ macro_rules! to_nd_array_tensor {
let array: ndarray::ArcArray<E, Dim<[usize; $n]>> = $array.reshape(dim);
let array = array.into_dyn();
NdArrayTensor {
array,
shape: $shape,
}
}};
(
bool,
$n:expr,
$shape:expr,
$array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let array: ndarray::ArcArray<bool, Dim<[usize; $n]>> = $array.reshape(dim);
let array = array.into_dyn();
NdArrayTensor {
array,
shape: $shape,

View File

@ -69,6 +69,22 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let values = tensor.array.into_iter().collect();
Data::new(values, tensor.shape)
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &NdArrayTensor<bool, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<bool, D2> {
match D2 {
1 => to_nd_array_tensor!(bool, 1, shape, tensor.array),
2 => to_nd_array_tensor!(bool, 2, shape, tensor.array),
3 => to_nd_array_tensor!(bool, 3, shape, tensor.array),
4 => to_nd_array_tensor!(bool, 4, shape, tensor.array),
5 => to_nd_array_tensor!(bool, 5, shape, tensor.array),
6 => to_nd_array_tensor!(bool, 6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
}
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
NdArrayDevice::Cpu
}

View File

@ -40,6 +40,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let values: Vec<bool> = tensor.tensor.into();
Data::new(values, tensor.shape)
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &TchTensor<bool, D1>,
shape: Shape<D2>,
) -> TchTensor<bool, D2> {
let shape_tch: TchShape<D2> = shape.into();
let tensor = tensor.tensor.reshape(&shape_tch.dims);
let shape = Shape::from(tensor.size());
TchTensor {
tensor,
shape,
kind: TchKind::new(),
}
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
match tensor.tensor.device() {
tch::Device::Cpu => TchDevice::Cpu,

View File

@ -98,8 +98,8 @@ where
/// Returns the dimensions of the current tensor.
///
/// Equivalent to `tensor.shape().dims`.
pub fn dims(&self) -> &[usize; D] {
&B::shape(&self.value).dims
pub fn dims(&self) -> [usize; D] {
B::shape(&self.value).dims
}
/// Returns the data of the current tensor.

View File

@ -2,6 +2,7 @@ use super::Tensor;
use crate::tensor::backend::Backend;
use crate::tensor::{Data, Shape};
#[derive(Debug, Clone)]
pub struct BoolTensor<B: Backend, const D: usize> {
pub(crate) value: B::BoolTensorPrimitive<D>,
}
@ -18,6 +19,13 @@ where
B::bool_shape(&self.value)
}
/// Returns the dimensions of the current tensor.
///
/// Equivalent to `tensor.shape().dims`.
pub fn dims(&self) -> [usize; D] {
self.shape().dims
}
pub fn into_data(self) -> Data<bool, D> {
B::bool_into_data(self.value)
}
@ -35,4 +43,13 @@ where
let data = B::bool_to_data(&self.value);
Tensor::from_data(data.convert())
}
/// Reshape the tensor to have the given shape.
///
/// # Panics
///
/// If the tensor can not be reshape to the given shape.
pub fn reshape<const D2: usize, S: Into<Shape<D2>>>(&self, shape: S) -> BoolTensor<B, D2> {
BoolTensor::new(B::bool_reshape(&self.value, shape.into()))
}
}

View File

@ -20,6 +20,10 @@ pub trait TensorOps<B: Backend> {
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> &Shape<D>;
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: &B::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::BoolTensorPrimitive<D2>;
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
fn to_device<const D: usize>(
tensor: &B::TensorPrimitive<D>,

View File

@ -0,0 +1,301 @@
use crate as burn;
use crate::{
config::Config,
module::{Module, Param},
nn,
tensor::{activation, backend::Backend, BoolTensor, Tensor},
};
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer.
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
/// The size of the each linear layer.
d_model: usize,
/// The number of heads.
n_heads: usize,
/// The dropout rate. Default: 0.1
#[config(default = 0.1)]
dropout: f64,
/// The minimum value a float can take. Default: -1.0e4
/// This is used to mask attention scores before calculating attention weights.
/// A value too low might result in NaN.
#[config(default = -1.0e4)]
min_float: f64,
}
/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
///
/// # Params
///
/// - query: [Linear](nn::Linear) layer with `d_model` input and output features.
/// - key: [Linear](nn::Linear) layer with `d_model` input and output features.
/// - value: [Linear](nn::Linear) layer with `d_model` input and output features.
/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
query: Param<nn::Linear<B>>,
key: Param<nn::Linear<B>>,
value: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
dropout: nn::Dropout,
activation: nn::GELU,
n_heads: usize,
d_k: usize,
min_float: f64,
}
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
#[derive(Debug)]
pub struct MhaInput<B: Backend> {
query: Tensor<B, 3>,
key: Tensor<B, 3>,
value: Tensor<B, 3>,
mask_pad: Option<BoolTensor<B, 2>>,
mask_attn: Option<BoolTensor<B, 3>>,
}
impl<B: Backend> MhaInput<B> {
/// Create a [multihead attention](MultiHeadAttention) input argument
/// by setting the query, key and value to the given tensor.
pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
Self {
query: tensor.clone(),
key: tensor.clone(),
value: tensor,
mask_pad: None,
mask_attn: None,
}
}
/// Create a [multihead attention](MultiHeadAttention) input argument.
pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
Self {
query,
key,
value,
mask_pad: None,
mask_attn: None,
}
}
/// Register the padding mask.
pub fn mask_pad(mut self, mask_pad: BoolTensor<B, 2>) -> Self {
self.mask_pad = Some(mask_pad);
self
}
/// Register the attention mask.
pub fn mask_attn(mut self, mask_attn: BoolTensor<B, 3>) -> Self {
self.mask_attn = Some(mask_attn);
self
}
}
/// [Multihead attention](MultiHeadAttention) outputs.
#[derive(Debug)]
pub struct MhaOutput<B: Backend> {
/// The attention weights [batch_size, seq_length_1, seq_length_2].
pub weights: Tensor<B, 4>,
/// The context tensor [batch_size, seq_length_1, d_model].
pub context: Tensor<B, 3>,
}
impl<B: Backend> MultiHeadAttention<B> {
/// Create the module from the given configuration.
pub fn new(config: &MultiHeadAttentionConfig) -> Self {
let linear = |config: &MultiHeadAttentionConfig| {
Param::new(nn::Linear::new(&nn::LinearConfig::new(
config.d_model,
config.d_model,
)))
};
Self {
query: linear(config),
key: linear(config),
value: linear(config),
output: linear(config),
dropout: nn::Dropout::new(&nn::DropoutConfig::new(config.dropout)),
activation: nn::GELU::new(),
n_heads: config.n_heads,
d_k: config.d_model / config.n_heads,
min_float: config.min_float,
}
}
/// Applies the forward pass on the input tensors.
///
/// # Shapes
///
/// - query: `[batch_size, seq_length_1, d_model]`
/// - key: `[batch_size, seq_length_2, d_model]`
/// - value: `[batch_size, seq_length_2, d_model]`
/// - output: `[batch_size, seq_length_1, d_model]`
pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = self.attention_linear(input.query, &self.query);
let key = self.attention_linear(input.key, &self.key);
let value = self.attention_linear(input.value, &self.value);
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.matmul(&value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = self.output.forward(context);
MhaOutput { weights, context }
}
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
let attn_scores = query
.matmul(&key.transpose())
.div_scalar((self.d_k as f32).sqrt());
self.dropout.forward(attn_scores)
}
fn attn_weights(
&self,
mut attn_scores: Tensor<B, 4>,
mask_pad: Option<BoolTensor<B, 2>>,
mask_attn: Option<BoolTensor<B, 3>>,
) -> Tensor<B, 4> {
if let Some(mask_pad) = mask_pad {
let [batch_size, seq_length] = mask_pad.dims();
attn_scores = attn_scores.mask_fill(
&mask_pad.reshape([batch_size, 1, 1, seq_length]),
self.min_float,
);
}
if let Some(mask_attn) = mask_attn {
let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
attn_scores = attn_scores.mask_fill(
&mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
self.min_float,
);
}
activation::softmax(&attn_scores, 3)
}
fn attention_linear(&self, x: Tensor<B, 3>, linear: &Param<nn::Linear<B>>) -> Tensor<B, 4> {
let [batch_size, seq_length, _d_model] = x.dims();
linear
.forward(x)
.reshape([batch_size, seq_length, self.n_heads, self.d_k])
.swap_dims(1, 2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Distribution, Shape};
#[test]
fn test_self_attention_shapes() {
let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
let mha = MultiHeadAttention::<TestBackend>::new(&MultiHeadAttentionConfig::new(
d_model, n_heads,
));
let input = MhaInput::self_attn(Tensor::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
));
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
&Shape::new([batch_size, seq_length, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
&Shape::new([batch_size, n_heads, seq_length, seq_length]),
"Weights should have the correct shape",
);
}
#[test]
fn test_generic_mha_shapes() {
let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
let mha = MultiHeadAttention::<TestBackend>::new(&MultiHeadAttentionConfig::new(
d_model, n_heads,
));
let input = MhaInput::new(
Tensor::random([batch_size, seq_length_1, d_model], Distribution::Standard),
Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
);
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
&Shape::new([batch_size, seq_length_1, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
&Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
"Weights should have the correct shape",
);
}
#[test]
pub fn test_self_attention_mask_pad() {
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads));
// Create a padding mask
let mask_pad = Tensor::zeros([batch_size, seq_length]);
let mask_pad = mask_pad.index_assign(
[0..batch_size, seq_length - num_padded..seq_length],
&Tensor::ones([batch_size, num_padded]),
);
let mask_pad = mask_pad.equal_scalar(1);
let tensor_1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
// Change the end of the tensor
let tensor_2 = tensor_1.index_assign(
[
0..batch_size,
seq_length - num_padded..seq_length,
0..d_model,
],
&Tensor::random([batch_size, num_padded, d_model], Distribution::Standard),
);
let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
let output_1 = mha.forward(input_1);
let output_2 = mha.forward(input_2);
// Check that the begginning of each tensor is the same
output_1
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data()
.assert_approx_eq(
&output_2
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data(),
3,
);
}
}

View File

@ -0,0 +1,3 @@
mod mha;
pub use mha::*;

View File

@ -1,3 +1,5 @@
pub mod attention;
mod dropout;
mod embedding;
mod gelu;