!44399 pixel_shuffle_unshuffle_master
Merge pull request !44399 from yide12/pixel_shuffle_unshuffle_master
This commit is contained in:
commit
aa58bea98c
|
@ -356,6 +356,8 @@ Dynamic LR函数
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.nn.PixelShuffle
|
||||
mindspore.nn.PixelUnshuffle
|
||||
mindspore.nn.ResizeBilinear
|
||||
|
||||
工具
|
||||
|
|
|
@ -484,6 +484,8 @@ Parameter操作函数
|
|||
mindspore.ops.bounding_box_encode
|
||||
mindspore.ops.check_valid
|
||||
mindspore.ops.iou
|
||||
mindspore.ops.pixel_shuffle
|
||||
mindspore.ops.pixel_unshuffle
|
||||
|
||||
光谱函数
|
||||
----------------
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.nn.PixelShuffle
|
||||
==========================
|
||||
|
||||
.. py:class:: mindspore.nn.PixelShuffle(upscale_factor)
|
||||
|
||||
PixelShuffle函数。
|
||||
|
||||
在多个输入平面组成的输入上面应用PixelShuffle算法。在平面上应用高效亚像素卷积,步长为 :math:`1/r` 。关于PixelShuffle算法详细介绍,请参考 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 。
|
||||
|
||||
通常情况下,输入shape :math:`(*, C \times r^2, H, W)` ,输出shape :math:`(*, C, H \times r, W \times r)` 。`r` 是缩小因子。 `*` 是大于等于0的维度。
|
||||
|
||||
参数:
|
||||
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - Tensor,shape为 :math:`(*, C \times r^2, H, W)` 。输入Tensor的维度需要大于2,并且倒数第三维length可以被 `upscale_factor` 的平方整除。
|
||||
|
||||
输出:
|
||||
- **output** (Tensor) - Tensor,shape为 :math:`(*, C, H \times r, W \times r)` 。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `upscale_factor` 不是正整数。
|
||||
- **ValueError** - 输入 `x` 倒数第三维度的length不能被 `upscale_factor` 的平方整除。
|
||||
- **TypeError** - 输入 `x` 维度小于3。
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.nn.PixelUnshuffle
|
||||
============================
|
||||
|
||||
.. py:class:: mindspore.nn.PixelUnshuffle(downscale_factor)
|
||||
|
||||
PixelUnshuffle函数。
|
||||
|
||||
在多个输入平面组成的输入上面应用PixelUnshuffle算法。关于PixelUnshuffle算法详细介绍,请参考 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 。
|
||||
|
||||
通常情况下,输入shape :math:`(*, C, H \times r, W \times r)` ,输出shape :math:`(*, C \times r^2, H, W)` 。`r` 是缩小因子。 `*` 是大于等于0的维度。
|
||||
|
||||
参数:
|
||||
- **downscale_factor** (int) - 减小空间分辨率的因子,是正整数。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - Tensor,shape为 :math:`(*, C, H \times r, W \times r)` 。输入Tensor的维度需要大于2,并且倒数第一和倒数第二维length可以被 `downscale_factor` 整除。
|
||||
|
||||
输出:
|
||||
- **output** (Tensor) - Tensor,shape为 :math:`(*, C \times r^2, H, W)` 。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `downscale_factor` 不是正整数。
|
||||
- **ValueError** - 输入 `x` 倒数第一和倒数第二维度的length不能被 `downscale_factor` 整除。
|
||||
- **TypeError** - 输入 `x` 维度小于3。
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.ops.pixel_shuffle
|
||||
============================
|
||||
|
||||
.. py:function:: mindspore.ops.pixel_shuffle(x, upscale_factor)
|
||||
|
||||
pixel_shuffle函数。
|
||||
|
||||
在多个输入平面组成的输入上面应用pixel_shuffle算法。在平面上应用高效亚像素卷积,步长为 :math:`1/r` 。关于pixel_shuffle算法详细介绍,请参考 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 。
|
||||
|
||||
通常情况下,`x` shape :math:`(*, C \times r^2, H, W)` ,输出shape :math:`(*, C, H \times r, W \times r)` 。`r` 是缩小因子。 `*` 是大于等于0的维度。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - Tensor,shape为 :math:`(*, C \times r^2, H, W)` 。 `x` 的维度需要大于2,并且倒数第三维length可以被 `upscale_factor` 的平方整除。
|
||||
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。。
|
||||
|
||||
返回:
|
||||
- **output** (Tensor) - Tensor,shape为 :math:`(*, C, H \times r, W \times r)` 。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `upscale_factor` 不是正整数。
|
||||
- **ValueError** - `x` 倒数第三维度的length不能被 `upscale_factor` 的平方整除。
|
||||
- **TypeError** - `x` 维度小于3。
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.ops.pixel_unshuffle
|
||||
==============================
|
||||
|
||||
.. py:function:: mindspore.ops.pixel_unshuffle(x, downscale_factor)
|
||||
|
||||
pixel_unshuffle函数。
|
||||
|
||||
在多个输入平面组成的输入上面应用pixel_unshuffle算法。关于pixel_unshuffle算法详细介绍,请参考 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_ 。
|
||||
|
||||
通常情况下,`x` shape :math:`(*, C, H \times r, W \times r)` ,输出shape :math:`(*, C \times r^2, H, W)` 。`r` 是缩小因子。 `*` 是大于等于0的维度。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - Tensor,shape为 :math:`(*, C, H \times r, W \times r)` 。 `x` 的维度需要大于2,并且倒数第一和倒数第二维length可以被 `downscale_factor` 整除。
|
||||
- **downscale_factor** (int) - 减小空间分辨率的因子,是正整数。
|
||||
|
||||
返回:
|
||||
- **output** (Tensor) - Tensor,shape为 :math:`(*, C \times r^2, H, W)` 。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `downscale_factor` 不是正整数。
|
||||
- **ValueError** - `x` 倒数第一和倒数第二维度的length不能被 `downscale_factor` 整除。
|
||||
- **TypeError** - `x` 维度小于3。
|
|
@ -359,6 +359,8 @@ Image Processing Layer
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.nn.PixelShuffle
|
||||
mindspore.nn.PixelUnshuffle
|
||||
mindspore.nn.ResizeBilinear
|
||||
|
||||
Tools
|
||||
|
|
|
@ -484,6 +484,8 @@ Image Functions
|
|||
mindspore.ops.bounding_box_encode
|
||||
mindspore.ops.check_valid
|
||||
mindspore.ops.iou
|
||||
mindspore.ops.pixel_shuffle
|
||||
mindspore.ops.pixel_unshuffle
|
||||
|
||||
Spectral Functions
|
||||
------------------
|
||||
|
|
|
@ -20,6 +20,7 @@ import numbers
|
|||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -31,7 +32,7 @@ from mindspore.nn.layer.pooling import AvgPool2d
|
|||
from mindspore.nn.layer.activation import ReLU
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']
|
||||
__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop', 'PixelShuffle', 'PixelUnshuffle']
|
||||
|
||||
|
||||
class ImageGradients(Cell):
|
||||
|
@ -555,3 +556,94 @@ class CentralCrop(Cell):
|
|||
image = self.slice(image, bbox_begin, bbox_size)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class PixelShuffle(Cell):
|
||||
r"""
|
||||
PixelShuffle operatrion.
|
||||
|
||||
Applies a pixelshuffle operation over an input signal composed of several input planes. This is useful for
|
||||
implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
||||
<https://arxiv.org/abs/1609.05158> `_ .
|
||||
|
||||
Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
||||
:math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
|
||||
|
||||
Args:
|
||||
upscale_factor (int): factor to increase spatial resolution by, and is a positive integer.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and
|
||||
the length of third to last dimension can be divisible by `upscale_factor` squared.
|
||||
|
||||
Output:
|
||||
- **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
|
||||
|
||||
Raises:
|
||||
ValueError: If `upscale_factor` is not a positive integer.
|
||||
ValueError: If the length of third to last dimension of `x` is not divisible by `upscale_factor` squared.
|
||||
TypeError: If the dimension of `x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4))
|
||||
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
>>> pixel_shuffle = nn.PixelShuffle(3)
|
||||
>>> output = pixel_shuffle(input_x)
|
||||
>>> print(output.shape)
|
||||
(3, 2, 1, 12, 12)
|
||||
"""
|
||||
def __init__(self, upscale_factor):
|
||||
super(PixelShuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def construct(self, x):
|
||||
return ops.pixel_shuffle(x, self.upscale_factor)
|
||||
|
||||
|
||||
class PixelUnshuffle(Cell):
|
||||
r"""
|
||||
PixelUnshuffle operatrion.
|
||||
|
||||
Applies a pixelunshuffle operation over an input signal composed of several input planes. For more details, refer to
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
||||
<https://arxiv.org/abs/1609.05158> `_ .
|
||||
|
||||
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
||||
:math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
|
||||
|
||||
Args:
|
||||
downscale_factor (int): factor to decrease spatial resolution by, and is a positive integer.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `x` is larger than
|
||||
2, and the length of second to last dimension or last dimension can be divisible by `downscale_factor` .
|
||||
|
||||
Output:
|
||||
- **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
|
||||
|
||||
Raises:
|
||||
ValueError: If `downscale_factor` is not a positive integer.
|
||||
ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
|
||||
TypeError: If the dimension of `x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> pixel_unshuffle = nn.PixelUnshuffle(3)
|
||||
>>> input_x = np.arange(12 * 12).reshape((1, 1, 12, 12))
|
||||
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
>>> output = pixel_unshuffle(input_x)
|
||||
>>> print(output.shape)
|
||||
>>> (1, 9, 4, 4)
|
||||
"""
|
||||
def __init__(self, downscale_factor):
|
||||
super(PixelUnshuffle, self).__init__()
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
def construct(self, x):
|
||||
return ops.pixel_unshuffle(x, self.downscale_factor)
|
||||
|
|
|
@ -310,6 +310,8 @@ from .nn_func import (
|
|||
dropout3d,
|
||||
deformable_conv2d,
|
||||
fast_gelu,
|
||||
pixel_shuffle,
|
||||
pixel_unshuffle,
|
||||
hardshrink,
|
||||
soft_shrink,
|
||||
intopk,
|
||||
|
|
|
@ -3624,6 +3624,127 @@ def conv3d(inputs, weight, pad_mode="valid", padding=0, stride=1, dilation=1, gr
|
|||
return output
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||
validator.check_positive_int(arg_value, arg_name=arg_name, prim_name=prim_name)
|
||||
|
||||
|
||||
def pixel_shuffle(x, upscale_factor):
|
||||
r"""
|
||||
pixel_shuffle operatrion.
|
||||
|
||||
Applies a pixel_shuffle operation over an input signal composed of several input planes. This is useful for
|
||||
implementiong efficient sub-pixel convolution with a stride of :math:`1/r`. For more details, refer to
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
||||
<https://arxiv.org/abs/1609.05158> `_ .
|
||||
|
||||
Typically, the `x` is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
||||
:math:`(*, C, H \times r, W \times r)`, where `r` is an upscale factor and `*` is zero or more batch dimensions.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, and the
|
||||
length of third to last dimension can be divisible by `upscale_factor` squared.
|
||||
upscale_factor (int): factor to increase spatial resolution by, and is a positive integer.
|
||||
|
||||
Returns:
|
||||
- **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
|
||||
|
||||
Raises:
|
||||
ValueError: If `upscale_factor` is not a positive integer.
|
||||
ValueError: If the length of third to last dimension is not divisible by `upscale_factor` squared.
|
||||
TypeError: If the dimension of `x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4))
|
||||
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
>>> output = ops.pixel_shuffle(input_x, 3)
|
||||
>>> print(output.shape)
|
||||
(3, 2, 1, 12, 12)
|
||||
"""
|
||||
_check_positive_int(upscale_factor, "upscale_factor")
|
||||
idx = x.shape
|
||||
length = len(idx)
|
||||
if length < 3:
|
||||
raise TypeError(f"For pixel_shuffle, the dimension of `x` should be larger than 2, but got {length}.")
|
||||
pre = idx[:-3]
|
||||
c, h, w = idx[-3:]
|
||||
if c % upscale_factor ** 2 != 0:
|
||||
raise ValueError("For 'pixel_shuffle', the length of third to last dimension is not divisible"
|
||||
"by `upscale_factor` squared.")
|
||||
c = c // upscale_factor ** 2
|
||||
input_perm = (pre + (c, upscale_factor, upscale_factor, h, w))
|
||||
reshape = ops.Reshape()
|
||||
x = reshape(x, input_perm)
|
||||
input_perm = [i for i in range(length - 2)]
|
||||
input_perm = input_perm + [length, length - 2, length + 1, length - 1]
|
||||
input_perm = tuple(input_perm)
|
||||
transpose = ops.Transpose()
|
||||
x = transpose(x, input_perm)
|
||||
x = reshape(x, (pre + (c, upscale_factor * h, upscale_factor * w)))
|
||||
return x
|
||||
|
||||
|
||||
def pixel_unshuffle(x, downscale_factor):
|
||||
r"""
|
||||
pixel_unshuffle operatrion.
|
||||
|
||||
Applies a pixel_unshuffle operation over an input signal composed of several input planes. For more details, refer
|
||||
to `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
||||
<https://arxiv.org/abs/1609.05158>`_ .
|
||||
|
||||
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
||||
:math:`(*, C \times r^2, H, W)` , where `r` is a downscale factor and `*` is zero or more batch dimensions.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `x` is larger than 2,
|
||||
and the length of second to last dimension or last dimension can be divisible by `downscale_factor` .
|
||||
downscale_factor (int): factor to decrease spatial resolution by, and is a positive integer.
|
||||
|
||||
Returns:
|
||||
- **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
|
||||
|
||||
Raises:
|
||||
ValueError: If `downscale_factor` is not a positive integer.
|
||||
ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
|
||||
TypeError: If the dimension of `x` is less than 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(12 * 12).reshape((1, 1, 12, 12))
|
||||
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
>>> output = ops.pixel_unshuffle(input_x, 3)
|
||||
>>> print(output.shape)
|
||||
>>> (1, 9, 4, 4)
|
||||
"""
|
||||
_check_positive_int(downscale_factor, "downscale_factor")
|
||||
idx = x.shape
|
||||
length = len(idx)
|
||||
if length < 3:
|
||||
raise TypeError(f"For pixel_unshuffle, the dimension of `x` should be larger than 2, but got {length}.")
|
||||
pre = idx[:-3]
|
||||
c, h, w = idx[-3:]
|
||||
if h % downscale_factor != 0 or w % downscale_factor != 0:
|
||||
raise ValueError("For 'pixel_unshuffle', the length of second to last 2 dimension should be divisible "
|
||||
"by downscale_factor.")
|
||||
h = h // downscale_factor
|
||||
w = w // downscale_factor
|
||||
input_perm = (pre + (c, h, downscale_factor, w, downscale_factor))
|
||||
reshape = ops.Reshape()
|
||||
x = reshape(x, input_perm)
|
||||
input_perm = [i for i in range(length - 2)]
|
||||
input_perm = input_perm + [length - 1, length + 1, length - 2, length]
|
||||
input_perm = tuple(input_perm)
|
||||
transpose = ops.Transpose()
|
||||
x = transpose(x, input_perm)
|
||||
x = reshape(x, (pre + (c * downscale_factor * downscale_factor, h, w)))
|
||||
return x
|
||||
|
||||
|
||||
def glu(x, axis=-1):
|
||||
r"""
|
||||
Computes GLU (Gated Linear Unit activation function) of input tensors .
|
||||
|
@ -4086,6 +4207,8 @@ __all__ = [
|
|||
'dropout2d',
|
||||
'dropout3d',
|
||||
'fast_gelu',
|
||||
'pixel_shuffle',
|
||||
'pixel_unshuffle',
|
||||
'hardshrink',
|
||||
'soft_shrink',
|
||||
'intopk',
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
class PixelShuffleNet(nn.Cell):
|
||||
"""PixelShuffleNet"""
|
||||
|
||||
def __init__(self):
|
||||
super(PixelShuffleNet, self).__init__()
|
||||
self.pixel = nn.PixelShuffle(2)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.pixel(x)
|
||||
return output
|
||||
|
||||
|
||||
class PixelUnShuffleNet(nn.Cell):
|
||||
"""PixelUnShuffleNet"""
|
||||
|
||||
def __init__(self):
|
||||
super(PixelUnShuffleNet, self).__init__()
|
||||
self.pixel = nn.PixelUnshuffle(2)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.pixel(x)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_compile_max(mode):
|
||||
"""
|
||||
Feature: Test PixelShuffleAndUnShuffle
|
||||
Description: Test the functionality of PixelShuffleAndUnShuffle
|
||||
Expectation: Success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
input_x = np.arange(4 * 2 * 2).reshape((4, 2, 2))
|
||||
input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
shufflenet = PixelShuffleNet()
|
||||
unshufflenet = PixelUnShuffleNet()
|
||||
output1 = shufflenet(input_x)
|
||||
expect_output1 = np.array([[[0, 4, 1, 5],
|
||||
[8, 12, 9, 13],
|
||||
[2, 6, 3, 7],
|
||||
[10, 14, 11, 15]]])
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
output2 = unshufflenet(output1)
|
||||
assert np.allclose(input_x.asnumpy(), output2.asnumpy())
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
class PixelShuffleNet(nn.Cell):
|
||||
"""PixelShuffleNet"""
|
||||
|
||||
def construct(self, x):
|
||||
output = ops.pixel_shuffle(x, 2)
|
||||
return output
|
||||
|
||||
|
||||
class PixelUnShuffleNet(nn.Cell):
|
||||
"""PixelUnShuffleNet"""
|
||||
|
||||
def construct(self, x):
|
||||
output = ops.pixel_unshuffle(x, 2)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_compile_max(mode):
|
||||
"""
|
||||
Feature: Test PixelShuffleAndUnShuffle
|
||||
Description: Test the functionality of PixelShuffleAndUnShuffle
|
||||
Expectation: Success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
input_x = np.arange(4 * 2 * 2).reshape((4, 2, 2))
|
||||
input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
shufflenet = PixelShuffleNet()
|
||||
unshufflenet = PixelUnShuffleNet()
|
||||
output1 = shufflenet(input_x)
|
||||
expect_output1 = np.array([[[0, 4, 1, 5],
|
||||
[8, 12, 9, 13],
|
||||
[2, 6, 3, 7],
|
||||
[10, 14, 11, 15]]])
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
output2 = unshufflenet(output1)
|
||||
assert np.allclose(input_x.asnumpy(), output2.asnumpy())
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test image api
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class PixelShuffleAndUnShuffleNet(nn.Cell):
|
||||
"""PixelShuffleAndUnShuffleNet"""
|
||||
|
||||
def __init__(self):
|
||||
super(PixelShuffleAndUnShuffleNet, self).__init__()
|
||||
self.pixelshuffle = nn.PixelShuffle(3)
|
||||
self.pixelunshuffle = nn.PixelUnshuffle(3)
|
||||
|
||||
def construct(self, x):
|
||||
output_shuffle = self.pixelshuffle(x)
|
||||
output_unshuffle = self.pixelunshuffle(output_shuffle)
|
||||
return output_unshuffle
|
||||
|
||||
|
||||
def test_compile_pixel_shuffle_unshuffle():
|
||||
"""
|
||||
Feature: Test PixelShuffleAndUnShuffle
|
||||
Description: Test the functionality of PixelShuffleAndUnShuffle
|
||||
Expectation: Success
|
||||
"""
|
||||
net = PixelShuffleAndUnShuffleNet()
|
||||
input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4))
|
||||
input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
_cell_graph_executor.compile(net, input_x)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test image api
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class PixelShuffleAndUnShuffleNet(nn.Cell):
|
||||
"""PixelShuffleAndUnShuffleNet"""
|
||||
|
||||
def construct(self, x):
|
||||
scaler = 3
|
||||
output_shuffle = ops.pixel_shuffle(x, scaler)
|
||||
output_unshuffle = ops.pixel_unshuffle(output_shuffle, scaler)
|
||||
return output_unshuffle
|
||||
|
||||
|
||||
def test_compile_pixel_shuffle_unshuffle():
|
||||
"""
|
||||
Feature: Test ops PixelShuffleAndUnShuffle
|
||||
Description: Test the functionality of PixelShuffleAndUnShuffle
|
||||
Expectation: Success
|
||||
"""
|
||||
net = PixelShuffleAndUnShuffleNet()
|
||||
input_x = np.arange(3 * 2 * 9 * 4 * 4).reshape((3, 2, 9, 4, 4))
|
||||
input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
||||
_cell_graph_executor.compile(net, input_x)
|
Loading…
Reference in New Issue