fractionalmaxpool_ops

This commit is contained in:
yide12 2022-11-07 16:30:09 +08:00 committed by yide12
parent 11b1963e11
commit 851697791d
12 changed files with 504 additions and 147 deletions

View File

@ -32,6 +32,8 @@ mindspore.ops.function
mindspore.ops.dropout2d
mindspore.ops.dropout3d
mindspore.ops.flatten
mindspore.ops.fractional_max_pool2d
mindspore.ops.fractional_max_pool3d
mindspore.ops.interpolate
mindspore.ops.lp_pool1d
mindspore.ops.lp_pool2d

View File

@ -0,0 +1,35 @@
mindspore.ops.fractional_max_pool2d
===================================
.. py:function:: mindspore.ops.fractional_max_pool2d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
对输入的多维数据进行二维的分数最大池化运算。
对多个输入平面组成的输入上应用2D分数最大池化。在 :math:`(kH_{in}, kW_{in})` 区域上应用最大池化操作由输出shape决定随机步长。对于任何输入shape指定输出shape为 :math:`(H, W)` 。输出特征的数量等于输入平面的数量。
在一个输入Tensor上应用2D fractional max pooling可被视为组成一个2D平面。
分数最大池化的详细描述在 `Fractional Max-Pooling <https://arxiv.org/pdf/1412.6071>`_
参数:
- **input_x** (Tensor) - shape为 :math:`(N, C, H_{in}, W_{in})` 的Tensor。支持的数据类型float16、float32、float64、int32和int64。
- **kernel_size** (Union[int, tuple[int]]) - 指定池化核尺寸大小如果为int则代表池化核的高和宽。如果为tuple其值必须包含两个正整数值分别表示池化核的高和宽。取值必须为正整数。
- **output_size** (Union[int, tuple[int]],可选) - 目标输出shape。如果是整数则表示输出目标的高和宽。如果是tuple其值必须包含两个整数值分别表示目标输出的高和宽。默认值None。
- **output_ratio** (Union[float, tuple[float]],可选) - 目标输出shape与输入shape的比率。通过输入shape和 `output_ratio` 确定输出shape。支持数据类型float16、float32、double数值介于0到1之间。默认值None。
- **return_indices** (bool可选) - 如果为 `True` 返回分数最大池化的最大值的的索引值。默认值False。
- **_random_samples** (Tensor可选) - 3D张量分数最大池化的随机步长。支持的数据类型float16、float32、double。数值介于0到1之间。shape为 :math:`(N, C, 2)` 的Tensor。默认值None。
返回:
- **y** (Tensor) - 数据类型和输入相同shape是 :math:`(N, C, H, W)`
- **argmax** (Tensor) - 输出的索引是一个张量。shape和输出 `y` 一致数据类型是int64。仅当 `return_indices` 为True时输出最大池化的索引值。
异常:
- **TypeError** - `input_x` 不是float16、float32、float64、int32或int64。
- **TypeError** - `_random_samples` 不是float16、float32或float64。
- **ValueError** - `kernel_size` 不是整数并且不是长度为2的元组。
- **ValueError** - `output_shape` 不是整数并且不是长度为2的元组。
- **ValueError** - `kernel_size` `output_shape` 与-1的和大于 `input_x` 的对应维度的量。
- **ValueError** - `_random_samples` 维度不是3。
- **ValueError** - `output_size``output_ratio` 同时为 `None`
- **ValueError** - `input_x` 和 `_random_samples` 的第一维度大小不相等。
- **ValueError** - `input_x` 和 `_random_samples` 第二维度大小不相等。
- **ValueError** - `_random_samples` 第三维度大小不是2。

View File

@ -0,0 +1,38 @@
mindspore.ops.fractional_max_pool3d
===================================
.. py:function:: mindspore.ops.fractional_max_pool3d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
对输入的多维数据进行三维的分数最大池化运算。
对多个输入平面组成的输入上应用3D分数最大池化。在 :math:`(kD_{in}, kH_{in}, kW_{in})` 区域上应用最大池化操作由输出shape决定随机步长。输出特征的数量等于输入平面的数量。
分数最大池化的详细描述在 `Fractional MaxPooling by Ben Graham <https://arxiv.org/abs/1412.6071>`_
输入输出的数据格式可以是”NCDHW“。其中N是批次大小C是通道数D是特征深度H是特征高度W是特征宽度。
参数:
- **input_x** (Tensor) - 4维或5维的张量支持的数据类型float16、float32、double、int32、int64。支持shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`
- **kernel_size** (Union[int, tuple[int]]) - 指定池化核尺寸大小如果为int则代表池化核的深度高和宽。如果为tuple其值必须包含三个正整数值分别表示池化核的深度高和宽。取值必须为正整数。
- **output_size** (Union[int, tuple[int]],可选) - 目标输出大小。如果是整数则表示输出目标的深、高和宽。如果是tuple其值必须包含三个整数值分别表示目标输出的深、高和宽。默认值None。
- **output_ratio** (Union[float, tuple[float]],可选) - 目标输出shape与输入shape的比率。通过输入shape和 `output_ratio` 确定输出shape。支持数据类型float16、float32、double数值介于0到1之间。默认值None。
- **return_indices** (bool可选) - 如果为 `True` 返回分数最大池化的最大值的的索引值。默认值False。
- **_random_samples** (Tensor可选) - 随机步长。支持的数据类型float16、float32、double。shape为 :math:`(N, C, 3)` 的Tensor。数值介于0到1之间。默认值None。
返回:
- **y** (Tensor) - 3D分数最大池化的输出是一个张量。数据类型和输入相同shape是 :math:`(N, C, D, H, W)`
- **argmax** (Tensor) - 仅当 `return_indices` 为True时输出最大池化的索引值。shape和输出 `y` 一致。
异常:
- **TypeError** - `input_x` 不是4维或5维张量。
- **TypeError** - `random_samples` 不是3维张量。
- **TypeError** - `x` 数据类型不是float16、float32、double、int32、int64。
- **TypeError** - `random_samples` 数据类型不是float16、float32、double。
- **TypeError** - `argmax` 数据类型不是int32、int64。
- **ValueError** - `output_shape` 不是长度为3的元组。
- **ValueError** - `kernal_size` 不是长度为3的元组。
- **ValueError** - `output_shape``kernel_size` 不是正数。
- **ValueError** - `output_size``output_ratio` 同时为 `None`
- **ValueError** - `input_x``random_samples` 的第一维度大小不相等。
- **ValueError** - `input_x``random_samples` 第二维度大小不相等。
- **ValueError** - `random_samples` 第三维度大小不是3。

View File

@ -32,6 +32,8 @@ Neural Network
mindspore.ops.dropout2d
mindspore.ops.dropout3d
mindspore.ops.flatten
mindspore.ops.fractional_max_pool2d
mindspore.ops.fractional_max_pool3d
mindspore.ops.interpolate
mindspore.ops.lp_pool1d
mindspore.ops.lp_pool2d

View File

@ -25,7 +25,6 @@ import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool3D, AdaptiveAvgPool3D
from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize
from mindspore.ops.operations.nn_ops import MaxPool3DWithArgmax
from mindspore.nn.cell import Cell
@ -1328,59 +1327,35 @@ class FractionalMaxPool2d(Cell):
>>> net = nn.FractionalMaxPool2d(kernel_size=2, output_size=(2, 2), _random_samples=_random_samples,
... return_indices=True)
>>> y, argmax = net(input_x)
>>> print(y)
Tensor(shape=[1, 1, 2, 2], dtype=Float32, value=
[[[[9.54500020e-001, 8.76399994e-001],
[9.67299998e-001, 9.85199988e-001]]]])
>>> print(argmax)
Tensor(shape=[1, 1, 2, 2], dtype=Int64, value=
[[[[ 1, 9],
[16, 24]]]])
>>> y
[[[[0.9545 0.8764]
[0.9673 0.9852]]]]
>>> argmax
[[[[ 1 9]
[16 24]]]]
>>> net = nn.FractionalMaxPool2d(kernel_size=2, output_ratio=(0.5, 0.5), _random_samples=_random_samples,
... return_indices=True)
>>> y, argmax = net(input_x)
>>> print(y)
Tensor(shape=[1, 1, 2, 2], dtype=Float32, value=
[[[[9.54500020e-001, 8.76399994e-001],
[9.67299998e-001, 9.85199988e-001]]]])
[[[[0.9545 0.8764]
[0.9673 0.9852]]]]
>>> print(argmax)
Tensor(shape=[1, 1, 2, 2], dtype=Int64, value=
[[[[ 1, 9],
[16, 24]]]])
[[[[ 1 9]
[16 24]]]]
"""
def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None):
"""Initialize FractionalMaxPool2d."""
super(FractionalMaxPool2d, self).__init__()
self.kernel_size = kernel_size
self.output_size = output_size
self.output_ratio = output_ratio
self.return_indices = return_indices
self.output_ratio = None
if _random_samples is None:
_random_samples = Tensor(np.array([[[0, 0]]]), mstype.float32)
self.random_samples = _random_samples
if output_ratio is not None:
if isinstance(output_ratio, float):
output_ratio = (output_ratio, output_ratio)
validator.check_float_range(output_ratio[0], 0.0, 1.0, Rel.INC_RIGHT)
validator.check_float_range(output_ratio[1], 0.0, 1.0, Rel.INC_RIGHT)
self.kernel_size = kernel_size
self.output_ratio = output_ratio
elif output_size is not None:
self.fractional_max_pool2d = FractionalMaxPoolWithFixedKsize(kernel_size, output_size)
else:
raise ValueError("'output_size' and 'output_ratio' can not be None at the same time.")
self._random_samples = _random_samples
def construct(self, x):
if self.output_ratio is not None:
output_size = (int(x.shape[-2] * self.output_ratio[0]), int(x.shape[-1] * self.output_ratio[1]))
fractional_max_pool2d = FractionalMaxPoolWithFixedKsize(self.kernel_size, output_size)
output = fractional_max_pool2d(x, self.random_samples)
if self.return_indices:
return output
return output[0]
output = self.fractional_max_pool2d(x, self.random_samples)
if self.return_indices:
return output
return output[0]
return ops.fractional_max_pool2d(x, self.kernel_size, self.output_size, self.output_ratio, self.return_indices,
self._random_samples)
class FractionalMaxPool3d(Cell):
@ -1458,56 +1433,30 @@ class FractionalMaxPool3d(Cell):
... _random_samples=_random_samples, return_indices=True)
>>> output, argmax = net(x)
>>> print(output)
Tensor(shape=[1, 1, 1, 1, 3], dtype=Float32, value=
[[[[[1.30000000e+001, 1.40000000e+001, 1.60000000e+001]]]]])
[[[[[13. 14. 16.]]]]]
>>> print(argmax)
Tensor(shape=[1, 1, 1, 1, 3], dtype=Int64, value=
[[[[[12, 13, 15]]]]])
[[[[[12 13 15]]]]]
>>> net = nn.FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
... _random_samples=_random_samples, return_indices=True)
>>> output, argmax = net(x)
>>> print(output)
Tensor(shape=[1, 1, 1, 1, 2], dtype=Float32, value=
[[[[[1.30000000e+001, 1.60000000e+001]]]]])
[[[[[13. 16.]]]]]
>>> print(argmax)
Tensor(shape=[1, 1, 1, 1, 2], dtype=Int64, value=
[[[[[12, 15]]]]])
[[[[[12 15]]]]]
"""
def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None):
"""Initialize FractionalMaxPool3d."""
super(FractionalMaxPool3d, self).__init__()
self.kernel_size = kernel_size
self.output_size = output_size
self.output_ratio = output_ratio
self.return_indices = return_indices
self.output_ratio = None
if _random_samples is None:
_random_samples = Tensor(np.array([0, 0, 0]).reshape([1, 1, 3]), mstype.float32)
self.random_samples = _random_samples
if output_ratio is not None:
if isinstance(output_ratio, float):
output_ratio = (output_ratio, output_ratio, output_ratio)
validator.check_float_range(output_ratio[0], 0.0, 1.0, Rel.INC_RIGHT)
validator.check_float_range(output_ratio[1], 0.0, 1.0, Rel.INC_RIGHT)
validator.check_float_range(output_ratio[2], 0.0, 1.0, Rel.INC_RIGHT)
self.kernel_size = kernel_size
self.output_ratio = output_ratio
elif output_size is not None:
self.fractional_max_pool3d = FractionalMaxPool3DWithFixedKsize(kernel_size, output_size)
else:
raise ValueError("'output_size' and 'output_ratio' can not be None at the same time.")
self._random_samples = _random_samples
def construct(self, x):
if self.output_ratio:
output_size = (int(x.shape[-3] * self.output_ratio[0]), int(x.shape[-2] * self.output_ratio[1]),
int(x.shape[-1] * self.output_ratio[2]))
fractional_max_pool3d = FractionalMaxPool3DWithFixedKsize(self.kernel_size, output_size)
output = fractional_max_pool3d(x, self.random_samples)
if self.return_indices:
return output
return output[0]
output = self.fractional_max_pool3d(x, self.random_samples)
if self.return_indices:
return output
return output[0]
return ops.fractional_max_pool3d(x, self.kernel_size, self.output_size, self.output_ratio, self.return_indices,
self._random_samples)
class MaxUnpool1d(Cell):

View File

@ -335,6 +335,8 @@ from .nn_func import (
flip,
fliplr,
flipud,
fractional_max_pool2d,
fractional_max_pool3d,
pixel_shuffle,
pixel_unshuffle,
hardshrink,

View File

@ -33,6 +33,7 @@ from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error
from mindspore.ops.operations.nn_ops import MaxUnpool2D, MaxUnpool3D
from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize
slice_ = P.Slice()
fast_gelu_ = P.FastGeLU()
@ -1328,6 +1329,217 @@ def fast_gelu(x):
return fast_gelu_(x)
@constexpr
def _check_float_range_inc_right(arg_value, lower_limit, upper_limit, arg_name=None, prim_name=None):
"""
Method for checking whether input value is in float range inc right.
"""
return validator.check_float_range(arg_value, lower_limit, upper_limit, Rel.INC_RIGHT, arg_name, prim_name)
def fractional_max_pool2d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
r"""
2D fractional max pooling operation for temporal data.
Applies a 2D fractional max pooling to an input signal composed of multiple input planes.
The max-pooling operation is applied in kH × kW regions by a stochastic step size determined by
the target output size. For any input size, the size of the specified output is H x W. The number
of output features is equal to the number of input planes.
Fractional MaxPooling is described in the paper `Fractional Max-Pooling <https://arxiv.org/pdf/1412.6071>`_.
Args:
input_x (Tensor): Tensor of shape :math:`(N, C, H_{in}, W_{in})`,
with float16, float32, float64, int32, int64 data type.
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
is an int number that represents height and width of the kernel, or a tuple
of two int numbers that represent height and width respectively.
The value must be a positive integer.
output_size (Union[int, tuple[int]], optional): The Shape of the target `output_size`,
is an int number that represents height and width, or a tuple
of two int numbers that represent height and width respectively.
The value must be a positive integer.
Default: None.
output_ratio (Union[float, tuple[float]], optional): The ratio of target output shape to input shape.
Specifying the size of the output tensor by using a ratio of the input size.
Data type : float16, float32, double, and value is between (0, 1).
Default: None.
return_indices (bool, optional): If `return_indices` is True, the indices of max value would be output.
Default: False.
_random_samples (Tensor, optional): The random step of FractionalMaxPool2d, which is a 3D tensor.
Tensor of data type : float16, float32, double, and value is between (0, 1).
Supported shape :math:`(N, C, 2)`.
Default: None.
Returns:
- **y** (Tensor) - Has the same type as the `input_x`.
Has the shape :math:`(N, C, H, W)`.
- **argmax** (Tensor) - The indices along with the outputs, which is a Tensor, with the same shape as the
`y` and int64 data type. It will output only when `return_indices` is True.
Raises:
TypeError: If data type of `input_x` is not one of the following: float16, float32, float64, int32, int64.
TypeError: If data type of `_random_samples` is not one of the following: float16, float32, float64.
ValueError: If `kernel_size` is not a number and `kernel_size` is not a tuple of length 2.
ValueError: If `output_size` is not a number and `output_size` is not a tuple of length 2.
ValueError: If the sum of `kernel_size` , `output_size` and -1 is larger than the corresponding
dimension of `input_x`.
ValueError: If the dimension of `_random_samples` is not 3.
ValueError: if `output_size` and `output_ratio` are None at the same time.
ValueError: If the first dimension size of `input_x` and `_random_samples` is not equal.
ValueError: If the second dimension size of `input_x` and `_random_samples` is not equal.
ValueError: If the third dimension size of `_random_samples` is not 2.
Supported Platforms:
``CPU``
Examples:
>>> input_x = Tensor(np.array([0.3220, 0.9545, 0.7879, 0.0975, 0.3698,
... 0.5135, 0.5740, 0.3435, 0.1895, 0.8764,
... 0.9581, 0.4760, 0.9014, 0.8522, 0.3664,
... 0.4980, 0.9673, 0.9879, 0.6988, 0.9022,
... 0.9304, 0.1558, 0.0153, 0.1559, 0.9852]).reshape([1, 1, 5, 5]), mstype.float32)
>>> _random_samples = Tensor(np.array([[[0.8, 0.8]]]), mstype.float32)
>>> y, argmax = ops.fractional_max_pool2d(input_x, kernel_size=2, output_size=(2, 2),
... _random_samples=_random_samples, return_indices=True)
>>> print(y)
[[[[0.9545 0.8764]
[0.9673 0.9852]]]]
>>> print(argmax)
[[[[ 1 9]
[16 24]]]]
>>> y, argmax = ops.fractional_max_pool2d(input_x, kernel_size=2, output_ratio=(0.5, 0.5),
... _random_samples=_random_samples, return_indices=True)
>>> print(y)
[[[[0.9545 0.8764]
[0.9673 0.9852]]]]
>>> print(argmax)
[[[[ 1 9]
[16 24]]]]
"""
if output_ratio is not None and output_size is not None or output_ratio is None and output_size is None:
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
f"at the same time, but got {output_ratio} and {output_size} .")
if len(input_x.shape) == 3:
input_x.expend_dims(axis=0)
if _random_samples is None:
_random_samples = Tensor([[[0, 0]]], mstype.float32)
if output_ratio is not None:
if isinstance(output_ratio, float):
output_ratio = (output_ratio, output_ratio)
_check_float_range_inc_right(output_ratio[0], 0.0, 1.0)
_check_float_range_inc_right(output_ratio[1], 0.0, 1.0)
output_size = (int(input_x.shape[-2] * output_ratio[0]), int(input_x.shape[-1] * output_ratio[1]))
fractional_max_pool = FractionalMaxPoolWithFixedKsize(kernel_size, output_size)
output = fractional_max_pool(input_x, _random_samples)
if return_indices:
return output
return output[0]
def fractional_max_pool3d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
r"""
3D fractional max pooling operation for temporal data.
This operator applies a 3D fractional max pooling over an input signal composed of several input planes.
The max-pooling operation is applied in kD x kH x kW regions by a stochastic step size determined
by the target output size.The number of output features is equal to the number of input planes.
Refer to the paper `Fractional MaxPooling by Ben Graham <https://arxiv.org/abs/1412.6071>`_ for more details.
The input and output data format can be "NCDHW". N is the batch size, C is the number of channels,
D the feature depth, H is the feature height, and W is the feature width.
Args:
input_x (Tensor): The input of FractionalMaxPool3d, which is a 4D or 5D tensor.
Tensor of data type : float16, float32, double, int32, int64.
Supported shape :math:`(N, C, D_{in}, H_{in}, W_{in})` .
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
is an int number that represents depth, height and width of the kernel, or a tuple
of three int numbers that represent depth, height and width respectively.
The value must be a positive integer.
output_size (Union[int, tuple[int]], optional): The Shape of the target `output_size`,
is an int number that represents depth, height and width, or a tuple
of three int numbers that represent depth, height and width respectively.
The value must be a positive integer.
Default: None.
output_ratio (Union[float, tuple[float]], optional): The ratio of target output shape to input shape.
Specifying the size of the output tensor by using a ratio of the input size.
Data type : float16, float32, double, and value is between (0, 1).
Default: None.
return_indices (bool, optional): If `return_indices` is True, the indices of max value would be output.
Default: False.
_random_samples (Tensor, optional): The random step of FractionalMaxPool3d, which is a 3D tensor.
Tensor of data type : float16, float32, double, and value is between (0, 1).
Supported shape :math:`(N, C, 3)`
Returns:
- **y** (Tensor) - A tensor, the output of FractionalMaxPool3d.
Has the same data type with `imput_x`.
Tensor of shape :math:`(N, C, D, H, W)` .
- **argmax** (Tensor) - The indices along with the outputs, which is a Tensor, with the same shape as the
`y` and int32 data type. It will output only when `return_indices` is True.
Raises:
TypeError: If `input_x` is not a 4D or 5D tensor.
TypeError: If `_random_samples` is not a 3D tensor.
TypeError: If data type of `imput_x` is not float16, float32, double, int32, int64.
TypeError: If dtype of `_random_samples` is not float16, float32, double.
TypeError: If dtype of `argmax` is not int32, int64.
ValueError: If `output_size` is a tuple and if `output_size` length is not 3.
ValueError: If `kernel_size` is a tuple and if `kernel_size` length is not 3.
ValueError: If numbers in `output_size` or `kernel_size` is not positive.
ValueError: if `output_size` and `output_ratio` are None at the same time.
ValueError: If the first dimension size of `input_x` and `_random_samples` is not equal.
ValueError: If the second dimension size of `input_x` and `_random_samples` is not equal.
ValueError: If the third dimension size of `_random_samples` is not 3.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
... .reshape([1, 1, 2, 2, 4]), mstype.float32)
>>> _random_samples = Tensor(np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3]), mstype.float32)
>>> output, argmax = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_size=(1, 1, 3),
... _random_samples=_random_samples, return_indices=True)
>>> print(output)
[[[[[13. 14. 16.]]]]]
>>> print(argmax)
[[[[[12 13 15]]]]]
>>> output, argmax = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
... _random_samples=_random_samples, return_indices=True)
>>> print(output)
[[[[[13. 16.]]]]]
>>> print(argmax)
[[[[[12 15]]]]]
"""
if output_ratio is not None and output_size is not None or output_ratio is None and output_size is None:
raise ValueError(f"For fractional_max_pool2d, 'output_size' and 'output_ratio' can not be specified or None"
f"at the same time, but got {output_ratio} and {output_size} .")
if len(input_x.shape) == 4:
input_x.expend_dims(axis=0)
if _random_samples is None:
_random_samples = Tensor([[[0, 0, 0]]], mstype.float32)
if output_ratio is not None:
if isinstance(output_ratio, float):
output_ratio = (output_ratio, output_ratio, output_ratio)
_check_float_range_inc_right(output_ratio[0], 0.0, 1.0)
_check_float_range_inc_right(output_ratio[1], 0.0, 1.0)
_check_float_range_inc_right(output_ratio[2], 0.0, 1.0)
output_size = (int(input_x.shape[-3] * output_ratio[0]), int(input_x.shape[-2] * output_ratio[1]),
int(input_x.shape[-1] * output_ratio[2]))
fractional_max_pool = FractionalMaxPool3DWithFixedKsize(kernel_size, output_size)
output = fractional_max_pool(input_x, _random_samples)
if return_indices:
return output
return output[0]
def kl_div(logits, labels, reduction='mean'):
r"""
Computes the Kullback-Leibler divergence between the logits and the labels.
@ -4810,6 +5022,8 @@ __all__ = [
'dropout2d',
'dropout3d',
'fast_gelu',
'fractional_max_pool2d',
'fractional_max_pool3d',
'pixel_shuffle',
'pixel_unshuffle',
'hardshrink',

View File

@ -10318,7 +10318,7 @@ class FractionalMaxPoolWithFixedKsize(Primitive):
ValueError: If the third dimension size of `random_samples` is not 2.
Supported Platforms:
``Ascend`` ``CPU``
``CPU``
Examples:
>>> # the ksize is an int number and the output_shape is a tuple.

View File

@ -26,7 +26,7 @@ class FractionalMaxPool2dNet(nn.Cell):
def __init__(self):
super(FractionalMaxPool2dNet, self).__init__()
_random_samples = Tensor(np.array([[[0.8, 0.8]]]), mstype.float32)
_random_samples = Tensor(np.array([[[0.0, 0.0]]]), mstype.float32)
self.pool1 = nn.FractionalMaxPool2d(kernel_size=2, output_size=(2, 2), _random_samples=_random_samples,
return_indices=True)
self.pool2 = nn.FractionalMaxPool2d(kernel_size=2, output_ratio=(0.5, 0.5), _random_samples=_random_samples,
@ -77,7 +77,7 @@ class FractionalMaxPool3dNet(nn.Cell):
def __init__(self):
super(FractionalMaxPool3dNet, self).__init__()
_random_samples = Tensor(np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3]), mstype.float32)
_random_samples = Tensor(np.array([0.0, 0.0, 0.0]).reshape([1, 1, 3]), mstype.float32)
self.pool1 = nn.FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), _random_samples=_random_samples,
output_size=(1, 1, 2), return_indices=True)
self.pool2 = nn.FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
@ -92,6 +92,7 @@ class FractionalMaxPool3dNet(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fractional_maxpool3d_normal(mode):

View File

@ -1,67 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore as ms
class FractionalMaxPool3dNet(nn.Cell):
"""FractionalMaxPool3d"""
def __init__(self):
super(FractionalMaxPool3dNet, self).__init__()
_random_samples = Tensor(np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3]), mstype.float32)
self.pool1 = nn.FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), _random_samples=_random_samples,
output_size=(1, 1, 2), return_indices=True)
self.pool2 = nn.FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
_random_samples=_random_samples, return_indices=True)
def construct(self, x):
output1 = self.pool1(x)
output2 = self.pool2(x)
return output1, output2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fractional_maxpool3d_normal(mode):
"""
Feature: Test FractioanlMaxPool3d
Description: Test the functionality of FractionalMaxPool3d
Expectation: Success
"""
ms.set_context(mode=mode)
input_x = Tensor(np.random.rand(16).reshape([1, 1, 2, 2, 4]), mstype.float32)
net = FractionalMaxPool3dNet()
output1, output2 = net(input_x)
assert output1[0].shape == output1[1].shape == (1, 1, 1, 1, 2)
assert output2[0].shape == output2[1].shape == (1, 1, 1, 1, 2)
input_x = Tensor([[[[[5.76273143e-001, 7.97047436e-001, 5.05385816e-001, 7.98332036e-001],
[5.79880655e-001, 9.75979388e-001, 3.17571498e-002, 8.08261558e-002]],
[[3.82758647e-001, 7.09801614e-001, 4.39641386e-001, 5.71077049e-001],
[9.16305065e-001, 3.71438652e-001, 6.52868748e-001, 6.91260636e-001]]]]], mstype.float32)
output1, output2 = net(input_x)
expect_output_y = np.array([[[[[9.16305065e-001, 6.91260636e-001]]]]])
expect_output_argmax = np.array([[[[[12, 15]]]]])
assert np.allclose(output1[0].asnumpy(), expect_output_y)
assert np.allclose(output1[1].asnumpy(), expect_output_argmax)
assert np.allclose(output2[0].asnumpy(), expect_output_y)
assert np.allclose(output2[1].asnumpy(), expect_output_argmax)

View File

@ -0,0 +1,106 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import ops
import mindspore.common.dtype as mstype
import mindspore as ms
class FractionalMaxPool2dNet(nn.Cell):
"""FractionalMaxPool2d ops"""
def construct(self, x):
output1 = ops.fractional_max_pool2d(x, kernel_size=2, output_size=(2, 2), return_indices=True)
output2 = ops.fractional_max_pool2d(x, kernel_size=2, output_ratio=(0.5, 0.5), return_indices=True)
return output1, output2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fractional_maxpool2d_normal(mode):
"""
Feature: FractionalMaxPool2d
Description: Verify the result of FractionalMaxPool2d
Expectation: success
"""
ms.set_context(mode=mode)
net = FractionalMaxPool2dNet()
input_x = Tensor(np.random.rand(25).reshape([1, 1, 5, 5]), mstype.float32)
output1, output2 = net(input_x)
assert output1[0].shape == output1[1].shape == (1, 1, 2, 2)
assert output2[0].shape == output2[1].shape == (1, 1, 2, 2)
input_x = Tensor([[[[5.58954370e-001, 6.63938331e-001, 6.21228504e-001, 2.42979444e-001, 3.76893662e-001],
[1.81983045e-003, 3.52343421e-001, 4.62048613e-001, 1.10343760e-001, 1.39571702e-001],
[4.99799584e-001, 4.64907907e-001, 6.20357162e-001, 3.59420753e-001, 1.26215309e-001],
[7.71829579e-002, 4.58553624e-001, 3.58015698e-001, 3.53923170e-001, 1.75972716e-001],
[5.65106732e-001, 6.46603699e-001, 6.05013040e-001, 3.82114821e-001, 4.62306777e-003]]]],
mstype.float32)
output1, output2 = net(input_x)
expect_output_y = np.array([[[[6.63938344e-001, 3.76893669e-001],
[6.46603703e-001, 3.82114828e-001]]]])
expect_output_argmax = np.array([[[[1, 4],
[21, 23]]]])
assert np.allclose(output1[0].asnumpy(), expect_output_y)
assert np.allclose(output1[1].asnumpy(), expect_output_argmax)
assert np.allclose(output2[0].asnumpy(), expect_output_y)
assert np.allclose(output2[1].asnumpy(), expect_output_argmax)
class FractionalMaxPool3dNet(nn.Cell):
"""FractionalMaxPool3d ops"""
def construct(self, x):
output1 = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_size=(1, 1, 2), return_indices=True)
output2 = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
return_indices=True)
return output1, output2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fractional_maxpool3d_normal(mode):
"""
Feature: Test FractioanlMaxPool3d
Description: Test the functionality of FractionalMaxPool3d
Expectation: Success
"""
ms.set_context(mode=mode)
input_x = Tensor(np.random.rand(16).reshape([1, 1, 2, 2, 4]), mstype.float32)
net = FractionalMaxPool3dNet()
output1, output2 = net(input_x)
assert output1[0].shape == output1[1].shape == (1, 1, 1, 1, 2)
assert output2[0].shape == output2[1].shape == (1, 1, 1, 1, 2)
input_x = Tensor([[[[[5.76273143e-001, 7.97047436e-001, 5.05385816e-001, 7.98332036e-001],
[5.79880655e-001, 9.75979388e-001, 3.17571498e-002, 8.08261558e-002]],
[[3.82758647e-001, 7.09801614e-001, 4.39641386e-001, 5.71077049e-001],
[9.16305065e-001, 3.71438652e-001, 6.52868748e-001, 6.91260636e-001]]]]], mstype.float32)
output1, output2 = net(input_x)
expect_output_y = np.array([[[[[9.16305065e-001, 6.91260636e-001]]]]])
expect_output_argmax = np.array([[[[[12, 15]]]]])
assert np.allclose(output1[0].asnumpy(), expect_output_y)
assert np.allclose(output1[1].asnumpy(), expect_output_argmax)
assert np.allclose(output2[0].asnumpy(), expect_output_y)
assert np.allclose(output2[1].asnumpy(), expect_output_argmax)

View File

@ -0,0 +1,75 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test fractional maxpooling ops
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import ops
from mindspore.common.api import _cell_graph_executor
import mindspore.common.dtype as mstype
class FractionalMaxPool2dNet(nn.Cell):
"""fractional_max_pool2d"""
def construct(self, x, _random_samples):
output1 = ops.fractional_max_pool2d(x, kernel_size=2, output_size=(2, 2), _random_samples=_random_samples,
return_indices=True)
output2 = ops.fractional_max_pool2d(x, kernel_size=2, output_ratio=(0.5, 0.5), _random_samples=_random_samples,
return_indices=True)
return output1, output2
def test_compile_fractional_maxpool2d():
"""
Feature: Test fractional_max_pool2d
Description: Test the functionality of fractional_max_pool2d
Expectation: Success
"""
input_x = Tensor(np.array([0.3220, 0.9545, 0.7879, 0.0975, 0.3698,
0.5135, 0.5740, 0.3435, 0.1895, 0.8764,
0.9581, 0.4760, 0.9014, 0.8522, 0.3664,
0.4980, 0.9673, 0.9879, 0.6988, 0.9022,
0.9304, 0.1558, 0.0153, 0.1559, 0.9852]).reshape([1, 1, 5, 5]), mstype.float32)
_random_samples = Tensor(np.array([[[0.0, 0.0]]]), mstype.float32)
net = FractionalMaxPool2dNet()
_cell_graph_executor.compile(net, input_x, _random_samples)
class FractionalMaxPool3dNet(nn.Cell):
"""fractional_max_pool3d"""
def construct(self, x, _random_samples):
output1 = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_size=(1, 1, 2),
_random_samples=_random_samples, return_indices=True)
output2 = ops.fractional_max_pool3d(x, kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5),
_random_samples=_random_samples, return_indices=True)
return output1, output2
def test_compile_fractional_maxpool3d():
"""
Feature: Test fractional_max_pool3d
Description: Test the functionality of fractional_max_pool3d
Expectation: Success
"""
input_x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
.reshape([1, 1, 2, 2, 4]), mstype.float32)
_random_samples = Tensor(np.array([0.0, 0.0, 0.0]).reshape([1, 1, 3]), mstype.float32)
net = FractionalMaxPool3dNet()
_cell_graph_executor.compile(net, input_x, _random_samples)