forked from mindspore-Ecosystem/mindspore
fix lppool
This commit is contained in:
parent
8d04817a64
commit
98fb8174bf
|
@ -32,5 +32,6 @@ mindspore.nn.LPPool1d
|
|||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `kernel_size` 或 `stride` 不是int。
|
||||
- **TypeError** - `ceil_mode` 不是bool。
|
||||
- **TypeError** - `norm_type` 不是float也不是int。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 小于1。
|
||||
- **ValueError** - `x` 的shape长度不等于2或3。
|
|
@ -32,6 +32,7 @@ mindspore.nn.LPPool2d
|
|||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `kernel_size` 或 `stride` 不是int也不是tuple。
|
||||
- **TypeError** - `ceil_mode` 不是bool。
|
||||
- **TypeError** - `norm_type` 不是float也不是int。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 小于1。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 是一个长度不为2的tuple。
|
||||
- **ValueError** - `x` 的shape长度不等于4。
|
|
@ -30,5 +30,6 @@ mindspore.ops.lp_pool1d
|
|||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `kernel_size` 或 `stride` 不是int。
|
||||
- **TypeError** - `ceil_mode` 不是bool。
|
||||
- **TypeError** - `norm_type` 不是float也不是int。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 小于1。
|
||||
- **ValueError** - `x` 的shape长度不等于2或3。
|
|
@ -30,6 +30,7 @@ mindspore.ops.lp_pool2d
|
|||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `kernel_size` 或 `stride` 不是int也不是tuple。
|
||||
- **TypeError** - `ceil_mode` 不是bool。
|
||||
- **TypeError** - `norm_type` 不是float也不是int。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 小于1。
|
||||
- **ValueError** - `kernel_size` 或 `stride` 是一个长度不为2的tuple。
|
||||
- **ValueError** - `x` 的shape长度不等于4。
|
|
@ -116,6 +116,7 @@ class LPPool1d(Cell):
|
|||
TypeError: If `x` is not an Tensor.
|
||||
TypeError: If `kernel_size` or `stride` is not an int.
|
||||
TypeError: If `ceil_mode` is not a bool.
|
||||
TypeError: If `norm_type` is neither float nor int.
|
||||
ValueError: If `kernel_size` or `stride` is less than 1.
|
||||
ValueError: If length of shape of `x` is not equal to 2 or 3.
|
||||
|
||||
|
@ -147,7 +148,7 @@ class LPPool1d(Cell):
|
|||
self.ceil_mode = ceil_mode
|
||||
|
||||
def construct(self, x):
|
||||
return ops.lp_pool1d(x, float(self.norm_type), self.kernel_size,
|
||||
return ops.lp_pool1d(x, self.norm_type, self.kernel_size,
|
||||
self.stride, self.ceil_mode)
|
||||
|
||||
|
||||
|
@ -189,6 +190,7 @@ class LPPool2d(Cell):
|
|||
TypeError: If `x` is not an Tensor.
|
||||
TypeError: If `kernel_size` or `stride` is neither int nor tuple.
|
||||
TypeError: If `ceil_mode` is not a bool.
|
||||
TypeError: If `norm_type` is neither float nor int.
|
||||
ValueError: If `kernel_size` or `stride` is less than 1.
|
||||
ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`.
|
||||
ValueError: If length of shape of `x` is not equal to 4.
|
||||
|
@ -227,7 +229,7 @@ class LPPool2d(Cell):
|
|||
self.ceil_mode = ceil_mode
|
||||
|
||||
def construct(self, x):
|
||||
return ops.lp_pool2d(x, float(self.norm_type), self.kernel_size,
|
||||
return ops.lp_pool2d(x, self.norm_type, self.kernel_size,
|
||||
self.stride, self.ceil_mode)
|
||||
|
||||
|
||||
|
|
|
@ -3572,6 +3572,7 @@ def lp_pool1d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
TypeError: If `x` is not an Tensor.
|
||||
TypeError: If `kernel_size` or `stride` is not an int.
|
||||
TypeError: If `ceil_mode` is not a bool.
|
||||
TypeError: If `norm_type` is neither float nor int.
|
||||
ValueError: If `kernel_size` or `stride` is less than 1.
|
||||
ValueError: If length of shape of `x` is not equal to 2 or 3.
|
||||
|
||||
|
@ -3594,6 +3595,10 @@ def lp_pool1d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
[63. 66.]]]
|
||||
"""
|
||||
_shape_check(x.shape, [2, 3], "lp_pool1d")
|
||||
if isinstance(norm_type, (float, int)):
|
||||
norm_type = float(norm_type)
|
||||
else:
|
||||
raise TypeError(f"For lp_pool1d, the type of 'norm_type' must be float or int, but got {type(norm_type)}")
|
||||
sign = _get_cache_prim(ops.Sign)()
|
||||
squeeze = _get_cache_prim(ops.Squeeze)(0)
|
||||
expand_dims = _get_cache_prim(ops.ExpandDims)()
|
||||
|
@ -3647,6 +3652,7 @@ def lp_pool2d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
TypeError: If `x` is not an Tensor.
|
||||
TypeError: If `kernel_size` or `stride` is neither int nor tuple.
|
||||
TypeError: If `ceil_mode` is not a bool.
|
||||
TypeError: If `norm_type` is neither float nor int.
|
||||
ValueError: If `kernel_size` or `stride` is less than 1.
|
||||
ValueError: If `kernel_size` or `stride` is a tuple whose length is not equal to `2`.
|
||||
ValueError: If length of shape of `x` is not equal to 4.
|
||||
|
@ -3677,6 +3683,10 @@ def lp_pool2d(x, norm_type, kernel_size, stride=None, ceil_mode=False):
|
|||
|
||||
"""
|
||||
_shape_check(x.shape, [4], "lp_pool2d")
|
||||
if isinstance(norm_type, (float, int)):
|
||||
norm_type = float(norm_type)
|
||||
else:
|
||||
raise TypeError(f"For lp_pool2d, the type of 'norm_type' must be float or int, but got {type(norm_type)}")
|
||||
sign = _get_cache_prim(ops.Sign)()
|
||||
if not isinstance(x, tuple):
|
||||
kernel_size = tuple((kernel_size, kernel_size))
|
||||
|
|
|
@ -38,10 +38,10 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_lppool1d_normal(mode):
|
||||
def test_lppool2d_normal(mode):
|
||||
"""
|
||||
Feature: LPPool1d
|
||||
Description: Verify the result of LPPool1d
|
||||
Feature: LPPool2d
|
||||
Description: Verify the result of LPPool2d
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
|
|
|
@ -35,10 +35,10 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_lppool1d_normal(mode):
|
||||
def test_lppool2d_normal(mode):
|
||||
"""
|
||||
Feature: LPPool1d
|
||||
Description: Verify the result of LPPool1d
|
||||
Feature: LPPool2d
|
||||
Description: Verify the result of LPPool2d
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
|
|
Loading…
Reference in New Issue