Fix checks_channels_div_groups condition and ONNX conv import with groups (#2051)

* Fix checks_channels_div_groups condition

* Fix conv channels config w/ groups
This commit is contained in:
Guillaume Lagrange 2024-07-22 13:53:48 -04:00 committed by GitHub
parent 0bbc1ed30f
commit 4c7353230e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 29 deletions

View File

@ -2,7 +2,7 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
let channels_in_div_by_group = channels_in % groups == 0;
let channels_out_div_by_group = channels_out % groups == 0;
if !channels_in_div_by_group && !channels_out_div_by_group {
if !channels_in_div_by_group || !channels_out_div_by_group {
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"

View File

@ -220,6 +220,14 @@ mod tests {
assert_eq!(config.initializer, init);
}
#[test]
#[should_panic = "Both channels must be divisible by the number of groups."]
fn channels_with_groups_is_invalid() {
let device = Default::default();
let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
let _ = config.init::<TestBackend>(&device);
}
#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);

View File

@ -16,7 +16,7 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
let mut strides = vec![1];
let mut pads = vec![0, 0];
let mut dilations = vec![1];
let mut group: i64 = 1;
let mut group: usize = 1;
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
@ -28,28 +28,28 @@ pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels_in = shape[1];
let channels_out = shape[0];
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels_in = shape[1] * group;
let channels_out = shape[0];
let padding = padding_config_1d(&pads);
Conv1dConfig::new(channels_in, channels_out, kernel_shape[0] as usize)
.with_stride(strides[0] as usize)
.with_dilation(dilations[0] as usize)
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
@ -60,7 +60,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
let mut strides = vec![1, 1];
let mut pads = vec![0, 0, 0, 0];
let mut dilations = vec![1, 1];
let mut group: i64 = 1;
let mut group: usize = 1;
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
@ -71,21 +71,21 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1] * group, shape[0]];
let padding = padding_config_2d(&pads);
Conv2dConfig::new(
@ -94,7 +94,7 @@ pub fn conv2d_config(curr: &Node) -> Conv2dConfig {
)
.with_stride([strides[0] as usize, strides[1] as usize])
.with_dilation([dilations[0] as usize, dilations[1] as usize])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
@ -105,7 +105,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
let mut strides = vec![1, 1, 1];
let mut pads = vec![0, 0, 0, 0, 0, 0];
let mut dilations = vec![1, 1, 1];
let mut group: i64 = 1;
let mut group: usize = 1;
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
let weight = if let ArgType::Tensor(ref weight) = curr.inputs[1].ty {
@ -116,21 +116,21 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
// check if the bias is present
let bias = curr.inputs.len() == 3;
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1] * group, shape[0]];
let padding = padding_config_3d(&pads);
Conv3dConfig::new(
@ -151,7 +151,7 @@ pub fn conv3d_config(curr: &Node) -> Conv3dConfig {
dilations[1] as usize,
dilations[2] as usize,
])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
.with_padding(padding)
}
@ -228,7 +228,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let group = attrs
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1);
.unwrap_or(1) as usize;
// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
@ -247,7 +247,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
let channels: [usize; 2] = [shape[1] * group, shape[0]];
ConvTranspose2dConfig::new(
channels,
@ -256,7 +256,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
.with_stride([stride[0] as usize, stride[1] as usize])
.with_padding([pads[0] as usize, pads[1] as usize])
.with_dilation([dilations[0] as usize, dilations[1] as usize])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
}
pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
@ -280,7 +280,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
let group = attrs
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1);
.unwrap_or(1) as usize;
// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
@ -299,7 +299,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
// the channels are inverted in the weight tensor
let shape = weight.shape.clone().unwrap();
let channels: [usize; 2] = [shape[1], shape[0]];
let channels: [usize; 2] = [shape[1] * group, shape[0]];
ConvTranspose3dConfig::new(
channels,
@ -316,7 +316,7 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
dilations[1] as usize,
dilations[2] as usize,
])
.with_groups(group as usize)
.with_groups(group)
.with_bias(bias)
}