!3425 fix avgpoolgrad

Merge pull request !3425 from fangzehua/avgpool
This commit is contained in:
mindspore-ci-bot 2020-07-27 14:15:29 +08:00 committed by Gitee
commit 568da0d510
10 changed files with 189 additions and 30 deletions

View File

@ -46,6 +46,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"reduce_max", "reduce_max_d"},
{"reduce_min", "reduce_min_d"},
{"avg_pool_grad", "avg_pool_grad_d"},
{"avg_pool_grad_vm", "avg_pool_grad_d"},
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
{"depthwise_conv2d_native", "depthwise_conv2d"},

View File

@ -26,6 +26,7 @@ namespace opt {
ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimCast->name(), {1});
Register(prim::kPrimAvgPoolGrad->name(), {0});
Register(prim::kPrimAvgPoolGradVm->name(), {0});
Register(prim::kPrimConv2DBackpropInput->name(), {2});
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1});

View File

@ -128,6 +128,7 @@ inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");

View File

@ -33,7 +33,6 @@ from .activation import get_activation
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']

View File

@ -14,7 +14,10 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
import numpy as np
from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor
from .grad_base import bprop_getters
from .. import functional as F
from .. import operations as P
@ -24,7 +27,6 @@ from ..operations import _inner_ops as inner
from ... import context
@bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation."""
@ -195,33 +197,133 @@ def get_bprop_max_pool_grad(self):
return bprop
def _windowed_output_size(input_size, ksize, stride, padding):
"""
helper func for AvgPoolGrad
"""
tmp_output = 0
tmp_pad_need = 0
tmp_pad_before = 0
tmp_pad_after = 0
if padding == 'VALID':
tmp_output = (input_size - ksize + stride) // stride
tmp_pad_before = 0
tmp_pad_after = 0
elif padding == 'SAME':
tmp_output = (input_size + stride - 1) // stride
tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size)
tmp_pad_before = tmp_pad_need // 2
tmp_pad_after = tmp_pad_need - tmp_pad_before
return tmp_output, tmp_pad_before, tmp_pad_after
@constexpr
def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype):
"""
helper func for AvgPoolGrad.
`assist_input_matrix` is a 2d matrix with input_shape after padding,
the value of element which is padded is 0, else are 1.
For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize,
w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
number of input that assosiate with output element.
"""
n_input, c_input, h_input, w_input = x_shape
h_ksize, w_ksize = ksize[2], ksize[3]
h_stride, w_stride = stride[2], stride[3]
n_output = n_input
c_output = c_input
h_output, w_output = 0, 0
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize,
h_stride, padding)
w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize,
w_stride, padding)
output_size = n_output * c_output * h_output * w_output
output_shape = (n_output, c_output, h_output, w_output)
output = np.array([0.0] * output_size)
output = np.reshape(output, output_shape)
in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right)
assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32)
if pad_top > 0:
assist_input_matrix[:pad_top, :] = 0
if pad_bottom > 0:
assist_input_matrix[-pad_bottom:, :] = 0
if pad_left > 0:
assist_input_matrix[:, :pad_left] = 0
if pad_right > 0:
assist_input_matrix[:, -pad_right:] = 0
for h in range(h_output):
for w in range(w_output):
curr_input = assist_input_matrix[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize]
curr_sum = np.sum(curr_input)
if curr_sum > 0:
output[:, :, h, w] = 1. / curr_sum
return Tensor(output, x_dtype)
@constexpr
def _get_kernel_matrix(kernel_matrix_shape, x_dtype):
kernel_matrix = np.ones(kernel_matrix_shape)
return Tensor(kernel_matrix, x_dtype)
@bprop_getters.register(P.AvgPool)
def get_bprop_avg_pool_grad(self):
"""Grad definition for `AvgPool` operation."""
avgpool_grad = G.AvgPoolGrad(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
shape_op = P.Shape()
avgpool_grad_gpu = G.AvgPoolGradGpu(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
def bprop(x, out, dout):
dx = avgpool_grad(shape_op(x), dout)
return (dx,)
def bprop_gpu(x, out, dout):
dx = avgpool_grad_gpu(x, out, dout)
return (dx,)
# the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
if self.target == "GPU":
avgpool_grad_gpu = G.AvgPoolGradGpu(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
def bprop_gpu(x, out, dout):
dx = avgpool_grad_gpu(x, out, dout)
return (dx,)
bprop_fn = bprop_gpu
elif self.target == "GE":
avgpool_grad_ge = G.AvgPoolGrad(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
shape_op = P.Shape()
def bprop_ge(x, out, dout):
dx = avgpool_grad_ge(shape_op(x), dout)
return (dx,)
bprop_fn = bprop_ge
else:
bprop_fn = bprop
avgpool_grad_vm = G.AvgPoolGradVm(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
k_size_nchw = avgpool_grad_vm.ksize
stride_nchw = avgpool_grad_vm.strides
padding = self.padding
def bprop_vm(x, out, dout):
x_shape_nchw = F.shape(x)
x_dtype = F.dtype(x)
kernel_matrix_shape = (1, x_shape_nchw[1],
k_size_nchw[2],
k_size_nchw[3])
mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, padding, x_dtype)
kernel_matrix = _get_kernel_matrix(kernel_matrix_shape, x_dtype)
dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix)
return (dx,)
bprop_fn = bprop_vm
return bprop_fn

View File

@ -196,6 +196,7 @@ from .floor_mod import _floor_mod_tbe
from .scatter_nd_update import _scatter_nd_update_tbe
from .avg_pool import _avg_pool_tbe
from .avg_pool_grad import _avg_pool_grad_tbe
from .avg_pool_grad_vm import _avg_pool_grad_vm_tbe
from .ones_like import _ones_like_tbe
from .batch_to_space import _batch_to_space_tbe
from .space_to_batch import _space_to_batch_tbe

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AvgPoolGradVm op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
avg_pool_grad_vm_op_info = TBERegOp("AvgPoolGradVm") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("avg_pool_grad_d.so") \
.compute_cost(10) \
.kernel_name("avg_pool_grad_d") \
.partial_flag(True) \
.attr("x_origin", "required", "listInt", "all") \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "input_grad", False, "required", "all") \
.input(1, "mean_matrix", False, "optional", "all") \
.input(2, "kernel_matrix", False, "optional", "all") \
.output(0, "out_grad", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \
.get_op_info()
@op_info_register(avg_pool_grad_vm_op_info)
def _avg_pool_grad_vm_tbe():
"""AvgPoolGradVm TBE register"""
return

View File

@ -23,7 +23,6 @@ from .._utils import get_concat_offset
from ...common import dtype as mstype
from .. import functional as F
class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation."""
@ -492,7 +491,7 @@ class _PoolGrad(PrimitiveWithInfer):
class AvgPoolGrad(_PoolGrad):
"""Gradients of the avg pool operation."""
"""Gradients of the avg pool operation for ge."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
@ -508,6 +507,24 @@ class AvgPoolGrad(_PoolGrad):
return out
class AvgPoolGradVm(_PoolGrad):
"""Gradients of the avg pool operation for vm."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(AvgPoolGradVm, self).__init__(ksize, strides, padding)
self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
out = {
'value': None,
'shape': tuple(origin_input['value']),
'dtype': dout['dtype'],
}
return out
class AvgPoolGradGpu(_PoolGrad):
"""Gradients of the avg pool operation for gpu."""

View File

@ -1276,6 +1276,8 @@ class AvgPool(_Pool):
def __init__(self, ksize=1, strides=1, padding="valid"):
if context.get_context("device_target") == "GPU":
self.target = "GPU"
elif context.get_context("enable_ge"):
self.target = "GE"
else:
self.target = "OTHER"
super(AvgPool, self).__init__(ksize, strides, padding)

View File

@ -1316,13 +1316,6 @@ test_case_nn_ops = [
'block': P.AvgPool(ksize=(2, 2), strides=(2, 2), padding="VALID"),
'desc_inputs': [[100, 3, 28, 28]],
'desc_bprop': [[100, 3, 14, 14]]}),
('AvgPoolGrad', {
'block': G.AvgPoolGrad(ksize=(2, 2), strides=(2, 2), padding="VALID"),
'desc_const': [(3, 4, 6, 6)],
'const_first': True,
'desc_inputs': [[3, 4, 6, 6]],
'desc_bprop': [[3, 4, 6, 6]],
'skip': ['backward']}),
('MaxPoolWithArgmax', {
'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
'desc_inputs': [[128, 32, 32, 64]],