forked from mindspore-Ecosystem/mindspore
added reflectionpad1d and 2d
This commit is contained in:
parent
4cbfd148eb
commit
c0b9d74987
|
@ -0,0 +1,39 @@
|
|||
mindspore.nn.ReflectionPad1d
|
||||
============================
|
||||
|
||||
.. py:class:: mindspore.nn.ReflectionPad1d(paddings)
|
||||
|
||||
根据 `paddings` 对输入 `x` 进行填充。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **paddings** (tuple/int) - 填充大小,如果输入为int, 则对所有边界进行相同大小的padding,如果是tuple,则为(pad_left, pad_right)。
|
||||
|
||||
.. code-block::
|
||||
|
||||
# 假设参数和输入如下:
|
||||
paddings = (3, 1)。
|
||||
x = [[[0, 1, 2, 3], [4, 5, 6, 7]]].
|
||||
# `x` 的第一个维度为1, 第二个维度为2, 第三个维度为4。
|
||||
# 输出的第一个维度不变。
|
||||
# 输出的第二个维度不变。
|
||||
# 输出的第三个维度为W_out = W_in + pad_left + pad_right = 4 + 3 + 1 = 8。
|
||||
# 所以最终的输出shape为(1, 2, 8)。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **x** (Tensor) - 输入Tensor, shape为:math:`(C, W_in)` 或:math:`(N, C, W_in)`。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,填充后的Tensor, shape为:math:`(C, W_out)`或:math:`(N, C, W_out)`。其中:math:`W_out = W_in + pad_left + pad_right`
|
||||
|
||||
- 对 `x` 使用对称轴进行对称复制的方式进行填充(复制时不包括对称轴)。例如 `x` 为[[[0, 1, 2, 3], [4, 5, 6, 7]]], `paddings` 为(3, 1),则输出为[[[3, 2, 1, 0, 1, 2, 3, 2], [7, 6, 5, 4, 5, 6, 7, 6]]]。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `padding` 不是tuple或integer。
|
||||
- **ValueError** - `padding` 中存在不是integer的元素
|
||||
- **ValueError** - `padding` 是tuple,且长度不能被2整除。
|
||||
- **ValueError** - `padding` 是tuple,且存在负数。
|
||||
- **ValueError** - `padding` 是tuple,且和tensor的维度不匹配。
|
|
@ -0,0 +1,40 @@
|
|||
mindspore.nn.ReflectionPad2d
|
||||
============================
|
||||
|
||||
.. py:class:: mindspore.nn.ReflectionPad2d(paddings)
|
||||
|
||||
根据 `paddings` 对输入 `x` 进行填充。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **paddings** (tuple/int) - 填充大小,如果输入为integer, 则对所有边界进行相同大小的padding,如果是tuple,则顺序为(pad_left, pad_right, pad_up, pad_down)。
|
||||
|
||||
.. code-block::
|
||||
|
||||
# 假设参数和输入如下:
|
||||
paddings = (1, 1, 2, 0).
|
||||
x = [[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]].
|
||||
# `x` 的第一个维度为1, 第二个维度为1, 第三个维度为3,第四个维度为3。
|
||||
# 输出的第一个维度不变。
|
||||
# 输出的第二个维度不变。
|
||||
# 输出的第三个维度为H_out = H_in + pad_up + pad_down = 3 + 1 + 1 = 5。
|
||||
# 输出的第四个维度为W_out = W_in + pad_left + pad_right = 3 + 2 + 0 = 5。
|
||||
# 所以最终的输出shape为(1, 1, 5, 5)
|
||||
|
||||
**输入:**
|
||||
|
||||
- **x** (Tensor) - 输入Tensor, shape为:math:`(C, H_in, W_in)`或:math:`(N, C, H_in, W_in)`。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,填充后的Tensor, shape为:math:`(C, H_out, W_out)`或:math:`(N, C, H_out, W_out)`。其中:math:`H_out = H_in + pad_up + pad_down`,:math:`W_out = W_in + pad_left + pad_right, H_out = H_in`
|
||||
|
||||
- 对 `x` 使用对称轴进行对称复制的方式进行填充(复制时不包括对称轴)。例如 `x` 为[[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]], `paddings` 为(1, 1, 2, 0),则输出为[[[[7., 6., 7., 8., 7.], [4., 3., 4., 5., 4.], [1., 0., 1., 2., 1.], [4., 3., 4., 5., 4.], [7., 6., 7., 8., 7.]]]]。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `padding` 不是tuple或integer。
|
||||
- **ValueError** - `padding` 中存在不是integer的元素
|
||||
- **ValueError** - `padding` 是tuple,且长度不能被2整除。
|
||||
- **ValueError** - `padding` 是tuple,且存在负数。
|
||||
- **ValueError** - `padding` 是tuple,且和tensor的维度不匹配。
|
|
@ -36,6 +36,7 @@ from .combined import *
|
|||
from .timedistributed import *
|
||||
from .thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
|
||||
from .padding import ConstantPad1d, ConstantPad2d, ConstantPad3d, ZeroPad2d
|
||||
from .reflectionpad import ReflectionPad1d, ReflectionPad2d
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(activation.__all__)
|
||||
|
@ -55,3 +56,4 @@ __all__.extend(combined.__all__)
|
|||
__all__.extend(timedistributed.__all__)
|
||||
__all__.extend(thor_layer.__all__)
|
||||
__all__.extend(padding.__all__)
|
||||
__all__.extend(reflectionpad.__all__)
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ReflectionPad"""
|
||||
from mindspore.common import Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['ReflectionPad1d', 'ReflectionPad2d']
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_padding_dimension(dimension, padding):
|
||||
r"""
|
||||
Validate the input padding and add place holders if needed.
|
||||
Note: the input 'padding' in this function is already converted to list of lists to match MirrorPad
|
||||
"""
|
||||
if dimension < len(padding):
|
||||
raise ValueError(f"For padding with length {len(padding) * 2}, the dimension of the tensor should be at least "
|
||||
f"{len(padding)}, but got {dimension}")
|
||||
# add place holders
|
||||
if dimension > len(padding):
|
||||
padding = [(0, 0) for _ in range(dimension - len(padding))] + [x for x in padding]
|
||||
return padding
|
||||
|
||||
|
||||
def _swap_to_ms_padding_order(padding):
|
||||
r"""
|
||||
Check whether the input padding is a tuple or a int converted to a tuple.
|
||||
Check if the length of padding is divisible by 2.
|
||||
Convert the input padding to the format that MirrorPad would understand.
|
||||
"""
|
||||
number_of_paddings = len(padding) // 2
|
||||
new_padding = [[0, 0]] * number_of_paddings
|
||||
for i in range(number_of_paddings):
|
||||
new_padding[i] = [padding[2 * i], padding[2 * i + 1]]
|
||||
# reverse the padding list to match the order of paddings for MirrorPad
|
||||
new_padding.reverse()
|
||||
return new_padding
|
||||
|
||||
|
||||
class _ReflectionPadNd(Cell):
|
||||
r"""
|
||||
Using a given padding to do reflection pad on the given tensor.
|
||||
Work as a parent class, and only accepts tuple as padding input.
|
||||
"""
|
||||
def __init__(self, padding, name="ReflectionPadNd"):
|
||||
super(_ReflectionPadNd, self).__init__()
|
||||
self.name = name
|
||||
# check if padding and its elements are valid
|
||||
if not isinstance(padding, tuple):
|
||||
raise TypeError(f"For '{self.name}' the input 'padding' must be an integer or tuple, "
|
||||
f"but got {type(padding).__name__}")
|
||||
if len(padding) % 2 != 0:
|
||||
raise ValueError(f"For '{self.name}' the length of input 'padding' must be divisible by 2, "
|
||||
f"but got padding of length {len(padding)}. ")
|
||||
if not all(isinstance(i, int) for i in padding):
|
||||
raise TypeError(f"For '{self.name}' every element in 'padding' must be integer, "
|
||||
f"but got {padding}. ")
|
||||
if not all(i >= 0 for i in padding):
|
||||
raise ValueError(f"For '{self.name}' every element in 'padding' must be >= 0. "
|
||||
f"but got {padding}. ")
|
||||
self.padding = _swap_to_ms_padding_order(padding)
|
||||
|
||||
def construct(self, x):
|
||||
input_shape = x.shape
|
||||
padding = _check_padding_dimension(len(input_shape), self.padding)
|
||||
x = ops.MirrorPad(mode='REFLECT')(x, Tensor(padding))
|
||||
return x
|
||||
|
||||
|
||||
class ReflectionPad1d(_ReflectionPadNd):
|
||||
r"""
|
||||
Using a given padding to do reflection pad on the last dimension of the given tensor.
|
||||
|
||||
Args:
|
||||
padding (union[int, tuple]): The padding size to pad the last dimension of input tensor.
|
||||
If padding is an integer: all directions will be padded with the same size.
|
||||
If padding is a tuple: uses (pad_left, pad_right, pad_up, pad_down) to pad.
|
||||
|
||||
Inputs:
|
||||
Tensor, 2D or 3D
|
||||
|
||||
Outputs:
|
||||
Tensor, after padding.
|
||||
Suppose the tensor has dimension (N, C, W_in), the padded dimension will become (N, C, W_out),
|
||||
where W_out = W_in + pad_left + pad_right
|
||||
|
||||
Raises:
|
||||
TypeError: If 'padding' is not a tuple or int.
|
||||
TypeError: If there is an element in 'padding' that is not int64.
|
||||
ValueError: If the length of 'padding' is not divisible by 2.
|
||||
ValueError: If there is an element in 'padding' that is negative.
|
||||
ValueError: If the there is a dimension mismatch between the padding and the tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, padding):
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding)
|
||||
super(ReflectionPad1d, self).__init__(padding, 'ReflectionPad1d')
|
||||
|
||||
|
||||
class ReflectionPad2d(_ReflectionPadNd):
|
||||
r"""
|
||||
Using a given padding to do reflection pad on the last dimension of the given tensor.
|
||||
|
||||
Args:
|
||||
padding (union[int, tuple]): The padding size to pad the last dimension of input tensor.
|
||||
If padding is an integer: all directions will be padded with the same size.
|
||||
If padding is a tuple: uses (pad_left, pad_right, pad_up, pad_down) to pad.
|
||||
|
||||
Inputs:
|
||||
Tensor, 3D or 4D
|
||||
|
||||
Output:
|
||||
Tensor, after padding.
|
||||
Suppose the tensor has dimension (N, C, H_in, W_in), the padded dimension will become (N, C, H_out, W_out),
|
||||
where W_out = W_in + pad_left + pad_right, H_out = H_in + pad_up + pad_down
|
||||
|
||||
Raises:
|
||||
TypeError: If 'padding' is not a tuple or int.
|
||||
TypeError: If there is an element in 'padding' that is not int64.
|
||||
ValueError: If the length of 'padding' is not divisible by 2.
|
||||
ValueError: If there is an element in 'padding' that is negative.
|
||||
ValueError: If the there is a dimension mismatch between the padding and the tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, padding):
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding, padding)
|
||||
super(ReflectionPad2d, self).__init__(padding, 'ReflectionPad2d')
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net1d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net1d, self).__init__()
|
||||
self.pad = nn.ReflectionPad1d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
class Net2d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net2d, self).__init__()
|
||||
self.pad = nn.ReflectionPad2d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_1d():
|
||||
"""
|
||||
Feature: ReflectionPad1d
|
||||
Description: Infer process of ReflectionPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 3D tensor input
|
||||
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32))
|
||||
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 2D tensor as input
|
||||
x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float16))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float16))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float16))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_2d():
|
||||
r"""
|
||||
Feature: ReflectionPad2d
|
||||
Description: Infer process of ReflectionPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.int32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 3D tensor as input
|
||||
x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import pytest
|
||||
|
||||
|
||||
class Net1d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net1d, self).__init__()
|
||||
self.pad = nn.ReflectionPad1d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
class Net2d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net2d, self).__init__()
|
||||
self.pad = nn.ReflectionPad2d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_1d():
|
||||
"""
|
||||
Feature: ReflectionPad1d
|
||||
Description: Infer process of ReflectionPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 3D tensor input
|
||||
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32))
|
||||
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 2D tensor as input
|
||||
x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float16))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float16))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float16))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_2d():
|
||||
r"""
|
||||
Feature: ReflectionPad2d
|
||||
Description: Infer process of ReflectionPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.int32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 3D tensor as input
|
||||
x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
|
@ -0,0 +1,111 @@
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import pytest
|
||||
|
||||
|
||||
class Net1d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net1d, self).__init__()
|
||||
self.pad = nn.ReflectionPad1d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
class Net2d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net2d, self).__init__()
|
||||
self.pad = nn.ReflectionPad2d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_1d():
|
||||
"""
|
||||
Feature: ReflectionPad1d
|
||||
Description: Infer process of ReflectionPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 3D tensor input
|
||||
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32))
|
||||
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 2D tensor as input
|
||||
x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float16))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float16))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float16))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reflection_pad_2d():
|
||||
r"""
|
||||
Feature: ReflectionPad2d
|
||||
Description: Infer process of ReflectionPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.int32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
# Test functionality with 3D tensor as input
|
||||
x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
|
@ -0,0 +1,158 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.nn import ReflectionPad1d
|
||||
from mindspore.nn import ReflectionPad2d
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
def test_reflection_pad_1d():
|
||||
"""
|
||||
Feature: ReflectionPad1d
|
||||
Description: Infer process of ReflectionPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
# Test functionality with 3D tensor input
|
||||
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = ReflectionPad1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32))
|
||||
|
||||
print(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32))
|
||||
net = ReflectionPad1d(padding)
|
||||
output = net(x)
|
||||
print(output, expected_output)
|
||||
|
||||
# Test functionality with 2D tensor as input
|
||||
x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = ReflectionPad1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2],
|
||||
[7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float32))
|
||||
print(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1],
|
||||
[6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float32))
|
||||
net = ReflectionPad1d(padding)
|
||||
output = net(x)
|
||||
print(output, expected_output)
|
||||
|
||||
|
||||
def test_reflection_pad_2d():
|
||||
r"""
|
||||
Feature: ReflectionPad2d
|
||||
Description: Infer process of ReflectionPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = ReflectionPad2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.float32))
|
||||
print(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
output = ReflectionPad2d(padding)(x)
|
||||
expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]]).astype(np.float32))
|
||||
print(output, expected_output)
|
||||
|
||||
# Test functionality with 3D tensor as input
|
||||
x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = ReflectionPad2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1],
|
||||
[4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32))
|
||||
print(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
output = ReflectionPad2d(padding)(x)
|
||||
|
||||
expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3],
|
||||
[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3],
|
||||
[2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32))
|
||||
print(output, expected_output)
|
||||
|
||||
|
||||
|
||||
def test_invalid_padding_reflection_pad_1d():
|
||||
"""
|
||||
Feature: ReflectionPad1d
|
||||
Description: test 5 cases of invalid input.
|
||||
Expectation: success
|
||||
"""
|
||||
# case 1: padding is not int or tuple
|
||||
padding = '-1'
|
||||
with pytest.raises(TypeError):
|
||||
ReflectionPad1d(padding)
|
||||
|
||||
# case 2: padding length is not divisible by 2
|
||||
padding = (1, 2, 2)
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad1d(padding)
|
||||
|
||||
# case 3: padding element is not int
|
||||
padding = ('2', 2)
|
||||
with pytest.raises(TypeError):
|
||||
ReflectionPad1d(padding)
|
||||
|
||||
# case 4: negative padding
|
||||
padding = (-1, 2)
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad1d(padding)
|
||||
|
||||
# case 5: padding dimension does not match tensor dimension
|
||||
padding = (1, 1, 1, 1, 1, 1, 1, 1)
|
||||
x = Tensor([[1, 2, 3], [1, 2, 3]])
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad1d(padding)(x)
|
||||
|
||||
|
||||
def test_invalid_padding_reflection_pad_2d():
|
||||
"""
|
||||
Feature: ReflectionPad2d
|
||||
Description: test 5 cases of invalid input.
|
||||
Expectation: success
|
||||
"""
|
||||
# case 1: padding is not int or tuple
|
||||
padding = '-1'
|
||||
with pytest.raises(TypeError):
|
||||
ReflectionPad2d(padding)
|
||||
|
||||
# case 2: padding length is not divisible by 2
|
||||
padding = (1, 2, 2)
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad2d(padding)
|
||||
|
||||
# case 3: padding element is not int
|
||||
padding = ('2', 2)
|
||||
with pytest.raises(TypeError):
|
||||
ReflectionPad2d(padding)
|
||||
|
||||
# case 4: negative padding
|
||||
padding = (-1, 2)
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad2d(padding)
|
||||
|
||||
# case 5: padding dimension does not match tensor dimension
|
||||
padding = (1, 1, 1, 1, 1, 1, 1, 1)
|
||||
x = Tensor([[1, 2, 3], [1, 2, 3]])
|
||||
with pytest.raises(ValueError):
|
||||
ReflectionPad2d(padding)(x)
|
Loading…
Reference in New Issue