forked from mindspore-Ecosystem/mindspore
!541 add average pooling 1D
Merge pull request !541 from JichenZhao/avgpooling
This commit is contained in:
commit
67057d1309
|
@ -24,7 +24,7 @@ from .conv import Conv2d, Conv2dTranspose
|
|||
from .lstm import LSTM
|
||||
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold
|
||||
from .embedding import Embedding
|
||||
from .pooling import AvgPool2d, MaxPool2d
|
||||
from .pooling import AvgPool2d, MaxPool2d, AvgPool1d
|
||||
from .image import ImageGradients, SSIM, PSNR
|
||||
|
||||
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
|
||||
|
@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
|
|||
'LSTM',
|
||||
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
|
||||
'Embedding',
|
||||
'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold',
|
||||
'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold',
|
||||
'ImageGradients', 'SSIM', 'PSNR',
|
||||
]
|
||||
|
|
|
@ -14,9 +14,12 @@
|
|||
# ============================================================================
|
||||
"""pooling"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ... import context
|
||||
from ..cell import Cell
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import ParamValidator
|
||||
|
||||
|
||||
class _PoolNd(Cell):
|
||||
|
@ -208,3 +211,81 @@ class AvgPool2d(_PoolNd):
|
|||
|
||||
def construct(self, x):
|
||||
return self.avg_pool(x)
|
||||
|
||||
|
||||
class AvgPool1d(_PoolNd):
|
||||
r"""
|
||||
Average pooling for temporal data.
|
||||
|
||||
Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes.
|
||||
|
||||
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs
|
||||
regional average in the :math:`(W_{in})`-dimension. Given kernel size
|
||||
:math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows.
|
||||
|
||||
.. math::
|
||||
\text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1}
|
||||
\text{input}(N_i, C_j, h_k, s_0 \times w + n)
|
||||
|
||||
Note:
|
||||
pad_mode for training only supports "same" and "valid".
|
||||
|
||||
Args:
|
||||
kernel_size (int): The size of kernel window used to take the average value, Default: 1.
|
||||
stride (int): The distance of kernel moving, an int number that represents
|
||||
the width of movement is strides, Default: 1.
|
||||
pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
|
||||
Default: "valid".
|
||||
|
||||
- same: Adopts the way of completion. Output height and width will be the same as
|
||||
the input. Total number of padding will be calculated for horizontal and vertical
|
||||
direction and evenly distributed to top and bottom, left and right if possible.
|
||||
Otherwise, the last extra padding will be done from the bottom and the right side.
|
||||
|
||||
- valid: Adopts the way of discarding. The possibly largest height and width of output
|
||||
will be return without padding. Extra pixels will be discarded.
|
||||
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> pool = nn.AvgPool1d(kernel_size=3, strides=1)
|
||||
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
|
||||
>>> output = pool(x)
|
||||
>>> output.shape()
|
||||
(1, 2, 4, 2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
pad_mode="valid"):
|
||||
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
|
||||
ParamValidator.check_type('kernel_size', kernel_size, [int,])
|
||||
ParamValidator.check_type('stride', stride, [int,])
|
||||
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
|
||||
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
|
||||
ParamValidator.check_integer("stride", stride, 1, Rel.GE)
|
||||
self.kernel_size = (1, kernel_size)
|
||||
self.stride = (1, stride)
|
||||
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
|
||||
strides=self.stride,
|
||||
padding=self.pad_mode)
|
||||
self.shape = F.shape
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
self.slice = P.Slice()
|
||||
|
||||
def construct(self, x):
|
||||
batch, channel, high, width = self.shape(x)
|
||||
if width == self.kernel_size[1]:
|
||||
x = self.reduce_mean(x, 3)
|
||||
elif width - self.kernel_size[1] < self.stride[1]:
|
||||
x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1]))
|
||||
x = self.reduce_mean(x, 3)
|
||||
else:
|
||||
x = self.avg_pool(x)
|
||||
return x
|
||||
|
|
|
@ -56,3 +56,19 @@ def test_compile_max():
|
|||
net = MaxNet(3, stride=1, padding=0)
|
||||
x = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
|
||||
_executor.compile(net, x)
|
||||
|
||||
|
||||
class Avg1dNet(nn.Cell):
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=None):
|
||||
super(Avg1dNet, self).__init__()
|
||||
self.avg1d = nn.AvgPool1d(kernel_size, stride)
|
||||
|
||||
def construct(self, x):
|
||||
return self.avg1d(x)
|
||||
|
||||
def test_avg1d():
|
||||
net = Avg1dNet(3, 1)
|
||||
input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
|
||||
_executor.compile(net, input)
|
Loading…
Reference in New Issue