From 039c75af8ec4afb45990ee8e470e73ad80003e29 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 30 Apr 2020 06:45:41 -0400 Subject: [PATCH] fix avgpool and add check --- mindspore/nn/layer/normalization.py | 12 +++++++++- mindspore/nn/layer/pooling.py | 37 +++++++++++++++++------------ tests/ut/python/nn/test_pooling.py | 4 ++-- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 7a102b0bbe9..90a7ad788fc 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -182,6 +182,10 @@ def _channel_check(channel, num_channel): raise ValueError("the input channel is not equal with num_channel") @constexpr +def _shape_check(in_shape): + if len(in_shape) != 4: + raise ValueError("The input must has 4 dims") +@constexpr def _shape_infer(x_shape, num_feature): """global batch normalization shape and axes infer""" if len(x_shape) == 4: @@ -536,7 +540,8 @@ class GroupNorm(Cell): self.reduce_sum = P.ReduceSum(keep_dims=True) self.sqrt = P.Sqrt() - def construct(self, x): + def _cal_output(self, x): + """calculate groupnorm output""" batch, channel, height, width = self.shape(x) _channel_check(channel, self.num_channels) x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) @@ -548,6 +553,11 @@ class GroupNorm(Cell): output = x * self.gamma + self.beta return output + def construct(self, x): + _shape_check(self.shape(x)) + output = self._cal_output(x) + return output + def extend_repr(self): """Display instance object as string.""" s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 0569a8ada6e..b0c0816ca45 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -16,6 +16,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr from ... import context from ..cell import Cell from ..._checkparam import Rel @@ -52,7 +53,10 @@ class _PoolNd(Cell): def extend_repr(self): return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__) - +@constexpr +def _shape_check(in_shape): + if len(in_shape) != 3: + raise ValueError("The input must has 3 dim") class MaxPool2d(_PoolNd): r""" @@ -218,13 +222,13 @@ class AvgPool1d(_PoolNd): Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes. - Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs - regional average in the :math:`(W_{in})`-dimension. Given kernel size - :math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows. + Typically the input is of shape :math:`(N_{in}, C_{in}, L_{in})`, AvgPool1d outputs + regional average in the :math:`(L_{in})`-dimension. Given kernel size + :math:`ks = l_{ker}` and stride :math:`s = s_0`, the operation is as follows. .. math:: - \text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1} - \text{input}(N_i, C_j, h_k, s_0 \times w + n) + \text{output}(N_i, C_j, l) = \frac{1}{l_{ker}} \sum_{n=0}^{l_{ker}-1} + \text{input}(N_i, C_j, s_0 \times l + n) Note: pad_mode for training only supports "same" and "valid". @@ -246,17 +250,17 @@ class AvgPool1d(_PoolNd): Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, L_{in})`. Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + Tensor of shape :math:`(N, C_{out}, L_{out})`. Examples: - >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) - >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) + >>> pool = nn.AvgPool1d(kernel_size=6, strides=1) + >>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32) >>> output = pool(x) >>> output.shape() - (1, 2, 4, 2) + (1, 3, 1) """ def __init__(self, @@ -277,14 +281,17 @@ class AvgPool1d(_PoolNd): self.shape = F.shape self.reduce_mean = P.ReduceMean(keep_dims=True) self.slice = P.Slice() + self.expand = P.ExpandDims() def construct(self, x): - batch, channel, high, width = self.shape(x) + _shape_check(self.shape(x)) + batch, channel, width = self.shape(x) if width == self.kernel_size[1]: - x = self.reduce_mean(x, 3) + x = self.reduce_mean(x, 2) elif width - self.kernel_size[1] < self.stride[1]: - x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1])) - x = self.reduce_mean(x, 3) + x = self.slice(x, (0, 0, 0), (batch, channel, self.kernel_size[1])) + x = self.reduce_mean(x, 2) else: + x = self.expand(x, 2) x = self.avg_pool(x) return x diff --git a/tests/ut/python/nn/test_pooling.py b/tests/ut/python/nn/test_pooling.py index 428e050ea2a..863ffc555ff 100644 --- a/tests/ut/python/nn/test_pooling.py +++ b/tests/ut/python/nn/test_pooling.py @@ -69,6 +69,6 @@ class Avg1dNet(nn.Cell): return self.avg1d(x) def test_avg1d(): - net = Avg1dNet(3, 1) - input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) + net = Avg1dNet(6, 1) + input = Tensor(np.random.randint(0, 255, [1, 3, 6]).astype(np.float32)) _executor.compile(net, input) \ No newline at end of file