!37250 Fix the mismatch between input shape and format for pooling grad

Merge pull request !37250 from 范吉斌/fix_maxpool
This commit is contained in:
i-robot 2022-07-05 08:23:41 +00:00 committed by Gitee
commit 5944f29e4f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 6 deletions

View File

@ -175,15 +175,15 @@ bool PoolingGradGpuKernelMod::InitShape(const std::vector<KernelTensorPtr> &inpu
int *strideAout, int nbDims) {
ShapeVector dout_shape, input_mask, output_shape, input_shape;
if (kernel_name_ == kAvgPool3DGrad) {
dout_shape = inputs.at(kIndex0)->GetShapeVector();
output_shape = outputs.at(kIndex0)->GetShapeVector();
dout_shape = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
output_shape = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
input_mask = dout_shape;
input_shape = output_shape;
} else {
input_shape = inputs.at(kIndex0)->GetShapeVector();
input_mask = inputs.at(kIndex1)->GetShapeVector();
dout_shape = inputs.at(kIndex2)->GetShapeVector();
output_shape = outputs.at(kIndex0)->GetShapeVector();
input_shape = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
input_mask = inputs.at(kIndex1)->GetDeviceShapeAdaptively();
dout_shape = inputs.at(kIndex2)->GetDeviceShapeAdaptively();
output_shape = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
}
is_null_input_ =
CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(input_mask, kernel_name_, "mask") ||