fix lppool

This commit is contained in:
fengyihang 2022-10-27 12:40:13 +08:00
parent 8d04817a64
commit 98fb8174bf
8 changed files with 24 additions and 8 deletions

View File

@ -32,5 +32,6 @@ mindspore.nn.LPPool1d
- **TypeError** - `x` 不是Tensor。 - **TypeError** - `x` 不是Tensor。
- **TypeError** - `kernel_size``stride` 不是int。 - **TypeError** - `kernel_size``stride` 不是int。
- **TypeError** - `ceil_mode` 不是bool。 - **TypeError** - `ceil_mode` 不是bool。
- **TypeError** - `norm_type` 不是float也不是int。
- **ValueError** - `kernel_size``stride` 小于1。 - **ValueError** - `kernel_size``stride` 小于1。
- **ValueError** - `x` 的shape长度不等于2或3。 - **ValueError** - `x` 的shape长度不等于2或3。

View File

@ -32,6 +32,7 @@ mindspore.nn.LPPool2d
- **TypeError** - `x` 不是Tensor。 - **TypeError** - `x` 不是Tensor。
- **TypeError** - `kernel_size``stride` 不是int也不是tuple。 - **TypeError** - `kernel_size``stride` 不是int也不是tuple。
- **TypeError** - `ceil_mode` 不是bool。 - **TypeError** - `ceil_mode` 不是bool。
- **TypeError** - `norm_type` 不是float也不是int。
- **ValueError** - `kernel_size``stride` 小于1。 - **ValueError** - `kernel_size``stride` 小于1。
- **ValueError** - `kernel_size``stride` 是一个长度不为2的tuple。 - **ValueError** - `kernel_size``stride` 是一个长度不为2的tuple。
- **ValueError** - `x` 的shape长度不等于4。 - **ValueError** - `x` 的shape长度不等于4。

View File

@ -30,5 +30,6 @@ mindspore.ops.lp_pool1d
- **TypeError** - `x` 不是Tensor。 - **TypeError** - `x` 不是Tensor。
- **TypeError** - `kernel_size``stride` 不是int。 - **TypeError** - `kernel_size``stride` 不是int。
- **TypeError** - `ceil_mode` 不是bool。 - **TypeError** - `ceil_mode` 不是bool。
- **TypeError** - `norm_type` 不是float也不是int。
- **ValueError** - `kernel_size``stride` 小于1。 - **ValueError** - `kernel_size``stride` 小于1。
- **ValueError** - `x` 的shape长度不等于2或3。 - **ValueError** - `x` 的shape长度不等于2或3。

View File

@ -30,6 +30,7 @@ mindspore.ops.lp_pool2d
- **TypeError** - `x` 不是Tensor。 - **TypeError** - `x` 不是Tensor。
- **TypeError** - `kernel_size``stride` 不是int也不是tuple。 - **TypeError** - `kernel_size``stride` 不是int也不是tuple。
- **TypeError** - `ceil_mode` 不是bool。 - **TypeError** - `ceil_mode` 不是bool。
- **TypeError** - `norm_type` 不是float也不是int。
- **ValueError** - `kernel_size``stride` 小于1。 - **ValueError** - `kernel_size``stride` 小于1。
- **ValueError** - `kernel_size``stride` 是一个长度不为2的tuple。 - **ValueError** - `kernel_size``stride` 是一个长度不为2的tuple。
- **ValueError** - `x` 的shape长度不等于4。 - **ValueError** - `x` 的shape长度不等于4。

View File

@ -116,6 +116,7 @@ class LPPool1d(Cell):
TypeError: If `x` is not an Tensor. TypeError: If `x` is not an Tensor.
TypeError: If `kernel_size` or `stride` is not an int. TypeError: If `kernel_size` or `stride` is not an int.
TypeError: If `ceil_mode` is not a bool. TypeError: If `ceil_mode` is not a bool.
TypeError: If `norm_type` is neither float nor int.
ValueError: If `kernel_size` or `stride` is less than 1. ValueError: If `kernel_size` or `stride` is less than 1.
ValueError: If length of shape of `x` is not equal to 2 or 3. ValueError: If length of shape of `x` is not equal to 2 or 3.
@ -147,7 +148,7 @@ class LPPool1d(Cell):
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
def construct(self, x): def construct(self, x):
return ops.lp_pool1d(x, float(self.norm_type), self.kernel_size, return ops.lp_pool1d(x, self.norm_type, self.kernel_size,
self.stride, self.ceil_mode) self.stride, self.ceil_mode)
@ -189,6 +190,7 @@ class LPPool2d(Cell):
TypeError: If `x` is not an Tensor. TypeError: If `x` is not an Tensor.
TypeError: If `kernel_size` or `stride` is neither int nor tuple. TypeError: If `kernel_size` or `stride` is neither int nor tuple.
TypeError: If `ceil_mode` is not a bool. TypeError: If `ceil_mode` is not a bool.
TypeError: If `norm_type` is neither float nor int.
ValueError: If `kernel_size` or `stride` is less than 1. ValueError: If `kernel_size` or `stride` is less than 1.
ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`. ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`.
ValueError: If length of shape of `x` is not equal to 4. ValueError: If length of shape of `x` is not equal to 4.
@ -227,7 +229,7 @@ class LPPool2d(Cell):
self.ceil_mode = ceil_mode self.ceil_mode = ceil_mode
def construct(self, x): def construct(self, x):
return ops.lp_pool2d(x, float(self.norm_type), self.kernel_size, return ops.lp_pool2d(x, self.norm_type, self.kernel_size,
self.stride, self.ceil_mode) self.stride, self.ceil_mode)

View File

@ -3572,6 +3572,7 @@ def lp_pool1d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
TypeError: If `x` is not an Tensor. TypeError: If `x` is not an Tensor.
TypeError: If `kernel_size` or `stride` is not an int. TypeError: If `kernel_size` or `stride` is not an int.
TypeError: If `ceil_mode` is not a bool. TypeError: If `ceil_mode` is not a bool.
TypeError: If `norm_type` is neither float nor int.
ValueError: If `kernel_size` or `stride` is less than 1. ValueError: If `kernel_size` or `stride` is less than 1.
ValueError: If length of shape of `x` is not equal to 2 or 3. ValueError: If length of shape of `x` is not equal to 2 or 3.
@ -3594,6 +3595,10 @@ def lp_pool1d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
[63. 66.]]] [63. 66.]]]
""" """
_shape_check(x.shape, [2, 3], "lp_pool1d") _shape_check(x.shape, [2, 3], "lp_pool1d")
if isinstance(norm_type, (float, int)):
norm_type = float(norm_type)
else:
raise TypeError(f"For lp_pool1d, the type of 'norm_type' must be float or int, but got {type(norm_type)}")
sign = _get_cache_prim(ops.Sign)() sign = _get_cache_prim(ops.Sign)()
squeeze = _get_cache_prim(ops.Squeeze)(0) squeeze = _get_cache_prim(ops.Squeeze)(0)
expand_dims = _get_cache_prim(ops.ExpandDims)() expand_dims = _get_cache_prim(ops.ExpandDims)()
@ -3647,6 +3652,7 @@ def lp_pool2d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
TypeError: If `x` is not an Tensor. TypeError: If `x` is not an Tensor.
TypeError: If `kernel_size` or `stride` is neither int nor tuple. TypeError: If `kernel_size` or `stride` is neither int nor tuple.
TypeError: If `ceil_mode` is not a bool. TypeError: If `ceil_mode` is not a bool.
TypeError: If `norm_type` is neither float nor int.
ValueError: If `kernel_size` or `stride` is less than 1. ValueError: If `kernel_size` or `stride` is less than 1.
ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`. ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`.
ValueError: If length of shape of `x` is not equal to 4. ValueError: If length of shape of `x` is not equal to 4.
@ -3677,6 +3683,10 @@ def lp_pool2d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
""" """
_shape_check(x.shape, [4], "lp_pool2d") _shape_check(x.shape, [4], "lp_pool2d")
if isinstance(norm_type, (float, int)):
norm_type = float(norm_type)
else:
raise TypeError(f"For lp_pool2d, the type of 'norm_type' must be float or int, but got {type(norm_type)}")
sign = _get_cache_prim(ops.Sign)() sign = _get_cache_prim(ops.Sign)()
if not isinstance(x, tuple): if not isinstance(x, tuple):
kernel_size = tuple((kernel_size, kernel_size)) kernel_size = tuple((kernel_size, kernel_size))

View File

@ -38,10 +38,10 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_lppool1d_normal(mode): def test_lppool2d_normal(mode):
""" """
Feature: LPPool1d Feature: LPPool2d
Description: Verify the result of LPPool1d Description: Verify the result of LPPool2d
Expectation: success Expectation: success
""" """
ms.set_context(mode=mode) ms.set_context(mode=mode)

View File

@ -35,10 +35,10 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_lppool1d_normal(mode): def test_lppool2d_normal(mode):
""" """
Feature: LPPool1d Feature: LPPool2d
Description: Verify the result of LPPool1d Description: Verify the result of LPPool2d
Expectation: success Expectation: success
""" """
ms.set_context(mode=mode) ms.set_context(mode=mode)