!18682 Fix adaptiveavgpool2d accuracy error when input H is not equal to W

Merge pull request !18682 from zuochuanyong/fix_AdaptiveAvgPool2D
This commit is contained in:
i-robot 2021-06-22 07:28:05 +00:00 committed by Gitee
commit d90cfd2aaa
1 changed files with 4 additions and 4 deletions

View File

@ -61,8 +61,8 @@ class AdaptiveAvgPool2DKernel : public GpuKernel {
output_height = shape_addr[0];
output_width = shape_addr[0];
} else if (shape_addr.size() == 2) {
output_height = static_cast<uint>(shape_addr[1]);
output_width = static_cast<uint>(shape_addr[0]);
output_height = static_cast<uint>(shape_addr[0]);
output_width = static_cast<uint>(shape_addr[1]);
} else {
MS_LOG(ERROR) << "Input Error.";
return false;
@ -79,8 +79,8 @@ class AdaptiveAvgPool2DKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
len = static_cast<uint>(input_shape.size());
input_height = static_cast<uint>(input_shape[len - 1]);
input_width = static_cast<uint>(input_shape[len - 2]);
input_height = static_cast<uint>(input_shape[len - 2]);
input_width = static_cast<uint>(input_shape[len - 1]);
size = static_cast<uint>(len == 3 ? input_shape[0] : input_shape[0] * input_shape[1]);
for (uint i = 0; i < len; i++) {
input_size_ *= input_shape[i];