forked from mindspore-Ecosystem/mindspore
!10575 fix pooling max grad
From: @wuxuejian Reviewed-by: @liangchenghui,@oacjiewen Signed-off-by: @liangchenghui
This commit is contained in:
commit
45d6dc716f
|
@ -117,15 +117,12 @@ bool MaxPoolingGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
|
|||
size_t src_wh = src_shape_[2] * src_shape_[3];
|
||||
size_t dst_wh = dst_shape_[2] * dst_shape_[3];
|
||||
for (size_t n = 0; n < src_shape_[0]; ++n) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t c = start; c < end; ++c) {
|
||||
ChannelPoolingGrad(input, diff, output);
|
||||
input = input + src_wh;
|
||||
output = output + src_wh;
|
||||
diff = diff + dst_wh;
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, src_shape_[1]);
|
||||
for (size_t c = 0; c < src_shape_[1]; ++c) {
|
||||
ChannelPoolingGrad(input, diff, output);
|
||||
input = input + src_wh;
|
||||
output = output + src_wh;
|
||||
diff = diff + dst_wh;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue