fix AdaptiveAvgPool2D when output_size contains 0
This commit is contained in:
parent
98412c8a32
commit
5b2b0c8dad
|
@ -139,7 +139,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
output_size (Union[int, tuple]): The target output size is H x W.
|
||||
ouput_size can be a tuple, or a single H for H x H, and H x W can be int or None
|
||||
ouput_size can be a tuple, or a single H for H x H, and H and W can be int or None
|
||||
which means the output size is the same as the input.
|
||||
|
||||
Inputs:
|
||||
|
@ -147,13 +147,21 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
|
|||
with float16, float32, float64 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and same dimensions as the input_x.
|
||||
Tensor, with the same type as the `input_x`.
|
||||
Shape of the output is `input_x_shape[:len(input_x_shape) - len(out_shape)] + out_shape`.
|
||||
If output_size contains None:
|
||||
`out_shape = input_x_shape[-2] + output_size[1]`: If `output_size` is `(None, w)`
|
||||
`out_shape = output_size[1] + input_x_shape[-1]`: If `output_size` is `(h, None)`
|
||||
`out_shape = input_x_shape[-2:]: If output_size` is `(None, None)`
|
||||
If `output_size` dees not contain `None`:
|
||||
`out_shape = (h, h)`: If `output_size` is `h`
|
||||
`out_shape = (h, w)`: If `output_size` is `(h, w)`
|
||||
|
||||
Raises:
|
||||
ValueError: if `output_size` is not a tuple and if `output_size` length is not 2.
|
||||
ValueError: If `output_size` is a tuple and if `output_size` length is not 2.
|
||||
TypeError: If `input_x` is not a tensor.
|
||||
TypeError: If dtype of `input_x` is not float16, float32, float64.
|
||||
ValueError: If `input_x` dimension is less than or more than output_size dimension.
|
||||
ValueError: If `input_x` dimension is less than or equal to output_size dimension.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
@ -175,17 +183,18 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
|
|||
"""Initialize AdaptiveAvgPool2D."""
|
||||
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
|
||||
if isinstance(output_size, tuple):
|
||||
validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name)
|
||||
validator.check_int(len(output_size), 2, Rel.EQ, 'length of output_size', self.name)
|
||||
self.output_size = (output_size, output_size) if isinstance(self.output_size, int) else output_size
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if len(x_shape) <= len(self.output_size):
|
||||
raise ValueError("{} dimension should be larger than {} dimension".format(x_shape, self.output_size))
|
||||
raise ValueError("input_x {} dimension should be larger than output_size {} "
|
||||
"dimension".format(x_shape, self.output_size))
|
||||
validator.check_int(len(x_shape), 5, Rel.LT, 'input_x_dimensions', self.name)
|
||||
for input_x_dimension in x_shape:
|
||||
validator.check_int(input_x_dimension, 0, Rel.GT, 'input_x dimension', self.name)
|
||||
zipped = zip(self.output_size, x_shape[-len(self.output_size):])
|
||||
out_size = [i if i else j for i, j in zipped]
|
||||
out_size = [i if i is not None else j for i, j in zipped]
|
||||
for item in out_size:
|
||||
validator.check_value_type("item of output_size", item, [int], self.name)
|
||||
self.add_prim_attr('output_size', out_size)
|
||||
|
|
Loading…
Reference in New Issue