rewrite pad ops

This commit is contained in:
fengyihang 2022-11-08 17:03:45 +08:00
parent 54c66072f3
commit ce2a9a5708
5 changed files with 268 additions and 218 deletions

View File

@ -1,38 +1,29 @@
mindspore.ops.pad
=================
.. py:function:: mindspore.ops.pad(input_x, paddings)
.. py:function:: mindspore.ops.pad(input_x, padding, mode='constant', value=None)
根据参数 `paddings` 对输入进行填充。
输出Tensor的shape计算公式如下
.. math::
\begin{aligned}
&\text{ input_x_shape} = (N_{1},N_{2},...,N_{n}) \\
&\begin{aligned}
\text{output_shape = }(&N_{1}+paddings[0,0]+paddings[0,1], \\
& N_{2}+paddings[1,0]+paddings[1,1], \\
&... , \\
& N_{n}+paddings[n-1,0]+paddings[n-1,1])
\end{aligned}
\end{aligned}
.. note::
仅当 `input_x` 为非动态shape时支持 `paddings` 里存在负数值。
参数:
- **input_x** (Tensor) - 输入Tensor。
- **paddings** (tuple) - 填充大小其shape为(N, 2)N是输入数据的维度填充的元素为int类型。对于 `x` 的第 `D` 个维度paddings[D, 0]表示输入Tensor的第 `D` 维度前面要扩展如果该值大于0或裁剪如果该值小于0的大小paddings[D, 1]表示在输入Tensor的第 `D` 个维度后面要扩展如果该值大于0或裁剪如果该值小于0的大小。
- **input_x** (Tensor) - 输入Tensorshape为 :math:`(N, *)` :math:`*` 代表任意附加维度。
- **padding** (Union[tuple[int], list[int], Tensor]) - pad的填充位置。
:math:`\left\lfloor\frac{\text{len(padding)}}{2}\right\rfloor` 维度的 `input_x` 将会被填充。
示例若只需要填充输入tensor的最后一个维度`padding` 则的填充方式为:math:`(\text{padding\_left}, \text{padding\_right})`;
示例若只需要填充输入tensor的最后两个维度`padding` 则的填充方式为:math:`(\text{padding\_left}, \text{padding\_right}, \text{padding\_top}, \text{padding\_bottom})`;
示例若只需要填充输入tensor的最后三个维度`padding` 则的填充方式为:math:`(\text{padding\_left}, \text{padding\_right}, \text{padding\_top}, \text{padding\_bottom}, \text{padding\_front}, \text{padding\_back}))`;
以此类推。
- **mode** (str可选) - Pad的填充模式可选择 "constant", "reflect" 或者 "replicate"。 默认值: "constant"。
对于 "constant" 模式,请参考 :class:`mindspore.nn.ConstantPad1d` 作为示例来理解这个填充模式并将这个模式扩展到n维。
对于 "reflect" 模式,请参考 :class:`mindspore.nn.ReflectionPad1d` 作为示例来理解这个填充模式并将这个模式扩展到n维。
对于 "replicate" 模式,请参考 :class:`mindspore.nn.ReplicationPad1d` 作为示例来理解这个填充模式并将这个模式扩展到n维。
- **value** (Union[int, float, None],可选) - 仅在 "constant" 模式下生效,设置在 "constant" 模式下的填充值如果值为None则会使用0作为默认填充值。
返回:
填充后的Tensor。
异常:
- **TypeError** - `paddings` 不是tuple。
- **TypeError** - `padding` 不是全为int的tuple或者list
- **TypeError** - `input_x` 不是Tensor。
- **ValueError** - `paddings` 的shape不是 :math:`(N, 2)`
- **ValueError** - `paddings` 的大小不等于2 * len(input_x)。
- **ValueError** - 计算出来的输出Tensor的shape里存在0或负数。
- **ValueError** - `paddings` 里存在负数值且 `input_x` 为动态shape。
- **ValueError** - `paddings` 的shape不等于 2 * len(input_x)。
- **ValueError** - mode不为"constant"并且value不为None。

View File

@ -23,7 +23,6 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import nn_ops as NN_OPS
from mindspore.ops.operations import image_ops as IMG
from mindspore.ops._utils import is_shape_unknown
import mindspore.common.dtype as mstype
from mindspore.ops.function.math_func import logsumexp
from mindspore.common.tensor import Tensor
@ -34,6 +33,7 @@ from mindspore._checkparam import Validator as validator
from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error
from mindspore.ops.operations.nn_ops import MaxUnpool2D, MaxUnpool3D
from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize
from mindspore.ops.operations.nn_ops import PadV3
slice_ = P.Slice()
fast_gelu_ = P.FastGeLU()
@ -41,6 +41,7 @@ softsign_ = P.Softsign()
hardswish_ = P.HSwish()
mish_ = NN_OPS.Mish()
selu_ = NN_OPS.SeLU()
scalar_to_tensor_ = P.ScalarToTensor()
sigmoid_ = NN_OPS.Sigmoid()
signed_type = [mstype.int8, mstype.byte, mstype.int16, mstype.short, mstype.int32, mstype.intc, mstype.int64,
mstype.intp, mstype.float16, mstype.half, mstype.float32, mstype.single, mstype.float64,
@ -2418,123 +2419,112 @@ def pdist(x, p=2.0):
@constexpr
def _check_pad_inputs(x_shape, paddings):
def _check_pad_inputs(padding):
"""check the input of pad"""
for _, pd in enumerate(paddings):
if not isinstance(pd, (list, tuple)) or len(pd) != 2 or not isinstance(pd[0], int) or \
not isinstance(pd[1], int):
raise TypeError(f"For 'pad', each element in 'paddings' must be a list or tuple of 2 int, but got {pd}.")
x_shape_unknown = is_shape_unknown(x_shape)
if not x_shape_unknown and len(x_shape) != len(paddings):
raise ValueError(f"For 'pad', the size of paddings must be 2 * {len(x_shape)}, but got {2 * len(paddings)}")
pad_all_non_negative = True
for _, pd in enumerate(paddings):
if pd[0] < 0 or pd[1] < 0:
pad_all_non_negative = False
if x_shape_unknown and not pad_all_non_negative:
# in this case, we can not infer the slice size
raise ValueError(f"For 'pad', if 'input_x' is dynamic shape, 'paddings' must be non-negative value, but got "
f"{paddings}")
if len(padding) % 2 != 0:
raise ValueError(f"For 'pad', the size of padding must be divisible by 2, but got {len(padding)}")
if not isinstance(padding, (tuple, list)):
raise TypeError(f"For 'pad', the type of 'paddings' must be a tuple of int or list of int or a Tensor,"
f" but got {type(padding)}.")
for pd in padding:
if not isinstance(pd, int):
raise TypeError(f"For 'pad', the paddings value must be tuple of int or list of int, but got {padding}")
def pad(input_x, paddings):
def pad(input_x, padding, mode='constant', value=None):
r"""
Pads the input tensor according to the paddings.
The formula to calculate the shape of the output tensor is as follows,
.. math::
\begin{aligned}
&\text{ input_x_shape} = (N_{1},N_{2},...,N_{n}) \\
&\begin{aligned}
\text{output_shape = }(&N_{1}+paddings[0,0]+paddings[0,1], \\
& N_{2}+paddings[1,0]+paddings[1,1], \\
&... , \\
& N_{n}+paddings[n-1,0]+paddings[n-1,1])
\end{aligned}
\end{aligned}
Note:
Negative `paddings` value is only supported when `input_x` is not dynamic shape.
Args:
input_x (Tensor): Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of additional dimensions.
paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
paddings are int type. For the input in `D` th dimension, paddings[D, 0] indicates how many sizes to be
extended(if this value > 0) or clipped(if this value < 0) ahead of the input tensor in the `D` th
dimension, and paddings[D, 1] indicates how many sizes to be extended(if this value > 0) or
clipped(if this value < 0) behind the input tensor in the `D` th dimension.
padding (Union[tuple[int], list[int], Tensor]): Filling position of pad.
:math:`\left\lfloor\frac{\text{len(padding)}}{2}\right\rfloor` dimensions
of `input_x` will be padded.
Example: to pad only the last dimension of the input tensor, then
:attr:`padding` has the form
:math:`(\text{padding\_left}, \text{padding\_right})`;
Example: to pad the last 2 dimensions of the input tensor, then use
:math:`(\text{padding\_left}, \text{padding\_right},`
:math:`\text{padding\_top}, \text{padding\_bottom})`;
Example: to pad the last 3 dimensions, use
:math:`(\text{padding\_left}, \text{padding\_right},`
:math:`\text{padding\_top}, \text{padding\_bottom}`
:math:`\text{padding\_front}, \text{padding\_back})`.and so on.
mode (str, optional): Pad filling mode, "constant", "reflect" or "replicate". Default: "constant".
For "constant" mode, please refer to :class:`mindspore.nn.ConstantPad1d` as an example to understand
this filling pattern and extend the padding pattern to n dimensions.
For "reflect" mode, please refer to :class:`mindspore.nn.ReflectionPad1d` as an example
and extend the padding pattern to n dimensions.
For "replicate" mode, please refer to :class:`mindspore.nn.ReplicationPad1d` as an example
and extend the padding pattern to n dimensions.
value (Union[int, float, None], optional): Valid only in "constant" mode, fill value for 'constant' padding,
if the value is None, the default value 0 is used.
Returns:
Tensor, the tensor after padding.
Raises:
TypeError: If `paddings` is not a tuple.
TypeError: If `paddings` is not an int of tuple or int of list.
TypeError: If `input_x` is not a Tensor.
ValueError: If shape of `paddings` is not :math:`(N, 2)`.
ValueError: If paddings.size is not equal to 2 * len(input_x).
ValueError: If the calculated output shape contains zero or negative dimension.
ValueError: If `paddings` contains negative value and `input_x` is dynamic shape.
ValueError: If padding.size is not equal to 2 * len(input_x).
ValueError: If mode is not "constant" and value not None.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> paddings = ((1, 2), (2, 1))
>>> output = ops.pad(input_x, paddings)
>>> print(output)
[[ 0. 0. 0. 0. 0. 0. ]
[ 0. 0. -0.1 0.3 3.6 0. ]
[ 0. 0. 0.4 0.5 -3.2 0. ]
[ 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. ]]
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.arange(1 * 2 * 2 * 2).reshape((1, 2, 2, 2)), dtype=ms.float64)
>>> output = ops.pad(x, [1, 0, 0, 1], mode='constant', value=6.0)
>>> print(x)
[[[[6. 0. 1.]
[6. 2. 3.]
[6. 6. 6.]]
[[6. 4. 5.]
[6. 6. 7.]
[6. 6. 6.]]]]
>>> output1 = ops.pad(x, (1, 0, 0, 1), mode='reflect')
>>> print(output1)
[[[[1. 0. 1.]
[3. 2. 3.]
[1. 0. 1.]]
[[5. 4. 5.]
[7. 6. 7.]
[5. 4. 5.]]]]
>>> output2 = ops.pad(x, (1, 1, 2, 1), mode='replicate')
[[[[0. 0. 1. 1.]
[0. 0. 1. 1.]
[0. 0. 1. 1.]
[2. 2. 3. 3.]
[2. 2. 3. 3.]]
[[4. 4. 5. 5.]
[4. 4. 5. 5.]
[4. 4. 5. 5.]
[6. 6. 7. 7.]
[6. 6. 7. 7.]]]]
"""
if not isinstance(input_x, Tensor):
raise TypeError(f"For 'pad', the type of 'input_x' must be Tensor, but got {type(input_x)}.")
if not isinstance(paddings, tuple):
raise TypeError(f"For 'pad', the type of 'paddings' must be tuple, but got {type(paddings)}.")
x_shape = input_x.shape
_check_pad_inputs(x_shape, paddings)
x_shape_unknown = is_shape_unknown(x_shape)
# input_x is dynamic shape
if x_shape_unknown:
_pad = _get_cache_prim(P.Pad)(paddings)
return _pad(input_x)
# input_x is static shape
pad_all_non_negative = True
pad_all_non_positive = True
slice_begin = []
slice_size = []
non_negative_padding = []
for i, pd in enumerate(paddings):
sz = x_shape[i] + pd[0]
if sz <= 0:
raise ValueError(f"For 'pad', input_x_shape[{i}] + paddings[{i}, 0] is {sz}, which is <= 0 and causes "
f"the output shape invalid.")
sz = sz + pd[1]
if sz <= 0:
raise ValueError(f"For 'pad', input_x_shape[{i}] + paddings[{i}, 0] + paddings[{i}, 1] is {sz}, which is "
f"<= 0 and causes the output shape invalid.")
slice_size.append(sz)
if pd[0] < 0:
slice_begin.append(abs(pd[0]))
else:
slice_begin.append(0)
if pd[0] < 0 or pd[1] < 0:
pad_all_non_negative = False
if pd[0] > 0 or pd[1] > 0:
pad_all_non_positive = False
non_negative_padding.append((max(0, pd[0]), max(0, pd[1])))
if pad_all_non_negative:
_pad = _get_cache_prim(P.Pad)(paddings)
return _pad(input_x)
if pad_all_non_positive:
return slice_(input_x, slice_begin, slice_size)
_pad = _get_cache_prim(P.Pad)(tuple(non_negative_padding))
out = _pad(input_x)
return slice_(out, slice_begin, slice_size)
if not isinstance(padding, Tensor):
_check_pad_inputs(padding)
if mode == "constant":
value = 0 if value is None else value
if isinstance(value, (float, int)):
value = scalar_to_tensor_(value, input_x.dtype)
else:
if value is not None:
raise ValueError(f"For 'pad', the padding mode '{mode}' can not set value, but got value {value}.")
if mode == "replicate":
mode = "edge"
out = PadV3(mode=mode, paddings_contiguous=True)(input_x, padding, value)
return out
def relu(x):

View File

@ -22,86 +22,6 @@ from mindspore import Tensor
from mindspore.ops.functional import vmap
class FuncNet(nn.Cell):
def __init__(self, paddings):
super(FuncNet, self).__init__()
self.paddings = paddings
def construct(self, x):
return ops.pad(x, self.paddings)
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.network = network
self.grad = ops.GradOperation()
def construct(self, x):
return self.grad(self.network)(x)
def run_case(x, paddings, expect):
net = FuncNet(paddings)
out_ms = net(Tensor(x))
assert np.allclose(expect, out_ms.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_function_cpu():
"""
Feature: test ops.Pad functional interface.
Description: paddings has negative values.
Expectation: the result match with numpy result.
"""
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
# case1: padding value are non negative
paddings1 = ((0, 1), (1, 0))
expect1 = np.array([[0, 1, 2, 3], [0, 4, 5, 6], [0, 7, 8, 9], [0, 0, 0, 0]], dtype=np.float32)
# case2: padding value are non positive
paddings2 = ((-1, 0), (-1, -1))
expect2 = np.array([[5], [8]], dtype=np.float32)
# case3: padding with positive and negative value
paddings3 = ((-1, 1), (1, -1))
expect3 = np.array([[0, 4, 5], [0, 7, 8], [0, 0, 0]], dtype=np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_case(x, paddings1, expect1)
run_case(x, paddings2, expect2)
run_case(x, paddings3, expect3)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
run_case(x, paddings1, expect1)
run_case(x, paddings2, expect2)
run_case(x, paddings3, expect3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad_function_grad_cpu():
"""
Feature: test ops.Pad functional interface backward.
Description: paddings has negative values.
Expectation: the result match with numpy result.
"""
paddings = ((1, -1), (1, -1))
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
expect = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = GradNet(FuncNet(paddings))
out_ms = net(Tensor(x))
assert np.allclose(expect, out_ms.asnumpy())
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
net = GradNet(FuncNet(paddings))
out_ms = net(Tensor(x))
assert np.allclose(expect, out_ms.asnumpy())
def vmap_case():
class Net(nn.Cell):
def __init__(self, paddings):

View File

@ -35,13 +35,14 @@ class PadNet(nn.Cell):
def run_case():
paddings = ((1, 0), (0, 2))
paddings_ms = (0, 2, 1, 0)
shape = (4, 4)
shape_dyn = (None, 4)
sz = reduce(lambda a, b: a * b, shape)
x = np.arange(sz).reshape(shape).astype(np.float32)
expect = np.pad(x, paddings, mode="constant", constant_values=0)
x_dyn = Tensor(shape=shape_dyn, dtype=mindspore.float32)
net = PadNet(paddings)
net = PadNet(paddings_ms)
net.set_inputs(x_dyn)
output = net(Tensor(x))
assert (output.asnumpy() == expect).all()
@ -75,19 +76,3 @@ def test_pad_dyn_gpu():
run_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
run_case()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pad_dyn_ascend():
"""
Feature: test Pad dynamic shape on Ascend.
Description: inputs is dynamic shape.
Expectation: the result match with expect
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
run_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
run_case()

View File

@ -0,0 +1,164 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def __init__(self, mode):
super(Net, self).__init__()
self.mode = mode
def construct(self, x, padding, value=None):
output = ops.pad(x, padding, self.mode, value)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"])
@pytest.mark.parametrize('padding', [[1, 2, 2, 1], (1, 2, 2, 1), ms.Tensor([1, 2, 2, 1])])
def test_pad_normal(mode, pad_mode, padding):
"""
Feature: pad
Description: Verify the result of pad
Expectation: success
"""
ms.set_context(mode=mode)
net = Net(pad_mode)
x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64)
if pad_mode == "constant":
output = net(x, padding, 6)
expect_output = np.array([[[[6., 6., 6., 6., 6., 6., 6.],
[6., 6., 6., 6., 6., 6., 6.],
[6., 0., 1., 2., 3., 6., 6.],
[6., 4., 5., 6., 7., 6., 6.],
[6., 8., 9., 10., 11., 6., 6.],
[6., 6., 6., 6., 6., 6., 6.]],
[[6., 6., 6., 6., 6., 6., 6.],
[6., 6., 6., 6., 6., 6., 6.],
[6., 12., 13., 14., 15., 6., 6.],
[6., 16., 17., 18., 19., 6., 6.],
[6., 20., 21., 22., 23., 6., 6.],
[6., 6., 6., 6., 6., 6., 6.]]]])
elif pad_mode == "reflect":
output = net(x, padding)
expect_output = np.array([[[[9., 8., 9., 10., 11., 10., 9.],
[5., 4., 5., 6., 7., 6., 5.],
[1., 0., 1., 2., 3., 2., 1.],
[5., 4., 5., 6., 7., 6., 5.],
[9., 8., 9., 10., 11., 10., 9.],
[5., 4., 5., 6., 7., 6., 5.]],
[[21., 20., 21., 22., 23., 22., 21.],
[17., 16., 17., 18., 19., 18., 17.],
[13., 12., 13., 14., 15., 14., 13.],
[17., 16., 17., 18., 19., 18., 17.],
[21., 20., 21., 22., 23., 22., 21.],
[17., 16., 17., 18., 19., 18., 17.]]]])
else:
output = net(x, padding)
expect_output = np.array([[[[0., 0., 1., 2., 3., 3., 3.],
[0., 0., 1., 2., 3., 3., 3.],
[0., 0., 1., 2., 3., 3., 3.],
[4., 4., 5., 6., 7., 7., 7.],
[8., 8., 9., 10., 11., 11., 11.],
[8., 8., 9., 10., 11., 11., 11.]],
[[12., 12., 13., 14., 15., 15., 15.],
[12., 12., 13., 14., 15., 15., 15.],
[12., 12., 13., 14., 15., 15., 15.],
[16., 16., 17., 18., 19., 19., 19.],
[20., 20., 21., 22., 23., 23., 23.],
[20., 20., 21., 22., 23., 23., 23.]]]])
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"])
@pytest.mark.parametrize('padding', [[-1, 2, 2, 1]])
def test_pad_negative(mode, pad_mode, padding):
"""
Feature: pad
Description: Verify the result of pad when padding is negative
Expectation: success
"""
ms.set_context(mode=mode)
net = Net(pad_mode)
x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64)
if pad_mode == "constant":
output = net(x, padding, 6)
expect_output = np.array([[[[6., 6., 6., 6., 6.],
[6., 6., 6., 6., 6.],
[1., 2., 3., 6., 6.],
[5., 6., 7., 6., 6.],
[9., 10., 11., 6., 6.],
[6., 6., 6., 6., 6.]],
[[6., 6., 6., 6., 6.],
[6., 6., 6., 6., 6.],
[13., 14., 15., 6., 6.],
[17., 18., 19., 6., 6.],
[21., 22., 23., 6., 6.],
[6., 6., 6., 6., 6.]]]])
elif pad_mode == "reflect":
output = net(x, padding)
expect_output = np.array([[[[9., 10., 11., 10., 9.],
[5., 6., 7., 6., 5.],
[1., 2., 3., 2., 1.],
[5., 6., 7., 6., 5.],
[9., 10., 11., 10., 9.],
[5., 6., 7., 6., 5.]],
[[21., 22., 23., 22., 21.],
[17., 18., 19., 18., 17.],
[13., 14., 15., 14., 13.],
[17., 18., 19., 18., 17.],
[21., 22., 23., 22., 21.],
[17., 18., 19., 18., 17.]]]])
else:
output = net(x, padding)
expect_output = np.array([[[[1., 2., 3., 3., 3.],
[1., 2., 3., 3., 3.],
[1., 2., 3., 3., 3.],
[5., 6., 7., 7., 7.],
[9., 10., 11., 11., 11.],
[9., 10., 11., 11., 11.]],
[[13., 14., 15., 15., 15.],
[13., 14., 15., 15., 15.],
[13., 14., 15., 15., 15.],
[17., 18., 19., 19., 19.],
[21., 22., 23., 23., 23.],
[21., 22., 23., 23., 23.]]]])
assert np.allclose(output.asnumpy(), expect_output)