forked from mindspore-Ecosystem/mindspore
!904 fix avgpool and add dimension check for groupnorm
Merge pull request !904 from JichenZhao/groupn
This commit is contained in:
commit
108ef72aaf
|
@ -185,6 +185,10 @@ def _channel_check(channel, num_channel):
|
||||||
raise ValueError("the input channel is not equal with num_channel")
|
raise ValueError("the input channel is not equal with num_channel")
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
def _shape_check(in_shape):
|
||||||
|
if len(in_shape) != 4:
|
||||||
|
raise ValueError("The input must has 4 dims")
|
||||||
|
@constexpr
|
||||||
def _shape_infer(x_shape, num_feature):
|
def _shape_infer(x_shape, num_feature):
|
||||||
"""global batch normalization shape and axes infer"""
|
"""global batch normalization shape and axes infer"""
|
||||||
if len(x_shape) == 4:
|
if len(x_shape) == 4:
|
||||||
|
@ -539,7 +543,8 @@ class GroupNorm(Cell):
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
|
|
||||||
def construct(self, x):
|
def _cal_output(self, x):
|
||||||
|
"""calculate groupnorm output"""
|
||||||
batch, channel, height, width = self.shape(x)
|
batch, channel, height, width = self.shape(x)
|
||||||
_channel_check(channel, self.num_channels)
|
_channel_check(channel, self.num_channels)
|
||||||
x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
|
x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
|
||||||
|
@ -551,6 +556,11 @@ class GroupNorm(Cell):
|
||||||
output = x * self.gamma + self.beta
|
output = x * self.gamma + self.beta
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
_shape_check(self.shape(x))
|
||||||
|
output = self._cal_output(x)
|
||||||
|
return output
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
"""Display instance object as string."""
|
"""Display instance object as string."""
|
||||||
s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
from ... import context
|
from ... import context
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
|
@ -52,7 +53,10 @@ class _PoolNd(Cell):
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
|
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
|
||||||
|
@constexpr
|
||||||
|
def _shape_check(in_shape):
|
||||||
|
if len(in_shape) != 3:
|
||||||
|
raise ValueError("The input must has 3 dim")
|
||||||
|
|
||||||
class MaxPool2d(_PoolNd):
|
class MaxPool2d(_PoolNd):
|
||||||
r"""
|
r"""
|
||||||
|
@ -218,13 +222,13 @@ class AvgPool1d(_PoolNd):
|
||||||
|
|
||||||
Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes.
|
Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes.
|
||||||
|
|
||||||
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs
|
Typically the input is of shape :math:`(N_{in}, C_{in}, L_{in})`, AvgPool1d outputs
|
||||||
regional average in the :math:`(W_{in})`-dimension. Given kernel size
|
regional average in the :math:`(L_{in})`-dimension. Given kernel size
|
||||||
:math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows.
|
:math:`ks = l_{ker}` and stride :math:`s = s_0`, the operation is as follows.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1}
|
\text{output}(N_i, C_j, l) = \frac{1}{l_{ker}} \sum_{n=0}^{l_{ker}-1}
|
||||||
\text{input}(N_i, C_j, h_k, s_0 \times w + n)
|
\text{input}(N_i, C_j, s_0 \times l + n)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
pad_mode for training only supports "same" and "valid".
|
pad_mode for training only supports "same" and "valid".
|
||||||
|
@ -246,17 +250,17 @@ class AvgPool1d(_PoolNd):
|
||||||
|
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, L_{in})`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
Tensor of shape :math:`(N, C_{out}, L_{out})`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> pool = nn.AvgPool1d(kernel_size=3, strides=1)
|
>>> pool = nn.AvgPool1d(kernel_size=6, strides=1)
|
||||||
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
|
>>> x = Tensor(np.random.randint(0, 10, [1, 3, 6]), mindspore.float32)
|
||||||
>>> output = pool(x)
|
>>> output = pool(x)
|
||||||
>>> output.shape()
|
>>> output.shape()
|
||||||
(1, 2, 4, 2)
|
(1, 3, 1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -277,14 +281,17 @@ class AvgPool1d(_PoolNd):
|
||||||
self.shape = F.shape
|
self.shape = F.shape
|
||||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||||
self.slice = P.Slice()
|
self.slice = P.Slice()
|
||||||
|
self.expand = P.ExpandDims()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
batch, channel, high, width = self.shape(x)
|
_shape_check(self.shape(x))
|
||||||
|
batch, channel, width = self.shape(x)
|
||||||
if width == self.kernel_size[1]:
|
if width == self.kernel_size[1]:
|
||||||
x = self.reduce_mean(x, 3)
|
x = self.reduce_mean(x, 2)
|
||||||
elif width - self.kernel_size[1] < self.stride[1]:
|
elif width - self.kernel_size[1] < self.stride[1]:
|
||||||
x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1]))
|
x = self.slice(x, (0, 0, 0), (batch, channel, self.kernel_size[1]))
|
||||||
x = self.reduce_mean(x, 3)
|
x = self.reduce_mean(x, 2)
|
||||||
else:
|
else:
|
||||||
|
x = self.expand(x, 2)
|
||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -69,6 +69,6 @@ class Avg1dNet(nn.Cell):
|
||||||
return self.avg1d(x)
|
return self.avg1d(x)
|
||||||
|
|
||||||
def test_avg1d():
|
def test_avg1d():
|
||||||
net = Avg1dNet(3, 1)
|
net = Avg1dNet(6, 1)
|
||||||
input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
|
input = Tensor(np.random.randint(0, 255, [1, 3, 6]).astype(np.float32))
|
||||||
_executor.compile(net, input)
|
_executor.compile(net, input)
|
Loading…
Reference in New Issue