From 45dbc8bf047c0d3ae021d5aebecc50e8a8433f8f Mon Sep 17 00:00:00 2001 From: panbingao Date: Fri, 10 Jul 2020 09:11:38 +0800 Subject: [PATCH] Move model_zoo.resnet.py --- cmake/package.cmake | 1 - mindspore/model_zoo/__init__.py | 0 .../networks/models/resnet50/src}/resnet.py | 0 .../models/resnet50/test_resnet50_imagenet.py | 2 +- .../gtest_input/optimizer/ad/ad_test.py | 2 +- .../pipeline/parse/parser_integrate.py | 2 +- tests/ut/python/model/resnet.py | 282 ++++++++++++++++++ tests/ut/python/train/test_amp.py | 2 +- 8 files changed, 286 insertions(+), 5 deletions(-) delete mode 100644 mindspore/model_zoo/__init__.py rename {mindspore/model_zoo => tests/st/networks/models/resnet50/src}/resnet.py (100%) mode change 100755 => 100644 create mode 100644 tests/ut/python/model/resnet.py diff --git a/cmake/package.cmake b/cmake/package.cmake index 42821cf41dd..2034b550406 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -210,7 +210,6 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/parallel ${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/train - ${CMAKE_SOURCE_DIR}/mindspore/model_zoo ${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication diff --git a/mindspore/model_zoo/__init__.py b/mindspore/model_zoo/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mindspore/model_zoo/resnet.py b/tests/st/networks/models/resnet50/src/resnet.py old mode 100755 new mode 100644 similarity index 100% rename from mindspore/model_zoo/resnet.py rename to tests/st/networks/models/resnet50/src/resnet.py diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index c88af6bcf7e..e721b62c589 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -27,10 +27,10 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.model_zoo.resnet import resnet50 import mindspore.nn as nn import mindspore.dataset as ds +from tests.st.networks.models.resnet50.src.resnet import resnet50 from tests.st.networks.models.resnet50.src.dataset import create_dataset from tests.st.networks.models.resnet50.src.lr_generator import get_learning_rate from tests.st.networks.models.resnet50.src.config import config diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py index e38c61f16e8..bcfa077ea5e 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py @@ -17,8 +17,8 @@ import numpy as np import mindspore as ms from mindspore.common.tensor import Tensor -from mindspore.model_zoo.resnet import resnet50 from mindspore.ops import Primitive +from tests.ut.python.model.resnet import resnet50 scala_add = Primitive('scalar_add') diff --git a/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py index fa5b1b90558..28bded64016 100644 --- a/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py +++ b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py @@ -22,9 +22,9 @@ from mindspore.common import dtype from mindspore.common.api import ms_function, _executor from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore.model_zoo.resnet import resnet50 from mindspore.ops import functional as F from mindspore.train.model import Model +from tests.ut.python.model.resnet import resnet50 def test_high_order_function(a): diff --git a/tests/ut/python/model/resnet.py b/tests/ut/python/model/resnet.py new file mode 100644 index 00000000000..001e1db0cf3 --- /dev/null +++ b/tests/ut/python/model/resnet.py @@ -0,0 +1,282 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.bn1 = _bn(channel) + + self.conv2 = _conv3x3(channel, channel, stride=stride) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1) + self.bn3 = _bn_last(out_channel) + + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), + _bn(out_channel)]) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = _conv7x7(3, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=10): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + +def resnet101(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index c7befb6c2be..6bb4ec54642 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -22,10 +22,10 @@ from mindspore import amp from mindspore import nn from mindspore.train import Model, ParallelMode from mindspore.common import dtype as mstype -from mindspore.model_zoo.resnet import resnet50 from ....dataset_mock import MindData from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import init +from tests.ut.python.model.resnet import resnet50 def setup_module(module): _ = module