!18682 Fix adaptiveavgpool2d accuracy error when input H is not equal to W
Merge pull request !18682 from zuochuanyong/fix_AdaptiveAvgPool2D
This commit is contained in:
commit
d90cfd2aaa
|
@ -61,8 +61,8 @@ class AdaptiveAvgPool2DKernel : public GpuKernel {
|
|||
output_height = shape_addr[0];
|
||||
output_width = shape_addr[0];
|
||||
} else if (shape_addr.size() == 2) {
|
||||
output_height = static_cast<uint>(shape_addr[1]);
|
||||
output_width = static_cast<uint>(shape_addr[0]);
|
||||
output_height = static_cast<uint>(shape_addr[0]);
|
||||
output_width = static_cast<uint>(shape_addr[1]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Input Error.";
|
||||
return false;
|
||||
|
@ -79,8 +79,8 @@ class AdaptiveAvgPool2DKernel : public GpuKernel {
|
|||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
len = static_cast<uint>(input_shape.size());
|
||||
input_height = static_cast<uint>(input_shape[len - 1]);
|
||||
input_width = static_cast<uint>(input_shape[len - 2]);
|
||||
input_height = static_cast<uint>(input_shape[len - 2]);
|
||||
input_width = static_cast<uint>(input_shape[len - 1]);
|
||||
size = static_cast<uint>(len == 3 ? input_shape[0] : input_shape[0] * input_shape[1]);
|
||||
for (uint i = 0; i < len; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
|
|
Loading…
Reference in New Issue