fix AdaptiveAvgPool2D when output_size contains 0

This commit is contained in:
zuochuanyong 2021-06-29 19:29:54 +08:00
parent 98412c8a32
commit 5b2b0c8dad
1 changed files with 16 additions and 7 deletions

View File

@ -139,7 +139,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
Args:
output_size (Union[int, tuple]): The target output size is H x W.
ouput_size can be a tuple, or a single H for H x H, and H x W can be int or None
ouput_size can be a tuple, or a single H for H x H, and H and W can be int or None
which means the output size is the same as the input.
Inputs:
@ -147,13 +147,21 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
with float16, float32, float64 data type.
Outputs:
Tensor, with the same type and same dimensions as the input_x.
Tensor, with the same type as the `input_x`.
Shape of the output is `input_x_shape[:len(input_x_shape) - len(out_shape)] + out_shape`.
If output_size contains None:
`out_shape = input_x_shape[-2] + output_size[1]`: If `output_size` is `(None, w)`
`out_shape = output_size[1] + input_x_shape[-1]`: If `output_size` is `(h, None)`
`out_shape = input_x_shape[-2:]: If output_size` is `(None, None)`
If `output_size` dees not contain `None`:
`out_shape = (h, h)`: If `output_size` is `h`
`out_shape = (h, w)`: If `output_size` is `(h, w)`
Raises:
ValueError: if `output_size` is not a tuple and if `output_size` length is not 2.
ValueError: If `output_size` is a tuple and if `output_size` length is not 2.
TypeError: If `input_x` is not a tensor.
TypeError: If dtype of `input_x` is not float16, float32, float64.
ValueError: If `input_x` dimension is less than or more than output_size dimension.
ValueError: If `input_x` dimension is less than or equal to output_size dimension.
Supported Platforms:
``GPU``
@ -175,17 +183,18 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
"""Initialize AdaptiveAvgPool2D."""
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
if isinstance(output_size, tuple):
validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name)
validator.check_int(len(output_size), 2, Rel.EQ, 'length of output_size', self.name)
self.output_size = (output_size, output_size) if isinstance(self.output_size, int) else output_size
def infer_shape(self, x_shape):
if len(x_shape) <= len(self.output_size):
raise ValueError("{} dimension should be larger than {} dimension".format(x_shape, self.output_size))
raise ValueError("input_x {} dimension should be larger than output_size {} "
"dimension".format(x_shape, self.output_size))
validator.check_int(len(x_shape), 5, Rel.LT, 'input_x_dimensions', self.name)
for input_x_dimension in x_shape:
validator.check_int(input_x_dimension, 0, Rel.GT, 'input_x dimension', self.name)
zipped = zip(self.output_size, x_shape[-len(self.output_size):])
out_size = [i if i else j for i, j in zipped]
out_size = [i if i is not None else j for i, j in zipped]
for item in out_size:
validator.check_value_type("item of output_size", item, [int], self.name)
self.add_prim_attr('output_size', out_size)