mirror of https://github.com/tracel-ai/burn.git
Fix checks_channels_div_groups condition and ONNX conv import with groups (#2051)
* Fix checks_channels_div_groups condition * Fix conv channels config w/ groups
This commit is contained in:
parent
0bbc1ed30f
commit
4c7353230e
|
@ -2,7 +2,7 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
|
|||
let channels_in_div_by_group = channels_in % groups == 0;
|
||||
let channels_out_div_by_group = channels_out % groups == 0;
|
||||
|
||||
if !channels_in_div_by_group && !channels_out_div_by_group {
|
||||
if !channels_in_div_by_group || !channels_out_div_by_group {
|
||||
panic!(
|
||||
"Both channels must be divisible by the number of groups. Got \
|
||||
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
|
||||
|
|
|
@ -220,6 +220,14 @@ mod tests {
|
|||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Both channels must be divisible by the number of groups."]
|
||||
fn channels_with_groups_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Conv2dConfig::new([5, 1], [5, 5]);
|
||||
|
|
|
@ -16,7 +16,7 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
|||
let mut strides = vec![1];
|
||||
let mut pads = vec![0, 0];
|
||||
let mut dilations = vec![1];
|
||||
let mut group: i64 = 1;
|
||||
let mut group: usize = 1;
|
||||
|
||||
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
|
||||
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
|
||||
|
@ -28,28 +28,28 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
|||
// check if the bias is present
|
||||
let bias = curr.inputs.len() == 3;
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels_in = shape[1];
|
||||
let channels_out = shape[0];
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
|
||||
"strides" => strides = value.clone().into_i64s(),
|
||||
"pads" => pads = value.clone().into_i64s(),
|
||||
"dilations" => dilations = value.clone().into_i64s(),
|
||||
"group" => group = value.clone().into_i64(),
|
||||
"group" => group = value.clone().into_i64() as usize,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels_in = shape[1] * group;
|
||||
let channels_out = shape[0];
|
||||
|
||||
let padding = padding_config_1d(&pads);
|
||||
|
||||
Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize)
|
||||
.with_stride(strides[0] as usize)
|
||||
.with_dilation(dilations[0] as usize)
|
||||
.with_groups(group as usize)
|
||||
.with_groups(group)
|
||||
.with_bias(bias)
|
||||
.with_padding(padding)
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
|||
let mut strides = vec![1, 1];
|
||||
let mut pads = vec![0, 0, 0, 0];
|
||||
let mut dilations = vec![1, 1];
|
||||
let mut group: i64 = 1;
|
||||
let mut group: usize = 1;
|
||||
|
||||
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
|
||||
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
|
||||
|
@ -71,21 +71,21 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
|||
// check if the bias is present
|
||||
let bias = curr.inputs.len() == 3;
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1], shape[0]];
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
|
||||
"strides" => strides = value.clone().into_i64s(),
|
||||
"pads" => pads = value.clone().into_i64s(),
|
||||
"dilations" => dilations = value.clone().into_i64s(),
|
||||
"group" => group = value.clone().into_i64(),
|
||||
"group" => group = value.clone().into_i64() as usize,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1] * group, shape[0]];
|
||||
|
||||
let padding = padding_config_2d(&pads);
|
||||
|
||||
Conv2dConfig::new(
|
||||
|
@ -94,7 +94,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
|
|||
)
|
||||
.with_stride([strides[0] as usize, strides[1] as usize])
|
||||
.with_dilation([dilations[0] as usize, dilations[1] as usize])
|
||||
.with_groups(group as usize)
|
||||
.with_groups(group)
|
||||
.with_bias(bias)
|
||||
.with_padding(padding)
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
|
|||
let mut strides = vec![1, 1, 1];
|
||||
let mut pads = vec![0, 0, 0, 0, 0, 0];
|
||||
let mut dilations = vec![1, 1, 1];
|
||||
let mut group: i64 = 1;
|
||||
let mut group: usize = 1;
|
||||
|
||||
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
|
||||
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
|
||||
|
@ -116,21 +116,21 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
|
|||
// check if the bias is present
|
||||
let bias = curr.inputs.len() == 3;
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1], shape[0]];
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
|
||||
"strides" => strides = value.clone().into_i64s(),
|
||||
"pads" => pads = value.clone().into_i64s(),
|
||||
"dilations" => dilations = value.clone().into_i64s(),
|
||||
"group" => group = value.clone().into_i64(),
|
||||
"group" => group = value.clone().into_i64() as usize,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1] * group, shape[0]];
|
||||
|
||||
let padding = padding_config_3d(&pads);
|
||||
|
||||
Conv3dConfig::new(
|
||||
|
@ -151,7 +151,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
|
|||
dilations[1] as usize,
|
||||
dilations[2] as usize,
|
||||
])
|
||||
.with_groups(group as usize)
|
||||
.with_groups(group)
|
||||
.with_bias(bias)
|
||||
.with_padding(padding)
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
|
|||
let group = attrs
|
||||
.remove("group")
|
||||
.map(AttributeValue::into_i64)
|
||||
.unwrap_or(1);
|
||||
.unwrap_or(1) as usize;
|
||||
|
||||
// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
|
||||
if !attrs.is_empty() {
|
||||
|
@ -247,7 +247,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
|
|||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1], shape[0]];
|
||||
let channels: [usize; 2] = [shape[1] * group, shape[0]];
|
||||
|
||||
ConvTranspose2dConfig::new(
|
||||
channels,
|
||||
|
@ -256,7 +256,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
|
|||
.with_stride([stride[0] as usize, stride[1] as usize])
|
||||
.with_padding([pads[0] as usize, pads[1] as usize])
|
||||
.with_dilation([dilations[0] as usize, dilations[1] as usize])
|
||||
.with_groups(group as usize)
|
||||
.with_groups(group)
|
||||
.with_bias(bias)
|
||||
}
|
||||
pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
|
||||
|
@ -280,7 +280,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
|
|||
let group = attrs
|
||||
.remove("group")
|
||||
.map(AttributeValue::into_i64)
|
||||
.unwrap_or(1);
|
||||
.unwrap_or(1) as usize;
|
||||
|
||||
// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
|
||||
if !attrs.is_empty() {
|
||||
|
@ -299,7 +299,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
|
|||
|
||||
// the channels are inverted in the weight tensor
|
||||
let shape = weight.shape.clone().unwrap();
|
||||
let channels: [usize; 2] = [shape[1], shape[0]];
|
||||
let channels: [usize; 2] = [shape[1] * group, shape[0]];
|
||||
|
||||
ConvTranspose3dConfig::new(
|
||||
channels,
|
||||
|
@ -316,7 +316,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
|
|||
dilations[1] as usize,
|
||||
dilations[2] as usize,
|
||||
])
|
||||
.with_groups(group as usize)
|
||||
.with_groups(group)
|
||||
.with_bias(bias)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue