!43448 for replicationpad
Merge pull request !43448 from 于振华/replicationpad_develop_1008
This commit is contained in:
commit
246b7c28c6
|
@ -217,6 +217,9 @@ Dropout层
|
|||
mindspore.nn.ConstantPad3d
|
||||
mindspore.nn.ReflectionPad1d
|
||||
mindspore.nn.ReflectionPad2d
|
||||
mindspore.nn.ReplicationPad1d
|
||||
mindspore.nn.ReplicationPad2d
|
||||
mindspore.nn.ReplicationPad3d
|
||||
mindspore.nn.ZeroPad2d
|
||||
|
||||
损失函数
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.nn.ReplicationPad1d
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.nn.ReplicationPad1d(padding)
|
||||
|
||||
根据 `padding` 对输入 `x` 的W维度上进行填充。
|
||||
|
||||
参数:
|
||||
- **padding** (union[int, tuple]) - 填充大小,如果输入为int,则对所有边界进行相同大小的填充;如果是tuple,则为(pad_left, pad_right)。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 维度为3D的Tensor。shape为 :math:`(N, C, W_{in})` 。
|
||||
|
||||
输出:
|
||||
Tensor,填充后的Tensor,shape为 :math:`(N, C, W_{out})` 。其中 :math:`W_{out} = W_{in} + pad\_left + pad\_right` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `padding` 不是tuple或int。
|
||||
- **TypeError** - `padding` 中存在不是int的元素。
|
||||
- **ValueError** - `padding` 是tuple,且长度不能被2整除。
|
||||
- **ValueError** - `padding` 是tuple,且长度和Tensor的维度不匹配。
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.nn.ReplicationPad2d
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.nn.ReplicationPad2d(padding)
|
||||
|
||||
根据 `padding` 对输入 `x` 的HW维度上进行填充。
|
||||
|
||||
参数:
|
||||
- **padding** (union[int, tuple]) - 填充大小,如果输入为int,则对所有边界进行相同大小的填充;如果是tuple,则顺序为 :math:`(pad_{left}, pad_{right}, pad_{up}, pad_{down})`。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 维度为4D的Tensor,shape为 :math:`(N, C, H_{in}, W_{in})` 。
|
||||
|
||||
输出:
|
||||
Tensor,填充后的Tensor,shape为 :math:`(N, C, H_{out}, W_{out})`。其中 :math:`H_{out} = H_{in} + pad_{up} + pad_{down}`, :math:`W_{out} = W_{in} + pad_{left} + pad_{right}` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `padding` 不是tuple或int。
|
||||
- **TypeError** - `padding` 中存在不是int的元素。
|
||||
- **ValueError** - `padding` 是tuple,且长度不能被2整除。
|
||||
- **ValueError** - `padding` 是tuple,且长度和Tensor的维度不匹配。
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.nn.ReplicationPad3d
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.nn.ReplicationPad3d(padding)
|
||||
|
||||
根据 `padding` 对输入 `x` 的DHW维度上进行填充。
|
||||
|
||||
参数:
|
||||
- **padding** (union[int, tuple]) - 填充大小,如果输入为int,则对所有边界进行相同大小的填充;如果是tuple,则顺序为 :math:`(pad_{left}, pad_{right}, pad_{up}, pad_{down})`。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 维度为5D的Tensor,shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})` 。
|
||||
|
||||
输出:
|
||||
Tensor,填充后的Tensor,shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`。其中 :math:`D_{out} = D_{in} + pad_{front} + pad_{back}`, :math:`H_{out} = H_{in} + pad_{up} + pad_{down}`, :math:`W_{out} = W_{in} + pad_{left} + pad_{right}`。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `padding` 不是tuple或int。
|
||||
- **TypeError** - `padding` 中存在不是int的元素。
|
||||
- **ValueError** - `padding` 是tuple,且长度不能被2整除。
|
||||
- **ValueError** - `padding` 是tuple,且长度和Tensor的维度不匹配。
|
|
@ -217,6 +217,9 @@ Padding Layer
|
|||
mindspore.nn.ConstantPad3d
|
||||
mindspore.nn.ReflectionPad1d
|
||||
mindspore.nn.ReflectionPad2d
|
||||
mindspore.nn.ReplicationPad1d
|
||||
mindspore.nn.ReplicationPad2d
|
||||
mindspore.nn.ReplicationPad3d
|
||||
mindspore.nn.ZeroPad2d
|
||||
|
||||
Loss Function
|
||||
|
|
|
@ -38,7 +38,7 @@ from mindspore.nn.layer.combined import *
|
|||
from mindspore.nn.layer.timedistributed import *
|
||||
from mindspore.nn.layer.thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
|
||||
from mindspore.nn.layer.padding import ConstantPad1d, ConstantPad2d, ConstantPad3d, ReflectionPad1d, \
|
||||
ReflectionPad2d, ZeroPad2d
|
||||
ReflectionPad2d, ZeroPad2d, ReplicationPad1d, ReplicationPad2d, ReplicationPad3d
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(activation.__all__)
|
||||
|
|
|
@ -17,10 +17,12 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.common import Tensor
|
||||
from mindspore import ops
|
||||
from mindspore.ops.operations import nn_ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ZeroPad2d']
|
||||
__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ZeroPad2d',
|
||||
'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d']
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -185,7 +187,7 @@ class _ConstantPadNd(Cell):
|
|||
raise ValueError(msg)
|
||||
|
||||
else:
|
||||
msg = "For '{}', the type of parameter 'padding' must be in [int, float], " \
|
||||
msg = "For '{}', the type of parameter 'padding' must be in [int, tuple], " \
|
||||
"but got {}".format(name, type(padding))
|
||||
raise TypeError(msg)
|
||||
|
||||
|
@ -585,3 +587,244 @@ class ZeroPad2d(_ConstantPadNd):
|
|||
def __init__(self, padding):
|
||||
value = 0
|
||||
super(ZeroPad2d, self).__init__(padding, value, name='ZeroPad2d')
|
||||
|
||||
|
||||
class _ReplicationPadNd(Cell):
|
||||
r"""
|
||||
Using a given padding to do replication pad on the given tensor.
|
||||
Work as a parent class, and only accepts tuple as padding input.
|
||||
"""
|
||||
def __init__(self, padding, name="ReplicationPadNd"):
|
||||
super(_ReplicationPadNd, self).__init__()
|
||||
self.name = name
|
||||
if not isinstance(padding, tuple):
|
||||
raise TypeError(f"For '{self.name}' the input 'padding' must be an integer or tuple, "
|
||||
f"but got {type(padding).__name__}")
|
||||
if len(padding) % 2 != 0:
|
||||
raise ValueError(f"For '{self.name}' the length of input 'padding' must be divisible by 2, "
|
||||
f"but got padding of length {len(padding)}. ")
|
||||
if not all(isinstance(i, int) for i in padding):
|
||||
raise TypeError(f"For '{self.name}' every element in 'padding' must be integer, "
|
||||
f"but got {padding}. ")
|
||||
self.padding = padding
|
||||
self.padv3 = nn_ops.PadV3(mode="edge")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _need_expend_dim(x):
|
||||
raise NotImplementedError
|
||||
|
||||
def construct(self, x):
|
||||
self._check_input_dim(x.shape, self.name)
|
||||
need_expend_dims = self._need_expend_dim(x)
|
||||
if need_expend_dims:
|
||||
x = x.expand_dims(0)
|
||||
x = self.padv3(x, self.padding)
|
||||
x = x.squeeze(0)
|
||||
else:
|
||||
x = self.padv3(x, self.padding)
|
||||
return x
|
||||
|
||||
|
||||
class ReplicationPad1d(_ReplicationPadNd):
|
||||
r"""
|
||||
Pad on W dimension of input `x` according to `padding`.
|
||||
|
||||
Args:
|
||||
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
||||
padding in all boundaries. If is tuple, uses :math:`(pad_{left}, pad_{right})` to pad.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - 2D or 3D, shape: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, after padding. Shape: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`,
|
||||
where :math:`W_{out} = W_{in} + pad_{left} + pad_{right}`
|
||||
|
||||
Raises:
|
||||
TypeError: If 'padding' is neither a tuple nor an int.
|
||||
TypeError: If there is an element in 'padding' that is not int.
|
||||
ValueError: If `padding` is tuple and the length of 'padding' is not divisible by 2.
|
||||
ValueError: If `padding` is tuple and there is a dimension mismatch between the padding and the tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples::
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.nn import ReplicationPad1d
|
||||
>>> x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
>>> pad1d = ReplicationPad1d(2)
|
||||
>>> input = Tensor(np.arange(0, 8).reshape(1, 2, 4), mindspore.float32)
|
||||
>>> input
|
||||
Tensor(shape=[1, 2, 4], dtype=Float32, value=
|
||||
[[[0., 1., 2., 3.],
|
||||
[4., 5., 6., 7.]]])
|
||||
>>> out = pad1d(input)
|
||||
>>> print(out)
|
||||
Tensor(shape=[1, 2, 8], dtype=Float32, value=
|
||||
[[[0., 0., 0., 1., 2., 3., 3., 3.],
|
||||
[4., 4., 4., 5., 6., 7., 7., 7.]]])
|
||||
>>> pad1d = ReplicationPad1d((3, 1))
|
||||
>>> out = pad1d(input)
|
||||
>>> print(out)
|
||||
Tensor(shape=[1, 2, 8], dtype=Float32, value=
|
||||
[[[0., 0., 0., 0., 1., 2., 3., 3.],
|
||||
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
||||
"""
|
||||
def __init__(self, padding):
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding)
|
||||
super(ReplicationPad1d, self).__init__(padding, name="ReplicationPad1d")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim not in (2, 3):
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 2 or 3 dims, but got {dim}.")
|
||||
|
||||
def _need_expend_dim(self, x):
|
||||
input_shape = x.shape
|
||||
return 1 if len(input_shape) == 2 else 0
|
||||
|
||||
|
||||
class ReplicationPad2d(_ReplicationPadNd):
|
||||
r"""
|
||||
Pad on HW dimension of input `x` according to `padding`.
|
||||
|
||||
Args:
|
||||
padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries.
|
||||
If a 4-`tuple`, uses :math:`(pad_{left}, pad_{right}, pad_{up}, pad_{down})` to pad.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - 3D or 4D, shape: :math:`(C, H_{in}, W_{out})` or :math:`(N, C, H_{out}, W_{out})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, after padding. Shape: :math:`(C, H_{out}, W_{out})` or :math:`(N, C, H_{out}, W_{out})`,
|
||||
where :math:`H_{out} = H_{in} + pad_{up} + pad_{down}`, :math:`W_{out} = W_{in} + pad_{left} + pad_{right}`.
|
||||
|
||||
Raises:
|
||||
TypeError: If 'padding' is neither a tuple nor an int.
|
||||
TypeError: If there is an element in 'padding' that is not int.
|
||||
ValueError: If `padding` is tuple and the length of 'padding' is not divisible by 2.
|
||||
ValueError: If `padding` is tuple and there is a dimension mismatch between the padding and the tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples::
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.nn import ReplicationPad2d
|
||||
>>> pad2d = ReplicationPad2d(2)
|
||||
>>> input = Tensor(np.arange(0, 9).reshape(1, 1, 3, 3), mindspore.float32)
|
||||
>>> input
|
||||
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
||||
[[[[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8.]]]])
|
||||
>>> out = pad2d(input)
|
||||
>>> print(out)
|
||||
Tensor(shape=[1, 1, 7, 7], dtype=Float32, value=
|
||||
[[[[0., 0., 0., 1., 2., 2., 2.],
|
||||
[0., 0., 0., 1., 2., 2., 2.],
|
||||
[0., 0., 0., 1., 2., 2., 2.],
|
||||
[3., 3., 3., 4., 5., 5., 5.],
|
||||
[6., 6., 6., 7., 8., 8., 8.],
|
||||
[6., 6., 6., 7., 8., 8., 8.],
|
||||
[6., 6., 6., 7., 8., 8., 8.]]]])
|
||||
>>> pad2d = nn.ReplicationPad2d((1, 1, 2, 0))
|
||||
>>> out = m(input)
|
||||
>>> print(out)
|
||||
Tensor(shape=[1, 1, 5, 5], dtype=Float32, value=
|
||||
[[[[0., 0., 1., 2., 2.],
|
||||
[0., 0., 1., 2., 2.],
|
||||
[0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.]]]])
|
||||
"""
|
||||
|
||||
def __init__(self, padding):
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding, padding)
|
||||
super(ReplicationPad2d, self).__init__(padding, name="ReplicationPad2d")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim not in (3, 4):
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 3 or 4 dims, but got {dim}.")
|
||||
|
||||
def _need_expend_dim(self, x):
|
||||
input_shape = x.shape
|
||||
return 1 if len(input_shape) == 3 else 0
|
||||
|
||||
|
||||
class ReplicationPad3d(_ReplicationPadNd):
|
||||
r"""
|
||||
Pad on DHW dimension of input `x` according to `padding`.
|
||||
|
||||
Args:
|
||||
padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries.
|
||||
If a 6-`tuple`, uses :math:`(pad_{left}, pad_{right}, pad_{up}, pad_{down}, pad_{front}, pad_{back})`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - 4D or 5D,
|
||||
shape: :math:`(C, D_{in}, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, after padding, shape: :math:`(C, D_{out}, H_{out}, W_{out})` or
|
||||
:math:`(N, C, D_{out}, H_{out}, W_{out})`, where
|
||||
:math:`D_{out} = D_{in} + pad_{front} + pad_{back}`,
|
||||
:math:`H_{out} = H_{in} + pad_{up} + pad_{down}`,
|
||||
:math:`W_{out} = W_{in} + pad_{left} + pad_{right}`.
|
||||
|
||||
Raises:
|
||||
TypeError: If 'padding' is neither a tuple nor an int.
|
||||
TypeError: If there is an element in 'padding' that is not int.
|
||||
ValueError: If `padding` is tuple and the length of 'padding' is not divisible by 2.
|
||||
ValueError: If `padding` is tuple and there is a dimension mismatch between the padding and the tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples::
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.nn import ReplicationPad3d
|
||||
>>> pad3d = ReplicationPad3d(1)
|
||||
>>> input = Tensor(np.arange(0, 9).reshape(1, 1, 1, 3, 3), mindspore.float32)
|
||||
>>> out = pad3d(input)
|
||||
>>> print(out)
|
||||
Tensor(shape=[1, 1, 7, 7], dtype=Float32, value=
|
||||
[[[[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]]]]])
|
||||
"""
|
||||
|
||||
def __init__(self, padding):
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding, padding, padding, padding)
|
||||
super(ReplicationPad3d, self).__init__(padding, name="ReplicationPad3d")
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, cls_name):
|
||||
dim = len(shape)
|
||||
if dim not in (4, 5):
|
||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 or 5 dims, but got {dim}.")
|
||||
|
||||
def _need_expend_dim(self, x):
|
||||
input_shape = x.shape
|
||||
return 1 if len(input_shape) == 4 else 0
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
class Net1d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net1d, self).__init__()
|
||||
self.pad = nn.ReplicationPad1d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
class Net2d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net2d, self).__init__()
|
||||
self.pad = nn.ReplicationPad2d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
class Net3d(nn.Cell):
|
||||
def __init__(self, padding):
|
||||
super(Net3d, self).__init__()
|
||||
self.pad = nn.ReplicationPad3d(padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.pad(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad1d_2d(mode):
|
||||
"""
|
||||
Feature: ReplicationPad1d
|
||||
Description: Infer process of ReplicationPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 2D tensor as input
|
||||
x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float16))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[0, 0, 0, 0, 1, 2, 3, 3],
|
||||
[4, 4, 4, 4, 5, 6, 7, 7]]).astype(np.float16))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[0, 0, 0, 1, 2, 3, 3, 3],
|
||||
[4, 4, 4, 5, 6, 7, 7, 7]]).astype(np.float16))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad1d_3d(mode):
|
||||
"""
|
||||
Feature: ReplicationPad1d
|
||||
Description: Infer process of ReplicationPad1d with 2 types of parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 3D tensor input
|
||||
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||
padding = (3, 1)
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[0, 0, 0, 0, 1, 2, 3, 3],
|
||||
[4, 4, 4, 4, 5, 6, 7, 7]]]).astype(np.float32))
|
||||
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
expected_output = Tensor(np.array([[[0, 0, 0, 1, 2, 3, 3, 3],
|
||||
[4, 4, 4, 5, 6, 7, 7, 7]]]).astype(np.float32))
|
||||
net = Net1d(padding)
|
||||
output = net(x)
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad2d_3d(mode):
|
||||
r"""
|
||||
Feature: ReplicationPad2d
|
||||
Description: Infer process of ReplicationPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 3D tensor as input
|
||||
x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[0, 0, 1, 2, 2], [0, 0, 1, 2, 2], [0, 0, 1, 2, 2],
|
||||
[3, 3, 4, 5, 5], [6, 6, 7, 8, 8]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[0, 0, 0, 1, 2, 2, 2], [0, 0, 0, 1, 2, 2, 2],
|
||||
[0, 0, 0, 1, 2, 2, 2], [3, 3, 3, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 8, 8, 8], [6, 6, 6, 7, 8, 8, 8],
|
||||
[6, 6, 6, 7, 8, 8, 8]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad2d_4d(mode):
|
||||
r"""
|
||||
Feature: ReplicationPad2d
|
||||
Description: Infer process of ReplicationPad2d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.int32))
|
||||
padding = (1, 1, 2, 0)
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[0, 0, 1, 2, 2], [0, 0, 1, 2, 2], [0, 0, 1, 2, 2],
|
||||
[3, 3, 4, 5, 5], [6, 6, 7, 8, 8]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 2
|
||||
net = Net2d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[0, 0, 0, 1, 2, 2, 2], [0, 0, 0, 1, 2, 2, 2],
|
||||
[0, 0, 0, 1, 2, 2, 2], [3, 3, 3, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 8, 8, 8], [6, 6, 6, 7, 8, 8, 8],
|
||||
[6, 6, 6, 7, 8, 8, 8]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad3d_4d(mode):
|
||||
r"""
|
||||
Feature: ReplicationPad3d
|
||||
Description: Infer process of ReplicationPad3d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 4D tensor as input
|
||||
x = Tensor(np.array([[[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]]).astype(np.int32))
|
||||
padding = (1, 1, 2, 0, 1, 1)
|
||||
net = Net3d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 1
|
||||
net = Net3d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]]]]]).astype(np.int32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_replicationpad3d_5d(mode):
|
||||
r"""
|
||||
Feature: ReplicationPad3d
|
||||
Description: Infer process of ReplicationPad3d with three type parameters.
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
# Test functionality with 5D tensor as input
|
||||
x = Tensor(np.array([[[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]]).astype(np.float32))
|
||||
padding = (1, 1, 2, 0, 1, 1)
|
||||
net = Net3d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.],
|
||||
[3., 3., 4., 5., 5.], [6., 6., 7., 8., 8.]]]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
||||
|
||||
padding = 1
|
||||
net = Net3d(padding)
|
||||
output = net(x)
|
||||
expected_output = Tensor(np.array([[[[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]],
|
||||
[[0., 0., 1., 2., 2.], [0., 0., 1., 2., 2.], [3., 3., 4., 5., 5.],
|
||||
[6., 6., 7., 8., 8.], [6., 6., 7., 8., 8.]]]]]).astype(np.float32))
|
||||
assert np.array_equal(output, expected_output)
|
Loading…
Reference in New Issue