add nn api of MaxUnpool1d, 2d, 3d

This commit is contained in:
LV 2022-10-12 20:15:45 +08:00 committed by ZhidanLiu
parent c5bafc2bce
commit 60b2bce240
10 changed files with 687 additions and 1 deletions

View File

@ -199,6 +199,9 @@ Dropout层
mindspore.nn.MaxPool1d mindspore.nn.MaxPool1d
mindspore.nn.MaxPool2d mindspore.nn.MaxPool2d
mindspore.nn.MaxPool3d mindspore.nn.MaxPool3d
mindspore.nn.MaxUnpool1d
mindspore.nn.MaxUnpool2d
mindspore.nn.MaxUnpool3d
填充层 填充层
-------------- --------------

View File

@ -0,0 +1,43 @@
mindspore.nn.MaxUnpool1d
========================
.. py:class:: mindspore.nn.MaxUnpool1d(kernel_size, stride=0, padding=0, output_size=())
`Maxpool1d` 的部分逆过程。 `Maxpool1d` 不是完全可逆的,因为非最大值丢失。
`MaxUnpool1d``MaxPool1d` 的输出为输入,包括最大值的索引。在计算 `maxpool1d` 部分逆的过程中,非最大值设置为零。
支持的输入数据格式为 :math:`(N, C, H_{in})`:math:`(C, H_{in})` ,输出数据的个格式为 :math:`(N, C, H_{out})`
:math:`(C, H_{out})` ,计算公式如下:
.. math::
\begin{array}{ll} \\
H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
\end{array}
参数:
- **kernel_size** (Union[int, tuple[int]]) - 池化核尺寸大小。
- **stride** (Union[int, tuple[int]]) - 池化操作的移动步长,若取值为 '0' 或者 '(0)' `stride` 值与 `kernel_size`
相同。默认值None。
- **padding** (str) - 填充值。默认值0。
- **output_size** (tuple[int]) - 输出shape可选参数。默认值()。
如果output_size为()那么输出shape根据 `kernel_size``stride``padding` 计算得出。
如果output_size不为(),那么 `output_size` 必须满足格式 :math:`(N, C, H)`:math:`(C, H)` ,取值范围需满足:
:math:`[(N, C, H_{out} - stride[0]), (N, C, H_{out} + stride[0])]`
输入:
- **x** (Tensor) - 待求逆的Tensor。shape为 :math:`(N, C, H_{in})`:math:`(C, H_{in})`
- **indices** (Tensor) - 最大值的索引。shape必须与输入`x`相同。取值范围需满足 :math:`[0, H_{in} - 1]`
数据类型必须是int32或int64。
输出:
shape为 :math:`(N, C, H_{out})`:math:`(C, H_{out})` 的Tensor数据类型与输入 `x` 相同。
异常:
- **TypeError** - `x``indices` 的数据类型不支持。
- **TypeError** - `kernel_size` `stride``padding` 既不是整数也不是tuple。
- **ValueError** - `stride``kernel_size` 的值不是非负的。
- **ValueError** - `x``indices` 的shape不一致。
- **ValueError** - `padding` 中的值有负数。
- **ValueError** - `output_size` 的长度不为0、2或3。
- **ValueError** - `output_size` 的取值与根据 `kernel_size, stride, padding` 计算得到的结果差距太大。

View File

@ -0,0 +1,46 @@
mindspore.nn.MaxUnpool2d
========================
.. py:class:: mindspore.nn.MaxUnpool2d(kernel_size, stride=0, padding=0, output_size=())
`Maxpool2d` 的部分逆过程。 `Maxpool2d` 不是完全可逆的,因为非最大值丢失。
`MaxUnpool2d``MaxPool2d` 的输出为输入,包括最大值的索引。在计算 `maxpool2d` 部分逆的过程中,非最大值设置为零。
支持的输入数据格式为 :math:`(N, C, H_{in}, W_{in})`:math:`(C, H_{in}, W_{in})`
输出数据的个格式为 :math:`(N, C, H_{out}, W_{out})`:math:`(C, H_{out}, W_{out})` ,计算公式如下:
.. math::
\begin{array}{ll} \\
H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
W_{out} = (W{in} - 1) \times stride[1] - 2 \times padding[1] + kernel_size[1] \\
\end{array}
参数:
- **kernel_size** (Union[int, tuple[int]]) - 池化核尺寸大小。int类型表示池化核的长宽相同。
tuple类型中的两个值分别代表池化核的长和宽。
- **stride** (Union[int, tuple[int]]) - 池化操作的移动步长int类型表示长宽方向的移动步长相同。
tuple中的两个值分别代表长宽方向移动的步长。若取值为 '0' 或者 '(0, 0)'`stride` 值与 `kernel_size` 相同。
默认值None。
- **padding** (str) - 填充值。默认值0。若为int类型则长宽方向的填充大小相同均为 `padding`
若为tuple类型则tuple中的两个值分别代表长宽方向填充的大小。
- **output_size** (tuple[int]) - 输出shape可选参数。默认值()。
如果output_size为()那么输出shape根据 `kernel_size``stride``padding` 计算得出。
如果output_size不为(),那么 `output_size` 必须满足格式 :math:`(N, C, H, W)`:math:`(C, H, W)` ,取值范围需满足:
:math:`[(N, C, H_{out} - stride[0], W_{out} - stride[1]), (N, C, H_{out} + stride[0], W_{out} + stride[1])]`
输入:
- **x** (Tensor) - 待求逆的Tensor。shape为 :math:`(N, C, H_{in}, W_{in})`:math:`(C, H_{in}, W_{in})`
- **indices** (Tensor) - 最大值的索引。shape必须与输入 `x` 相同。取值范围需满足 :math:`[0, H_{in} \times W_{in} - 1]`
数据类型必须是int32或int64。
输出:
shape为 :math:`(N, C, H_{out}, W_{out})`:math:`(C, H_{out}, W_{out})` 的Tensor数据类型与输入 `x` 相同。
异常:
- **TypeError** - `x``indices` 的数据类型不支持。
- **TypeError** - `kernel_size` `stride``padding` 既不是整数也不是tuple。
- **TypeError** - `kernel_size` `stride``padding` 为tuple时长度不等于2。
- **ValueError** - `stride` `kernel_size``padding` 的值不是非负的。
- **ValueError** - `x``indices` 的shape不一致。
- **ValueError** - `padding` 中的值有负数。
- **ValueError** - `output_size` 的长度不为0、3或4。
- **ValueError** - `output_size` 的取值与根据 `kernel_size, stride, padding` 计算得到的结果差距太大。

View File

@ -0,0 +1,50 @@
mindspore.nn.MaxUnpool3d
========================
.. py:class:: mindspore.nn.MaxUnpool3d(kernel_size, stride=0, padding=0, output_size=())
`Maxpool3d` 的部分逆过程。 `Maxpool3d` 不是完全可逆的,因为非最大值丢失。
`MaxUnpool3d``MaxPool3d` 的输出为输入,包括最大值的索引。在计算 `maxpool3d` 部分逆的过程中,非最大值设置为零。
支持的输入数据格式为 :math:`(N, C, D_{in}, H_{in}, W_{in})`:math:`(C, D_{in}, H_{in}, W_{in})`
输出数据的个格式为 :math:`(N, C, D_{out}, H_{out}, W_{out})`:math:`(C, D_{out}, H_{out}, W_{out})` ,计算公式如下:
.. math::
\begin{array}{ll} \\
D_{out} = (D{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
H_{out} = (H{in} - 1) \times stride[1] - 2 \times padding[1] + kernel_size[1] \\
W_{out} = (W{in} - 1) \times stride[2] - 2 \times padding[2] + kernel_size[2] \\
\end{array}
参数:
- **kernel_size** (Union[int, tuple[int]]) - 池化核尺寸大小。int类型表示池化核的深度、长和宽相同。
tuple类型中的三个值分别代表池化核的深度、长和宽。
- **stride** (Union[int, tuple[int]]) - 池化操作的移动步长int类型表示深度、长和宽方向的移动步长相同。
tuple中的三个值分别代表深度、长和宽方向移动的步长。若取值为 '0' 或者 '(0, 0, 0)' `stride` 值与 `kernel_size` 相同。
默认值None。
- **padding** (str) - 填充值。默认值0。若为int类型则深度、长和宽方向的填充大小相同均为 `padding`
若为tuple类型则tuple中的三个值分别代表深度、长和宽方向填充的大小。
- **output_size** (tuple[int]) - 输出shape可选参数。默认值()。
如果output_size为()那么输出shape根据 `kernel_size``stride``padding` 计算得出。
如果output_size不为(),那么 `output_size` 必须满足格式 :math:`(N, C, D, H, W)`:math:`(C, D, H, W)`
取值范围需满足:
:math:`[(N, C, D_{out} - stride[0], H_{out} - stride[1], W_{out} - stride[2]), (N, C, D_{out} + stride[0], H_{out} + stride[1], W_{out} + stride[2])]`
输入:
- **x** (Tensor) - 待求逆的Tensor。shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`
:math:`(C, D_{in}, H_{in}, W_{in})`
- **indices** (Tensor) - 最大值的索引。shape必须与输入 `x` 相同。取值范围需满足
:math:`[0, D_{in} \times H_{in} \times W_{in} - 1]` 。数据类型必须是int32或int64。
输出:
shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`:math:`(C, D_{out}, H_{out}, W_{out})` 的Tensor
数据类型与输入 `x` 相同。
异常:
- **TypeError** - `x``indices` 的数据类型不支持。
- **TypeError** - `kernel_size` `stride``padding` 既不是整数也不是tuple。
- **TypeError** - `kernel_size` `stride``padding` 为tuple时长度不等于3。
- **ValueError** - `stride` `kernel_size``padding` 的值不是非负的。
- **ValueError** - `x``indices` 的shape不一致。
- **ValueError** - `padding` 中的值有负数。
- **ValueError** - `output_size` 的长度不为0、4或5。
- **ValueError** - `output_size` 的取值与根据 `kernel_size, stride, padding` 计算得到的结果差距太大。

View File

@ -199,6 +199,9 @@ Pooling Layer
mindspore.nn.MaxPool1d mindspore.nn.MaxPool1d
mindspore.nn.MaxPool2d mindspore.nn.MaxPool2d
mindspore.nn.MaxPool3d mindspore.nn.MaxPool3d
mindspore.nn.MaxUnpool1d
mindspore.nn.MaxUnpool2d
mindspore.nn.MaxUnpool3d
Padding Layer Padding Layer
------------- -------------

View File

@ -26,11 +26,12 @@ from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool3D, AdaptiveAvgPool3D from mindspore.ops.operations.nn_ops import AdaptiveMaxPool3D, AdaptiveAvgPool3D
from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize
from mindspore.ops.operations.nn_ops import MaxPool3DWithArgmax from mindspore.ops.operations.nn_ops import MaxPool3DWithArgmax
from mindspore.ops.operations.nn_ops import MaxUnpool2D, MaxUnpool3D
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
__all__ = ['AvgPool3d', 'MaxPool3d', 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d', 'FractionalMaxPool2d', __all__ = ['AvgPool3d', 'MaxPool3d', 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d', 'FractionalMaxPool2d',
'FractionalMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'FractionalMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d'] 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d']
class _PoolNd(Cell): class _PoolNd(Cell):
@ -1327,3 +1328,320 @@ class FractionalMaxPool3d(Cell):
if self.return_indices: if self.return_indices:
return output return output
return output[0] return output[0]
class MaxUnpool1d(Cell):
r"""
Computes a partial inverse of MaxPool1d.
MaxPool1d is not fully invertible, since the non-maximal values are lost.
MaxUnpool1d takes in as input the output of MaxPool1d including the indices of the maximal values
and computes a partial inverse in which all non-maximal values are set to zero. Typically the input
is of shape :math:`(N, C, H_{in})` or :math:`(C, H_{in})`, and the output is of shape :math:`(N, C, H_{out}`
or :math:`(C, H_{out}`. The operation is as follows.
.. math::
\begin{array}{ll} \\
H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
\end{array}
Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value.
stride (Union[int, tuple[int]]): The distance of kernel moving,
If stride is 0 or (0), then stride equal to kernel_size. Default: None.
padding (Union[int, tuple[int]]): The pad value to be filled. Default: 0.
output_size (tuple[int]) : The target output size is an optional input. Default: ().
If output_size == (), then the shape of output computed by kernel_size, stride and padding.
If output_size != (), then output_size must be :math:`(N, C, H)` or
:math:`(C, H)` and output_size must belong to
:math:`[(N, C, H_{out} - stride[0]), (N, C, H_{out} + stride[0])]`.
Inputs:
- **x** (Tensor) - The input Tensor to invert.
Tensor of shape :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
- **indices** (Tensor) - Max values' index represented by the indices.
Tensor of shape must be same with input 'x'.
Values of indices must belong to :math:`[0, H_{in} - 1]`.
Data type must be in int32 or int64.
Outputs:
Tensor, with shape :math:`(N, C, H_{out})` or :math:`(C, H_{out})`,
with the same data type with `x`.
Raises:
TypeError: If data type of `x` or `indices` is not supported.
TypeError: If `kernel_size`, `stride` or `padding` is neither int nor tuple.
ValueError: If numbers in `stride` (also support 0 and (0)) or `kernel_size` is not positive.
ValueError: If the shape of `x` and `indices` are not equal.
ValueError: If numbers in `padding` is negative.
ValueError: If `output_size` whose length is neither 0, 2 or 3.
ValueError: If `output_size` is not close to output size
computed by attr `kernel_size, stride, padding`.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([[2, 4, 6, 8]]).astype(np.float32))
>>> indices = Tensor(np.array([[1, 3, 5, 7]]).astype(np.int64))
>>> maxunpool1d = nn.MaxUnpool1d(kernel_size =2, stride=2, padding=0)
>>> output = maxunpool1d(x, indices)
>>> print(output.asnumpy())
[[0, 2, 0, 4, 0, 6, 0, 8]]
"""
def __init__(self, kernel_size, stride=None, padding=0, output_size=()):
"""Initialize MaxUnpool1d."""
super(MaxUnpool1d, self).__init__()
if len(output_size) == 2:
output_size = (1,) + output_size
if not stride:
stride = 0
self.max_unpool2d = MaxUnpool2D(ksize=(kernel_size, 1), strides=(stride, 1), pads=(padding, 0),
output_shape=output_size, data_format="NCHW")
self.shape = P.Shape()
@staticmethod
@constexpr
def _check_input_dim(x_shape, indices_shape, cls_name):
x_dim = len(x_shape)
if x_shape != indices_shape:
raise ValueError(f"For '{cls_name}', the x shape and indices shape must be equal, but got input "
f"shape {x_shape} and indices shape {indices_shape}.")
if x_dim not in (2, 3):
raise ValueError(f"For '{cls_name}', the x shape must have 2 or 3 dims, but got {x_dim}.")
return x_dim
def construct(self, x, indices):
x_shape = self.shape(x)
indices_shape = self.shape(indices)
x_dim = self._check_input_dim(x_shape, indices_shape, self.cls_name)
if x_dim == 2:
x = x.expand_dims(axis=0)
indices = indices.expand_dims(axis=0)
x = x.expand_dims(axis=3)
indices = indices.expand_dims(axis=3)
out = self.max_unpool2d(x, indices)
out = out.squeeze(-1)
out = out.squeeze(0)
else:
x = x.expand_dims(axis=3)
indices = indices.expand_dims(axis=3)
out = self.max_unpool2d(x, indices)
out = out.squeeze(-1)
return out
class MaxUnpool2d(Cell):
r"""
Computes a partial inverse of Maxpool2d.
MaxPool2d is not fully invertible, since the non-maximal values are lost.
MaxUnpool2d takes in as input the output of Maxpool2d including the indices of the maximal values
and computes a partial inverse in which all non-maximal values are set to zero. Typically the input
is of shape :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`, and the output is of
shape :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`. The operation is as follows.
.. math::
\begin{array}{ll} \\
H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
W_{out} = (W{in} - 1) \times stride[1] - 2 \times padding[1] + kernel_size[1] \\
\end{array}
Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
an int number that represents height and width of the kernel, or a tuple
of two int numbers that represent height and width respectively.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both stride, or a tuple of two int numbers that
represent height and width of movement respectively.
If stride is 0 or (0, 0), then stride equal to kernel_size. Default: None.
padding (Union[int, tuple[int]]): The pad value to be filled. Default: 0. If `padding` is an integer,
the paddings of height and width are the same, equal to padding. If `padding` is a tuple of two
integers, the padding of height and width equal to padding[0] and padding[1] correspondingly.
output_size (tuple[int]) : The target output size is an optional parameter. Default: ().
If output_size == (), then the shape of output computed by kernel_size, stride and padding.
If output_size != (), then output_size must be :math:`(N, C, H, W)` and output_size must belong to
:math:`[(N, C, H_{out} - stride[0], W_{out} - stride[1]),
(N, C, H_{out} + stride[0], W_{out} + stride[1])]`.
Inputs:
- **x** (Tensor) - The input Tensor to invert.
Tensor of shape :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- **indices** (Tensor) - Max values' index represented by the indices.
Tensor of shape must be same with input 'x'.
Values of indices must belong to :math:`[0, H_{in} \times W_{in} - 1]`.
Data type must be in int32 or int64.
Outputs:
Tensor, with shape :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`,
with the same data type with `x`.
Raises:
TypeError: If data type of `x` or `indices` is not supported.
TypeError: If `kernel_size`, `stride` or `padding` is neither int nor tuple.
ValueError: If numbers in `stride` (also support 0 and (0, 0)) or `kernel_size` is not positive.
ValueError: If the shape of `x` and `indices` are not equal.
ValueError: If numbers in `padding` is negative.
ValueError: If `kernel_size`, `stride` or `padding` is a tuple whose length is not equal to 2.
ValueError: If `output_size` whose length is neither 0, 3 or 4.
ValueError: If `output_size` is not close to output size
computed by attr `kernel_size, stride, padding`.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([[[[0, 1], [8, 9]]]]).astype(np.float32))
>>> indices = Tensor(np.array([[[[0, 1], [2, 3]]]]).astype(np.int64))
>>> maxunpool2d = nn.MaxUnpool2d(kernel_size=1, stride=1, padding=0)
>>> output = maxunpool2d(x, indices)
>>> print(output.asnumpy())
[[[[0. 1.]
[8. 9.]]]]
"""
def __init__(self, kernel_size, stride=None, padding=0, output_size=()):
"""Initialize MaxUnpool2d."""
super(MaxUnpool2d, self).__init__()
if len(output_size) == 3:
output_size = (1,) + output_size
if not stride:
stride = 0
self.max_unpool2d = MaxUnpool2D(ksize=kernel_size, strides=stride, pads=padding, output_shape=output_size,
data_format="NCHW")
self.shape = P.Shape()
@staticmethod
@constexpr
def _check_input_dim(x_shape, indices_shape, cls_name):
x_dim = len(x_shape)
if x_shape != indices_shape:
raise ValueError(f"For '{cls_name}', the x shape and indices shape must be equal, but got input "
f"shape {x_shape} and indices shape {indices_shape}.")
if x_dim not in (3, 4):
raise ValueError(f"For '{cls_name}', the x shape must have 3 or 4 dims, but got {x_dim}.")
return x_dim
def construct(self, x, indices):
x_shape = self.shape(x)
indices_shape = self.shape(indices)
x_dim = self._check_input_dim(x_shape, indices_shape, self.cls_name)
if x_dim == 3:
x = x.expand_dims(axis=0)
indices = indices.expand_dims(axis=0)
out = self.max_unpool2d(x, indices)
out = out.squeeze(0)
else:
out = self.max_unpool2d(x, indices)
return out
class MaxUnpool3d(Cell):
r"""
Computes a partial inverse of MaxPool3d.
MaxPool3d is not fully invertible, since the non-maximal values are lost.
MaxUnpool3d takes in as input the output of MaxPool3d including the indices of the maximal
values and computes a partial inverse in which all non-maximal values are set to zero.
Typically the input is of shape :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`,
and the output is of shape :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`.
The operation is as follows.
.. math::
\begin{array}{ll} \\
D_{out} = (D{in} - 1) \times stride[0] - 2 \times padding[0] + kernel_size[0] \\
H_{out} = (H{in} - 1) \times stride[1] - 2 \times padding[1] + kernel_size[1] \\
W_{out} = (W{in} - 1) \times stride[2] - 2 \times padding[2] + kernel_size[2] \\
\end{array}
Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
an int number that represents depth, height and width of the kernel, or a tuple
of three int numbers that represent depth, height and width respectively.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both stride, or a tuple of three int numbers that
represent depth, height and width of movement respectively.
If stride is 0 or (0, 0, 0), then stride equal to kernel_size. Default: None.
padding (Union[int, tuple[int]]): The pad value to be filled. Default: 0. If `padding` is an integer,
the paddings of depth, height and width are the same, equal to padding. If `padding` is a tuple of three
integers, the padding of depth, height and width equal to padding[0], padding[1] and padding[2]
correspondingly.
output_size (tuple[int]) : The target output size is an optional input. Default: ().
If output_size == (), then the shape of output computed by kernel_size, stride and padding.
If output_size != (), then output_size must be :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` and
output_size must belong to
:math:`[(N, C, D_{out} - stride[0], H_{out} - stride[1], W_{out} - stride[2]),
(N, C, D_{out} + stride[0], H_{out} + stride[1], W_{out} + stride[2])]`.
Inputs:
- **x** (Tensor) - The input Tensor to invert.
Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- **indices** (Tensor) - Max values' index represented by the indices.
Tensor of shape must be same with input 'x'.
Values of indices must belong to :math:`[0, D_{in} \times H_{in} \times W_{in} - 1]`.
Data type must be in int32 or int64.
Outputs:
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
with the same data type with `x`.
Raises:
TypeError: If data type of `x` or `indices` is not supported.
TypeError: If `kernel_size`, `stride` or `padding` is neither int nor tuple.
ValueError: If numbers in `stride` (also support 0 and (0, 0, 0)) or `kernel_size` is not positive.
ValueError: If numbers in `padding` is negative.
ValueError: If `kernel_size`, `stride` or `padding` is a tuple whose length is not equal to 3.
ValueError: If `output_size` whose length is neither 0, 4 or 5.
ValueError: If `output_size` is not close to output size
computed by attr `kernel_size, stride, padding`.
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([[[[[0, 1], [8, 9]]]]]).astype(np.float32))
>>> indices= Tensor(np.array([[[[[0, 1], [2, 3]]]]]).astype(np.int64))
>>> maxunpool3d = nn.MaxUnpool3d(kernel_size=1, stride=1, padding=0)
>>> output = maxunpool3d(x, indices)
>>> print(output.asnumpy())
[[[[[0. 1.]
[8. 9.]]]]]
"""
def __init__(self, kernel_size, stride=None, padding=0, output_size=()):
super(MaxUnpool3d, self).__init__()
if len(output_size) == 4:
output_size = (1,) + output_size
if not stride:
stride = 0
self.max_unpool3d = MaxUnpool3D(ksize=kernel_size, strides=stride, pads=padding, output_shape=output_size,
data_format="NCDHW")
self.shape = P.Shape()
@staticmethod
@constexpr
def _check_input_dim(x_shape, indices_shape, cls_name):
x_dim = len(x_shape)
if x_shape != indices_shape:
raise ValueError(f"For '{cls_name}', the x shape and indices shape must be equal, but got input "
f"shape {x_shape} and indices shape {indices_shape}.")
if x_dim not in (4, 5):
raise ValueError(f"For '{cls_name}', the x shape must have 4 or 5 dims, but got {x_dim}.")
return x_dim
def construct(self, x, indices):
x_shape = self.shape(x)
indices_shape = self.shape(indices)
x_dim = self._check_input_dim(x_shape, indices_shape, self.cls_name)
if x_dim == 4:
x = x.expand_dims(axis=0)
indices = indices.expand_dims(axis=0)
out = self.max_unpool3d(x, indices)
out = out.squeeze(0)
else:
out = self.max_unpool3d(x, indices)
return out

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
class Net(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(Net, self).__init__()
self.max_unpool1d = nn.MaxUnpool1d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool1d(x, indices)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_max_unpool1d_normal(mode):
"""
Feature: max_unpool1d
Description: Verify the result of MaxUnpool1d
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.array([[2, 4, 6, 8]]).astype(np.float32))
incices = Tensor(np.array([[1, 3, 5, 7]]).astype(np.int64))
net = Net(kernel_size=2, stride=2, padding=0)
output = net(x, incices).asnumpy()
expect = np.array([[0, 2, 0, 4, 0, 6, 0, 8]]).astype(np.float32)
assert np.allclose(output, expect, rtol=0.0001)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
class Net(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(Net, self).__init__()
self.max_unpool2d = nn.MaxUnpool2d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool2d(x, indices)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_max_unpool2d_normal(mode):
"""
Feature: max_unpool2d
Description: Verify the result of MaxUnpool2d
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.array([[[6., 8.],
[14., 16.]]]).astype(np.float32))
incices = Tensor(np.array([[[5, 7], [13, 15]]]).astype(np.int64))
net = Net(kernel_size=2, stride=2, padding=0)
output = net(x, incices).asnumpy()
expected_output = np.array([[[0., 0., 0., 0.],
[0, 6., 0., 8.],
[0., 0., 0., 0.],
[0., 14., 0., 16.]]]).astype(np.float32)
assert np.allclose(output, expected_output, rtol=0.0001)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
class Net(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(Net, self).__init__()
self.max_unpool3d = nn.MaxUnpool3d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool3d(x, indices)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_max_unpool3d_normal(mode):
"""
Feature: max_unpool3d
Description: Verify the result of MaxUnpool3d
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.array([[[[[7.]]]], [[[[15.]]]]]), mindspore.float32)
incices = Tensor(np.array([[[[[7]]]], [[[[7]]]]]), mindspore.int64)
net = Net(kernel_size=2, stride=1, padding=0)
output = net(x, incices).asnumpy()
expect = np.array([[[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 7.]]]],
[[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 15.]]]]]).astype(np.float32)
assert np.allclose(output, expect, rtol=0.0001)

View File

@ -120,3 +120,66 @@ def test_adaptive_max_pool_1d():
net = AdaptiveMaxPool1dNet(2) net = AdaptiveMaxPool1dNet(2)
input_ = Tensor(np.random.randint(0, 255, [1, 3, 6]).astype(np.float32)) input_ = Tensor(np.random.randint(0, 255, [1, 3, 6]).astype(np.float32))
_cell_graph_executor.compile(net, input_) _cell_graph_executor.compile(net, input_)
class MaxUnpool2dNet(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(MaxUnpool2dNet, self).__init__()
self.max_unpool2d = nn.MaxUnpool2d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool2d(x, indices)
class MaxUnpool1dNet(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(MaxUnpool1dNet, self).__init__()
self.max_unpool1d = nn.MaxUnpool1d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool1d(x, indices)
class MaxUnpool3dNet(nn.Cell):
def __init__(self, kernel_size, stride=0, padding=0, output_size=()):
super(MaxUnpool3dNet, self).__init__()
self.max_unpool3d = nn.MaxUnpool3d(kernel_size, stride, padding, output_size)
def construct(self, x, indices):
return self.max_unpool3d(x, indices)
def test_max_unpool2d_normal():
"""
Feature: max_unpool2d
Description: Verify the result of MaxUnpool2d
Expectation: success
"""
x = Tensor(np.array([[[6., 8.], [14., 16.]]]).astype(np.float32))
incices = Tensor(np.array([[[5, 7], [13, 15]]]).astype(np.int64))
net = MaxUnpool2dNet(kernel_size=2, stride=2, padding=0)
_cell_graph_executor.compile(net, x, incices)
def test_max_unpool1d_normal():
"""
Feature: max_unpool1d
Description: Verify the result of MaxUnpool1d
Expectation: success
"""
x = Tensor(np.array([[2, 4, 6, 8]]).astype(np.float32))
incices = Tensor(np.array([[1, 3, 5, 7]]).astype(np.int64))
net = MaxUnpool1dNet(kernel_size=2, stride=2, padding=0)
_cell_graph_executor.compile(net, x, incices)
def test_max_unpool3d_normal():
"""
Feature: max_unpool3d
Description: Verify the result of MaxUnpool3d
Expectation: success
"""
x = Tensor(np.array([[[[[7.]]]], [[[[15.]]]]]).astype(np.float32))
incices = Tensor(np.array([[[[[7]]]], [[[[7]]]]]).astype(np.int64))
net = MaxUnpool3dNet(kernel_size=2, stride=1, padding=0)
_cell_graph_executor.compile(net, x, incices)