mirror of https://github.com/tracel-ai/burn.git
Implement padding for conv2d (#523)
This commit is contained in:
parent
b83203bc1c
commit
e066d95d2e
|
@ -4,10 +4,10 @@ use crate::config::Config;
|
|||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::conv2d;
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
use burn_tensor::ops::ConvOptions;
|
||||
use libm::sqrt;
|
||||
|
||||
|
@ -28,8 +28,8 @@ pub struct Conv2dConfig {
|
|||
#[config(default = "1")]
|
||||
pub groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "Conv2dPaddingConfig::Valid")]
|
||||
pub padding: Conv2dPaddingConfig,
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
|
@ -38,18 +38,6 @@ pub struct Conv2dConfig {
|
|||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Padding configuration for 2D convolution [config](Conv2dConfig).
|
||||
#[derive(Module, Config, Debug)]
|
||||
pub enum Conv2dPaddingConfig {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
Same,
|
||||
/// Same as no padding.
|
||||
Valid,
|
||||
/// Applies the specified amount of padding to all inputs.
|
||||
Explicit(usize, usize),
|
||||
}
|
||||
|
||||
/// Applies a 2D convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
|
@ -67,7 +55,7 @@ pub struct Conv2d<B: Backend> {
|
|||
kernel_size: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
groups: usize,
|
||||
padding: Conv2dPaddingConfig,
|
||||
padding: PaddingConfig2d,
|
||||
}
|
||||
|
||||
impl Conv2dConfig {
|
||||
|
@ -136,29 +124,6 @@ impl<B: Backend> Conv2d<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Conv2dPaddingConfig {
|
||||
pub(crate) fn calculate_padding_2d(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
kernel_size: &[usize; 2],
|
||||
stride: &[usize; 2],
|
||||
) -> [usize; 2] {
|
||||
let same_padding = || {
|
||||
let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
|
||||
let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
|
||||
|
||||
[p1, p2]
|
||||
};
|
||||
|
||||
match self {
|
||||
Conv2dPaddingConfig::Same => same_padding(),
|
||||
Conv2dPaddingConfig::Valid => [0, 0],
|
||||
Conv2dPaddingConfig::Explicit(v1, v2) => [*v1, *v2],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_tensor::Data;
|
||||
|
|
|
@ -22,6 +22,7 @@ mod gelu;
|
|||
mod initializer;
|
||||
mod linear;
|
||||
mod norm;
|
||||
mod padding;
|
||||
mod pos_encoding;
|
||||
mod relu;
|
||||
mod rnn;
|
||||
|
@ -32,6 +33,7 @@ pub use gelu::*;
|
|||
pub use initializer::*;
|
||||
pub use linear::*;
|
||||
pub use norm::*;
|
||||
pub use padding::*;
|
||||
pub use pos_encoding::*;
|
||||
pub use relu::*;
|
||||
pub use rnn::*;
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
use crate as burn;
|
||||
|
||||
use burn_tensor::ops::conv::calculate_conv_padding;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
|
||||
/// Padding configuration for 2D operators.
|
||||
#[derive(Module, Config, Debug)]
|
||||
pub enum PaddingConfig2d {
|
||||
/// Dynamically calculate the amount of padding necessary to ensure that the output size will be
|
||||
/// the same as the input.
|
||||
Same,
|
||||
/// Same as no padding.
|
||||
Valid,
|
||||
/// Applies the specified amount of padding to all inputs.
|
||||
Explicit(usize, usize),
|
||||
}
|
||||
|
||||
impl PaddingConfig2d {
|
||||
pub(crate) fn calculate_padding_2d(
|
||||
&self,
|
||||
height: usize,
|
||||
width: usize,
|
||||
kernel_size: &[usize; 2],
|
||||
stride: &[usize; 2],
|
||||
) -> [usize; 2] {
|
||||
let same_padding = || {
|
||||
let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
|
||||
let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
|
||||
|
||||
[p1, p2]
|
||||
};
|
||||
|
||||
match self {
|
||||
Self::Same => same_padding(),
|
||||
Self::Valid => [0, 0],
|
||||
Self::Explicit(v1, v2) => [*v1, *v2],
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::nn::conv::Conv2dPaddingConfig;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::avg_pool2d;
|
||||
|
@ -18,19 +18,16 @@ pub struct AvgPool2dConfig {
|
|||
#[config(default = "[1, 1]")]
|
||||
pub strides: [usize; 2],
|
||||
/// The padding configuration.
|
||||
#[config(default = "AvgPool2dPaddingConfig::Valid")]
|
||||
pub padding: AvgPool2dPaddingConfig,
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
}
|
||||
|
||||
/// Padding configuration for 2D avg pooling [config](AvgPool2dConfig).
|
||||
pub type AvgPool2dPaddingConfig = Conv2dPaddingConfig;
|
||||
|
||||
/// Applies a 2D avg pooling over input tensors.
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub struct AvgPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: AvgPool2dPaddingConfig,
|
||||
padding: PaddingConfig2d,
|
||||
}
|
||||
|
||||
impl AvgPool2dConfig {
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate as burn;
|
|||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::nn::conv::Conv2dPaddingConfig;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use burn_tensor::module::max_pool2d;
|
||||
|
@ -18,19 +18,16 @@ pub struct MaxPool2dConfig {
|
|||
#[config(default = "[1, 1]")]
|
||||
pub strides: [usize; 2],
|
||||
/// The padding configuration.
|
||||
#[config(default = "MaxPool2dPaddingConfig::Valid")]
|
||||
pub padding: MaxPool2dPaddingConfig,
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
}
|
||||
|
||||
/// Padding configuration for 2D max pooling [config](MaxPool2dConfig).
|
||||
pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig;
|
||||
|
||||
/// Applies a 2D max pooling over input tensors.
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub struct MaxPool2d {
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: MaxPool2dPaddingConfig,
|
||||
padding: PaddingConfig2d,
|
||||
}
|
||||
|
||||
impl MaxPool2dConfig {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
use burn::nn::pool::MaxPool2dPaddingConfig;
|
||||
use burn::nn::PaddingConfig2d;
|
||||
|
||||
pub trait ToTokens {
|
||||
fn to_tokens(&self) -> TokenStream;
|
||||
|
@ -32,16 +32,16 @@ impl ToTokens for usize {
|
|||
}
|
||||
}
|
||||
|
||||
/// Padding configuration for MaxPool2dPaddingConfig
|
||||
impl ToTokens for MaxPool2dPaddingConfig {
|
||||
/// Padding configuration
|
||||
impl ToTokens for PaddingConfig2d {
|
||||
fn to_tokens(&self) -> TokenStream {
|
||||
match self {
|
||||
Self::Same => quote! { MaxPool2dPaddingConfig::Same },
|
||||
Self::Valid => quote! { MaxPool2dPaddingConfig::Valid },
|
||||
Self::Same => quote! { PaddingConfig2d::Same },
|
||||
Self::Valid => quote! { PaddingConfig2d::Valid },
|
||||
Self::Explicit(padding1, padding2) => {
|
||||
let padding1 = padding1.to_tokens();
|
||||
let padding2 = padding2.to_tokens();
|
||||
quote! { MaxPool2dPaddingConfig::Explicit(#padding1, #padding2) }
|
||||
quote! { PaddingConfig2d::Explicit(#padding1, #padding2) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -181,7 +181,9 @@ mod tests {
|
|||
node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
use burn::{nn::conv::Conv2dConfig, record::FullPrecisionSettings, tensor::Data};
|
||||
use burn::{
|
||||
nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data,
|
||||
};
|
||||
use quote::quote;
|
||||
|
||||
#[test]
|
||||
|
@ -199,7 +201,7 @@ mod tests {
|
|||
TensorType::new_float("tensor4", 4),
|
||||
Data::from([2.]).serialize(),
|
||||
None,
|
||||
Conv2dConfig::new([3, 3], [3, 3]),
|
||||
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
let expected = quote! {
|
||||
|
@ -207,6 +209,7 @@ mod tests {
|
|||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
|
||||
|
@ -219,6 +222,7 @@ mod tests {
|
|||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv2d = Conv2dConfig::new([3, 3], [3, 3])
|
||||
.with_stride([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.with_groups(1)
|
||||
.with_bias(true)
|
||||
|
@ -256,7 +260,7 @@ mod tests {
|
|||
TensorType::new_float("tensor4", 4),
|
||||
Data::from([2.]).serialize(),
|
||||
None,
|
||||
Conv2dConfig::new([3, 3], [3, 3]),
|
||||
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
graph.register(MatmulNode::new(
|
||||
TensorType::new_float("tensor3", 4),
|
||||
|
@ -269,6 +273,7 @@ mod tests {
|
|||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
|
||||
|
@ -281,6 +286,7 @@ mod tests {
|
|||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv2d = Conv2dConfig::new([3, 3], [3, 3])
|
||||
.with_stride([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.with_groups(1)
|
||||
.with_bias(true)
|
||||
|
|
|
@ -63,6 +63,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for Conv2dNode<PS> {
|
|||
let stride = self.config.stride.to_tokens();
|
||||
let dilation = self.config.dilation.to_tokens();
|
||||
let groups = self.config.groups.to_tokens();
|
||||
let padding = self.config.padding.to_tokens();
|
||||
let bias = self.config.bias;
|
||||
|
||||
let init_line = match with_record {
|
||||
|
@ -77,6 +78,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for Conv2dNode<PS> {
|
|||
let tokens = quote! {
|
||||
let #name = Conv2dConfig::new(#channels, #kernel_size)
|
||||
.with_stride(#stride)
|
||||
.with_padding(#padding)
|
||||
.with_dilation(#dilation)
|
||||
.with_groups(#groups)
|
||||
.with_bias(#bias)
|
||||
|
@ -117,6 +119,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for Conv2dNode<PS> {
|
|||
}
|
||||
}
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::nn::PaddingConfig2d");
|
||||
imports.register("burn::nn::conv::Conv2d");
|
||||
imports.register("burn::nn::conv::Conv2dConfig");
|
||||
}
|
||||
|
@ -134,7 +137,9 @@ mod tests {
|
|||
node::{conv2d::Conv2dNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
use burn::{nn::conv::Conv2dConfig, record::FullPrecisionSettings, tensor::Data};
|
||||
use burn::{
|
||||
nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen() {
|
||||
|
@ -146,7 +151,7 @@ mod tests {
|
|||
TensorType::new_float("output", 4),
|
||||
Data::from([2.]).serialize(),
|
||||
None,
|
||||
Conv2dConfig::new([3, 3], [3, 3]),
|
||||
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
let expected = quote! {
|
||||
|
@ -154,6 +159,7 @@ mod tests {
|
|||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
|
||||
|
@ -166,6 +172,7 @@ mod tests {
|
|||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv2d = Conv2dConfig::new([3, 3], [3, 3])
|
||||
.with_stride([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.with_groups(1)
|
||||
.with_bias(true)
|
||||
|
|
|
@ -77,9 +77,9 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
|||
}
|
||||
}
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::nn::PaddingConfig2d");
|
||||
imports.register("burn::nn::pool::MaxPool2d");
|
||||
imports.register("burn::nn::pool::MaxPool2dConfig");
|
||||
imports.register("burn::nn::pool::MaxPool2dPaddingConfig");
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
|
@ -95,10 +95,7 @@ mod tests {
|
|||
node::{max_pool2d::MaxPool2dNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
use burn::{
|
||||
nn::pool::{MaxPool2dConfig, MaxPool2dPaddingConfig},
|
||||
record::FullPrecisionSettings,
|
||||
};
|
||||
use burn::{nn::pool::MaxPool2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings};
|
||||
|
||||
#[test]
|
||||
fn test_codegen() {
|
||||
|
@ -110,7 +107,7 @@ mod tests {
|
|||
TensorType::new_float("output", 4),
|
||||
MaxPool2dConfig::new(1, [3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(MaxPool2dPaddingConfig::Valid),
|
||||
.with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
let expected = quote! {
|
||||
|
@ -118,9 +115,9 @@ mod tests {
|
|||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::nn::pool::MaxPool2d;
|
||||
use burn::nn::pool::MaxPool2dConfig;
|
||||
use burn::nn::pool::MaxPool2dPaddingConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
|
@ -131,7 +128,7 @@ mod tests {
|
|||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let max_pool2d = MaxPool2dConfig::new(1, [3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(MaxPool2dPaddingConfig::Valid)
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.init();
|
||||
|
||||
Self {
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
use burn::nn::{
|
||||
conv::{Conv2dConfig, Conv2dPaddingConfig},
|
||||
pool::{MaxPool2dConfig, MaxPool2dPaddingConfig},
|
||||
BatchNormConfig, LinearConfig,
|
||||
conv::Conv2dConfig, pool::MaxPool2dConfig, BatchNormConfig, LinearConfig, PaddingConfig2d,
|
||||
};
|
||||
|
||||
use super::ir::{ArgType, AttributeValue, Node, StateType};
|
||||
|
@ -56,11 +54,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
|||
}
|
||||
}
|
||||
|
||||
let padding = if pads.iter().all(|&x| x == 0) {
|
||||
Conv2dPaddingConfig::Valid
|
||||
} else {
|
||||
todo!("Conv2d: padding({pads:?}) is not fully supported");
|
||||
};
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
if strides.iter().all(|&x| x != 1) {
|
||||
todo!("Conv2d: strides({strides:?}) are not fully supported");
|
||||
|
@ -99,18 +93,7 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
|
|||
}
|
||||
}
|
||||
|
||||
let padding = if pads.iter().all(|&x| x == 0) {
|
||||
MaxPool2dPaddingConfig::Valid
|
||||
} else if (pads[0] == pads[1]) == (pads[2] == pads[3]) {
|
||||
// i.e [2, 2, 2, 2]
|
||||
MaxPool2dPaddingConfig::Explicit(pads[0] as usize, pads[0] as usize)
|
||||
} else if pads[0] == pads[1] && pads[2] == pads[3] {
|
||||
// i.e [2, 2, 3, 3]
|
||||
MaxPool2dPaddingConfig::Explicit(pads[0] as usize, pads[2] as usize)
|
||||
} else {
|
||||
// All other cases, same as input
|
||||
MaxPool2dPaddingConfig::Same
|
||||
};
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
MaxPool2dConfig::new(
|
||||
channels as usize,
|
||||
|
@ -256,3 +239,18 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig {
|
|||
.with_epsilon(epsilon as f64)
|
||||
.with_momentum(momentum as f64)
|
||||
}
|
||||
|
||||
fn padding_config(pads: &[i64]) -> PaddingConfig2d {
|
||||
if pads.iter().all(|&x| x == 0) {
|
||||
PaddingConfig2d::Valid
|
||||
} else if (pads[0] == pads[1]) == (pads[2] == pads[3]) {
|
||||
// i.e [2, 2, 2, 2]
|
||||
PaddingConfig2d::Explicit(pads[0] as usize, pads[0] as usize)
|
||||
} else if pads[0] == pads[1] && pads[2] == pads[3] {
|
||||
// i.e [2, 2, 3, 3]
|
||||
PaddingConfig2d::Explicit(pads[0] as usize, pads[2] as usize)
|
||||
} else {
|
||||
// All other cases, same as input
|
||||
PaddingConfig2d::Same
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ use burn::nn::BatchNorm;
|
|||
use burn::nn::BatchNormConfig;
|
||||
use burn::nn::Linear;
|
||||
use burn::nn::LinearConfig;
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
|
@ -23,6 +24,7 @@ impl<B: Backend> Model<B> {
|
|||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv2d1 = Conv2dConfig::new([1, 8], [3, 3])
|
||||
.with_stride([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.with_groups(1)
|
||||
.with_bias(true)
|
||||
|
|
|
@ -24,10 +24,10 @@ pub struct ConvBlockConfig {
|
|||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn new(config: &ConvBlockConfig) -> Self {
|
||||
let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size)
|
||||
.with_padding(nn::conv::Conv2dPaddingConfig::Same)
|
||||
.with_padding(nn::PaddingConfig2d::Same)
|
||||
.init();
|
||||
let pool = nn::pool::MaxPool2dConfig::new(config.channels[1], config.kernel_size)
|
||||
.with_padding(nn::conv::Conv2dPaddingConfig::Same)
|
||||
.with_padding(nn::PaddingConfig2d::Same)
|
||||
.init();
|
||||
let activation = nn::GELU::new();
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, conv::Conv2dPaddingConfig, BatchNorm},
|
||||
nn::{self, BatchNorm, PaddingConfig2d},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
@ -76,7 +76,7 @@ pub struct ConvBlock<B: Backend> {
|
|||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
|
||||
let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)
|
||||
.with_padding(Conv2dPaddingConfig::Valid)
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.init();
|
||||
let norm = nn::BatchNormConfig::new(channels[1]).init();
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::data::MNISTBatch;
|
|||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, conv::Conv2dPaddingConfig, loss::CrossEntropyLoss, BatchNorm},
|
||||
nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d},
|
||||
tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
Tensor,
|
||||
|
@ -91,7 +91,7 @@ pub struct ConvBlock<B: Backend> {
|
|||
impl<B: Backend> ConvBlock<B> {
|
||||
pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self {
|
||||
let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)
|
||||
.with_padding(Conv2dPaddingConfig::Valid)
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.init();
|
||||
let norm = nn::BatchNormConfig::new(channels[1]).init();
|
||||
|
||||
|
|
Loading…
Reference in New Issue