!44702 Add nn.ChannelShuffle
Merge pull request !44702 from pkuliuliu/master
This commit is contained in:
commit
c1c0dd86f8
|
@ -368,4 +368,5 @@ Dynamic LR函数
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.nn.ChannelShuffle
|
||||
mindspore.nn.Flatten
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.nn.ChannelShuffle
|
||||
============================
|
||||
|
||||
.. py:class:: mindspore.nn.ChannelShuffle()
|
||||
|
||||
将shape的为 :math`(*, C, H, W)` 的Tensor的通道划分成 :math`g` 组,并将其以 :math`(*, C \frac g, g, H, W)` 的shape重新排列, 同时保持Tensor原有的shape。
|
||||
|
||||
参数:
|
||||
- **groups** (int) - 划分通道的组数。取值范围是 :math`(0, \inf)` 。在上述公式中表示为 :math`g` 。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - Tensor的shape :math:`(*, C_{in}, H_{in}, W_{in})` 。
|
||||
|
||||
输出:
|
||||
Tensor,数据类型和shape与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `groups` 非整数。
|
||||
- **ValueError** - `groups` 小于1。
|
||||
- **ValueError** - `x` 的维度小于3。
|
||||
- **ValueError** - Tensor的通道数不能被 `groups` 整除。
|
|
@ -371,4 +371,5 @@ Tools
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.nn.ChannelShuffle
|
||||
mindspore.nn.Flatten
|
||||
|
|
|
@ -36,6 +36,7 @@ from mindspore.nn.layer.quant import *
|
|||
from mindspore.nn.layer.math import *
|
||||
from mindspore.nn.layer.combined import *
|
||||
from mindspore.nn.layer.timedistributed import *
|
||||
from mindspore.nn.layer.channel_shuffle import ChannelShuffle
|
||||
from mindspore.nn.layer.thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
|
||||
from mindspore.nn.layer.padding import ConstantPad1d, ConstantPad2d, ConstantPad3d, ReflectionPad1d, \
|
||||
ReflectionPad2d, ZeroPad2d, ReplicationPad1d, ReplicationPad2d, ReplicationPad3d
|
||||
|
@ -58,3 +59,4 @@ __all__.extend(combined.__all__)
|
|||
__all__.extend(timedistributed.__all__)
|
||||
__all__.extend(thor_layer.__all__)
|
||||
__all__.extend(padding.__all__)
|
||||
__all__.extend(channel_shuffle.__all__)
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""channel shuffle"""
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
__all__ = ['ChannelShuffle']
|
||||
|
||||
|
||||
class ChannelShuffle(Cell):
|
||||
r"""
|
||||
Divide the channels in a tensor of shape :math:`(*, C , H, W)`
|
||||
into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
|
||||
while keeping the original tensor shape.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups to divide channels in. Refer to :math`g`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If groups is not an int.
|
||||
ValueError: If `groups` is less than 1.
|
||||
ValueError: If dims of `x` is less than 3.
|
||||
ValueError: If number of channels can not be divisible by groups.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> channel_shuffle = nn.ChannelShuffle(2)
|
||||
>>> x = Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
|
||||
>>> print(x)
|
||||
[[[[0 1],
|
||||
[2 3]],
|
||||
[[4 5],
|
||||
[6 7]],
|
||||
[[8 9],
|
||||
[10 11]],
|
||||
[[12 13],
|
||||
[14 15]],
|
||||
]]
|
||||
>>> output = channel_shuffle(x)
|
||||
>>> print(output)
|
||||
[[[[0 1],
|
||||
[2 3]],
|
||||
[[8 9],
|
||||
[10 11]],
|
||||
[[4 5],
|
||||
[6 7]],
|
||||
[[12 13],
|
||||
[14 15]],
|
||||
]]
|
||||
"""
|
||||
def __init__(self, groups):
|
||||
"""Initialize ChannelShuffle."""
|
||||
super(ChannelShuffle, self).__init__()
|
||||
if not isinstance(groups, int):
|
||||
raise TypeError("For ChannelShuffle, the param `groups` must be int, but got {}.".format(type(groups)))
|
||||
if groups < 1:
|
||||
raise ValueError(f"For ChannelShuffle, the param `groups` must be larger than 0, but got {groups}.")
|
||||
|
||||
self.groups = groups
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
@staticmethod
|
||||
@constexpr
|
||||
def _check_input_dim(shape, channels, groups, cls_name):
|
||||
dim = len(shape)
|
||||
if dim < 3:
|
||||
raise ValueError(f"For {cls_name}, the in_shape must have more than 2 dims, but got {dim}.")
|
||||
|
||||
if channels % groups != 0:
|
||||
raise ValueError(f"For {cls_name}, number of channels must be divisible by groups, "
|
||||
f"but got {channels} channels and {groups} groups.")
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
n, c = x_shape[0], x_shape[1]
|
||||
self._check_input_dim(x_shape, c, self.groups, self.cls_name)
|
||||
out = self.reshape(x, (n, self.groups, c // self.groups, -1))
|
||||
out = self.transpose(out, (0, 2, 1, 3))
|
||||
return self.reshape(out, x_shape)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, groups):
|
||||
super(Net, self).__init__()
|
||||
self.channel_shuffle = nn.ChannelShuffle(groups)
|
||||
|
||||
def construct(self, x):
|
||||
return self.channel_shuffle(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_channel_shuffle_normal(mode):
|
||||
"""
|
||||
Feature: ChannelShuffle
|
||||
Description: Verify the result of ChannelShuffle
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net(2)
|
||||
x = ms.Tensor(np.arange(16).reshape((1, 4, 2, 2)), dtype=ms.int32)
|
||||
out = net(x)
|
||||
expect_out = np.array([[[[0, 1], [2, 3]], [[8, 9], [10, 11]],
|
||||
[[4, 5], [6, 7]], [[12, 13], [14, 15]]]]).astype(np.int32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test channel_shuffle api
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class ChannelShuffleNet(nn.Cell):
|
||||
"""ChannelShuffle"""
|
||||
def __init__(self, groups):
|
||||
super(ChannelShuffleNet, self).__init__()
|
||||
self.channel_shuffle = nn.ChannelShuffle(groups)
|
||||
|
||||
def construct(self, x):
|
||||
return self.channel_shuffle(x)
|
||||
|
||||
|
||||
def test_compile_channel_shuffle():
|
||||
"""
|
||||
Feature: Test ChannelShuffleNet
|
||||
Description: Test the functionality of ChannelShuffle
|
||||
Expectation: Success
|
||||
"""
|
||||
net = ChannelShuffleNet(2)
|
||||
x = ms.Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
|
||||
_cell_graph_executor.compile(net, x)
|
Loading…
Reference in New Issue