!37250 Fix the mismatch between input shape and format for pooling grad
Merge pull request !37250 from 范吉斌/fix_maxpool
This commit is contained in:
commit
5944f29e4f
|
@ -175,15 +175,15 @@ bool PoolingGradGpuKernelMod::InitShape(const std::vector<KernelTensorPtr> &inpu
|
|||
int *strideAout, int nbDims) {
|
||||
ShapeVector dout_shape, input_mask, output_shape, input_shape;
|
||||
if (kernel_name_ == kAvgPool3DGrad) {
|
||||
dout_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
dout_shape = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
output_shape = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
input_mask = dout_shape;
|
||||
input_shape = output_shape;
|
||||
} else {
|
||||
input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
input_mask = inputs.at(kIndex1)->GetShapeVector();
|
||||
dout_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
input_shape = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
input_mask = inputs.at(kIndex1)->GetDeviceShapeAdaptively();
|
||||
dout_shape = inputs.at(kIndex2)->GetDeviceShapeAdaptively();
|
||||
output_shape = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
}
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(input_mask, kernel_name_, "mask") ||
|
||||
|
|
Loading…
Reference in New Issue