mirror of https://github.com/tracel-ai/burn.git
parent
cb616ed72c
commit
445f41bb7b
|
@ -20,7 +20,7 @@ pub struct AvgPool2dConfig {
|
|||
pub padding: PaddingConfig2d,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
#[config(default = "true")]
|
||||
count_include_pad: bool,
|
||||
pub count_include_pad: bool,
|
||||
}
|
||||
|
||||
/// Applies a 2D avg pooling over input tensors.
|
||||
|
|
Binary file not shown.
|
@ -10,14 +10,20 @@ class Model(nn.Module):
|
|||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
# TODO when https://github.com/burn-rs/burn/issues/636 is resolved, test this with a model
|
||||
# that uses `count_include_pad=False` and padding=(2, 1)
|
||||
self.pool2d = nn.AvgPool2d((4, 2), stride=(
|
||||
2, 1), padding=(0, 0), count_include_pad=False)
|
||||
self.pool2d1 = nn.AvgPool2d((4, 2), stride=(
|
||||
2, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool2d(x)
|
||||
return x
|
||||
self.pool2d2 = nn.AvgPool2d((4, 2), stride=(
|
||||
2, 1), padding=(2, 1), count_include_pad=True)
|
||||
|
||||
self.pool2d3 = nn.AvgPool2d((4, 2), stride=(
|
||||
2, 1), padding=(2, 1), count_include_pad=False)
|
||||
|
||||
def forward(self, x1, x2, x3):
|
||||
y1 = self.pool2d1(x1)
|
||||
y2 = self.pool2d2(x2)
|
||||
y3 = self.pool2d3(x3)
|
||||
return y1, y2, y3
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -33,18 +39,22 @@ def main():
|
|||
device = torch.device("cpu")
|
||||
|
||||
file_name = "avg_pool2d.onnx"
|
||||
test_input = torch.randn(1, 1, 5, 5, device=device)
|
||||
torch.onnx.export(model, test_input, file_name,
|
||||
input1 = torch.randn(1, 1, 5, 5, device=device)
|
||||
torch.onnx.export(model, (input1, input1, input1), file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data shape of ones: {}".format(test_input.shape))
|
||||
print("Test input data of ones: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data shape: {}".format(output.shape))
|
||||
print("Test output: {}".format(output))
|
||||
print("Test input data shape: {}".format(input1.shape))
|
||||
print("Test input data: {}".format(input1))
|
||||
output1, output2, output3 = model.forward(input1, input1, input1)
|
||||
print("Test output1 data shape: {}".format(output1.shape))
|
||||
print("Test output2 data shape: {}".format(output2.shape))
|
||||
print("Test output3 data shape: {}".format(output3.shape))
|
||||
print("Test output1: {}".format(output1))
|
||||
print("Test output2: {}".format(output2))
|
||||
print("Test output3: {}".format(output3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -362,10 +362,30 @@ mod tests {
|
|||
[-1.805, -0.476, 0.205, 0.338, 1.353],
|
||||
[0.374, 0.013, 0.774, -0.109, -0.271],
|
||||
]]]);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
|
||||
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
|
||||
let expected1 = Data::from([[[[0.008, -0.131, -0.208, 0.425]]]]);
|
||||
let expected2 = Data::from([[[
|
||||
[-0.045, 0.202, -0.050, -0.295, 0.162, 0.160],
|
||||
[-0.176, 0.008, -0.131, -0.208, 0.425, 0.319],
|
||||
[-0.084, -0.146, 0.017, 0.170, 0.216, 0.125],
|
||||
]]]);
|
||||
let expected3 = Data::from([[[
|
||||
[-0.182, 0.404, -0.100, -0.590, 0.324, 0.638],
|
||||
[-0.352, 0.008, -0.131, -0.208, 0.425, 0.638],
|
||||
[-0.224, -0.195, 0.023, 0.226, 0.288, 0.335],
|
||||
]]]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 3);
|
||||
let expected_shape1 = Shape::from([1, 1, 1, 4]);
|
||||
let expected_shape2 = Shape::from([1, 1, 3, 6]);
|
||||
let expected_shape3 = Shape::from([1, 1, 3, 6]);
|
||||
|
||||
assert_eq!(output1.shape(), expected_shape1);
|
||||
assert_eq!(output2.shape(), expected_shape2);
|
||||
assert_eq!(output3.shape(), expected_shape3);
|
||||
|
||||
output1.to_data().assert_approx_eq(&expected1, 3);
|
||||
output2.to_data().assert_approx_eq(&expected2, 3);
|
||||
output3.to_data().assert_approx_eq(&expected3, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -51,6 +51,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
|
|||
let kernel_size = self.config.kernel_size.to_tokens();
|
||||
let strides = self.config.strides.to_tokens();
|
||||
let padding = self.config.padding.to_tokens();
|
||||
let count_include_pad = self.config.count_include_pad;
|
||||
|
||||
let init_line = quote! {
|
||||
init();
|
||||
|
@ -60,6 +61,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool2dNode {
|
|||
let #name = AvgPool2dConfig::new(#kernel_size)
|
||||
.with_strides(#strides)
|
||||
.with_padding(#padding)
|
||||
.with_count_include_pad(#count_include_pad)
|
||||
.#init_line
|
||||
};
|
||||
|
||||
|
@ -137,6 +139,7 @@ mod tests {
|
|||
let avg_pool2d = AvgPool2dConfig::new([3, 3])
|
||||
.with_strides([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_count_include_pad(true)
|
||||
.init();
|
||||
|
||||
Self {
|
||||
|
|
|
@ -129,6 +129,7 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
|
|||
let mut strides = vec![1, 1];
|
||||
let mut pads = vec![0, 0, 0, 0];
|
||||
let mut count_include_pad: i64 = 0;
|
||||
let mut ceil_mode: i64 = 0;
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
|
@ -136,19 +137,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
|
|||
"strides" => strides = value.clone().into_i64s(),
|
||||
"pads" => pads = value.clone().into_i64s(),
|
||||
"count_include_pad" => count_include_pad = value.clone().into_i64(),
|
||||
"ceil_mode" => ceil_mode = value.clone().into_i64(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
if count_include_pad == 1 && padding != PaddingConfig2d::Valid {
|
||||
todo!("AvgPool2d: count_include_pad is not supported. See https://github.com/burn-rs/burn/issues/636");
|
||||
if ceil_mode == 1 {
|
||||
panic!("ceil_mode is not supported");
|
||||
}
|
||||
|
||||
let padding = padding_config(&pads);
|
||||
|
||||
AvgPool2dConfig::new([kernel_shape[0] as usize, kernel_shape[1] as usize])
|
||||
.with_strides([strides[0] as usize, strides[1] as usize])
|
||||
.with_padding(padding)
|
||||
.with_count_include_pad(count_include_pad == 1)
|
||||
}
|
||||
|
||||
/// Create a FlattenConfig from the attributes of the node
|
||||
|
|
Loading…
Reference in New Issue