From fc310bc5aef039214cf8f800128b8243781b9971 Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Tue, 8 Sep 2020 15:38:42 +0800 Subject: [PATCH] fix network bug --- model_zoo/official/cv/shufflenetv2/blocks.py | 83 ------------------- model_zoo/official/cv/shufflenetv2/eval.py | 8 +- .../scripts/run_distribute_train_for_gpu.sh | 17 ++-- .../scripts/run_standalone_train_for_gpu.sh | 16 +++- .../cv/shufflenetv2/src/CrossEntropySmooth.py | 38 +++++++++ .../official/cv/shufflenetv2/src/loss.py | 60 -------------- .../{network.py => src/shufflenetv2.py} | 69 ++++++++++++++- model_zoo/official/cv/shufflenetv2/train.py | 7 +- model_zoo/official/recommend/deepfm/train.py | 10 ++- 9 files changed, 144 insertions(+), 164 deletions(-) delete mode 100644 model_zoo/official/cv/shufflenetv2/blocks.py create mode 100644 model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py delete mode 100644 model_zoo/official/cv/shufflenetv2/src/loss.py rename model_zoo/official/cv/shufflenetv2/{network.py => src/shufflenetv2.py} (63%) diff --git a/model_zoo/official/cv/shufflenetv2/blocks.py b/model_zoo/official/cv/shufflenetv2/blocks.py deleted file mode 100644 index 253abce0fc2..00000000000 --- a/model_zoo/official/cv/shufflenetv2/blocks.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import mindspore.nn as nn -import mindspore.ops.operations as P - - -class ShuffleV2Block(nn.Cell): - def __init__(self, inp, oup, mid_channels, *, ksize, stride): - super(ShuffleV2Block, self).__init__() - self.stride = stride - ##assert stride in [1, 2] - - self.mid_channels = mid_channels - self.ksize = ksize - pad = ksize // 2 - self.pad = pad - self.inp = inp - - outputs = oup - inp - - branch_main = [ - # pw - nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, - pad_mode='pad', padding=0, has_bias=False), - nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), - nn.ReLU(), - # dw - nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, - pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), - nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), - # pw-linear - nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, - pad_mode='pad', padding=0, has_bias=False), - nn.BatchNorm2d(num_features=outputs, momentum=0.9), - nn.ReLU(), - ] - self.branch_main = nn.SequentialCell(branch_main) - - if stride == 2: - branch_proj = [ - # dw - nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, - pad_mode='pad', padding=pad, group=inp, has_bias=False), - nn.BatchNorm2d(num_features=inp, momentum=0.9), - # pw-linear - nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, - pad_mode='pad', padding=0, has_bias=False), - nn.BatchNorm2d(num_features=inp, momentum=0.9), - nn.ReLU(), - ] - self.branch_proj = nn.SequentialCell(branch_proj) - else: - self.branch_proj = None - - def construct(self, old_x): - if self.stride == 1: - x_proj, x = self.channel_shuffle(old_x) - return P.Concat(1)((x_proj, self.branch_main(x))) - if self.stride == 2: - x_proj = old_x - x = old_x - return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) - return None - - def channel_shuffle(self, x): - batchsize, num_channels, height, width = P.Shape()(x) - ##assert (num_channels % 4 == 0) - x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) - x = P.Transpose()(x, (1, 0, 2,)) - x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) - return x[0], x[1] diff --git a/model_zoo/official/cv/shufflenetv2/eval.py b/model_zoo/official/cv/shufflenetv2/eval.py index fdbddcf3760..2dcfdc52c93 100644 --- a/model_zoo/official/cv/shufflenetv2/eval.py +++ b/model_zoo/official/cv/shufflenetv2/eval.py @@ -23,8 +23,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.config import config_gpu as cfg from src.dataset import create_dataset -from network import ShuffleNetV2 - +from src.shufflenetv2 import ShuffleNetV2 +from src.CrossEntropySmooth import CrossEntropySmooth if __name__ == '__main__': parser = argparse.ArgumentParser(description='image classification evaluation') @@ -43,8 +43,8 @@ if __name__ == '__main__': load_param_into_net(net, ckpt) net.set_train(False) dataset = create_dataset(args_opt.dataset_path, False, 0, 1) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False, - smooth_factor=0.1, num_classes=cfg.num_classes) + loss = CrossEntropySmooth(sparse=True, reduction='mean', + smooth_factor=0.1, num_classes=cfg.num_classes) eval_metrics = {'Loss': nn.Loss(), 'Top1-Acc': nn.Top1CategoricalAccuracy(), 'Top5-Acc': nn.Top5CategoricalAccuracy()} diff --git a/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh index c3bfedeaf8c..ec03ea7bfe2 100644 --- a/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh +++ b/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -if [ $# -lt 3 ] +if [ $# != 3 ] && [ $# != 4 ] then - echo "Usage: \ - sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] \ + echo "Usage: + sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) " exit 1 fi @@ -48,10 +48,15 @@ cd ../train || exit export CUDA_VISIBLE_DEVICES="$2" -if [ $1 -gt 1 ] +if [ $# == 3 ] then mpirun -n $1 --allow-run-as-root \ python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 & -else - python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$3 > train.log 2>&1 & fi + +if [ $# == 4 ] +then + mpirun -n $1 --allow-run-as-root \ + python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 --resume=$4 > train.log 2>&1 & +fi + diff --git a/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh index 02da407d1ca..8d83803eff2 100644 --- a/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh +++ b/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -if [ $# -lt 1 ] +if [ $# != 1 ] && [ $# != 2 ] then - echo "Usage: \ - sh run_standalone_train_for_gpu.sh [DATASET_PATH] \ + echo "Usage: + sh run_standalone_train_for_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) " exit 1 fi @@ -37,4 +37,12 @@ fi mkdir ../train cd ../train || exit -python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & +if [ $# == 1 ] +then + python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & +fi + +if [ $# == 2 ] +then + python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 --resume=$2 > train.log 2>&1 & +fi diff --git a/model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py b/model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py new file mode 100644 index 00000000000..bf38c6e77b0 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +class CrossEntropySmooth(_Loss): + """CrossEntropy""" + def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): + super(CrossEntropySmooth, self).__init__() + self.onehot = P.OneHot() + self.sparse = sparse + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) + + def construct(self, logit, label): + if self.sparse: + label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, label) + return loss diff --git a/model_zoo/official/cv/shufflenetv2/src/loss.py b/model_zoo/official/cv/shufflenetv2/src/loss.py deleted file mode 100644 index 01757501e7e..00000000000 --- a/model_zoo/official/cv/shufflenetv2/src/loss.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""define loss function for network.""" -from mindspore.common import dtype as mstype -from mindspore.nn.loss.loss import _Loss -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore import Tensor -import mindspore.nn as nn - - -class CrossEntropy(_Loss): - """the redefined loss function with SoftmaxCrossEntropyWithLogits""" - def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): - super(CrossEntropy, self).__init__() - self.factor = factor - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - - def construct(self, logits, label): - logit, aux = logits - one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) - loss_logit = self.ce(logit, one_hot_label) - loss_logit = self.mean(loss_logit, 0) - one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) - loss_aux = self.ce(aux, one_hot_label_aux) - loss_aux = self.mean(loss_aux, 0) - return loss_logit + self.factor*loss_aux - - -class CrossEntropy_Val(_Loss): - """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" - def __init__(self, smooth_factor=0, num_classes=1000): - super(CrossEntropy_Val, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - - def construct(self, logits, label): - one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) - loss_logit = self.ce(logits, one_hot_label) - loss_logit = self.mean(loss_logit, 0) - return loss_logit diff --git a/model_zoo/official/cv/shufflenetv2/network.py b/model_zoo/official/cv/shufflenetv2/src/shufflenetv2.py similarity index 63% rename from model_zoo/official/cv/shufflenetv2/network.py rename to model_zoo/official/cv/shufflenetv2/src/shufflenetv2.py index f2d2105cfbf..1245d94798a 100644 --- a/model_zoo/official/cv/shufflenetv2/network.py +++ b/model_zoo/official/cv/shufflenetv2/src/shufflenetv2.py @@ -14,13 +14,78 @@ # ============================================================================ import numpy as np -from blocks import ShuffleV2Block - from mindspore import Tensor import mindspore.nn as nn import mindspore.ops.operations as P +class ShuffleV2Block(nn.Cell): + def __init__(self, inp, oup, mid_channels, *, ksize, stride): + super(ShuffleV2Block, self).__init__() + self.stride = stride + ##assert stride in [1, 2] + + self.mid_channels = mid_channels + self.ksize = ksize + pad = ksize // 2 + self.pad = pad + self.inp = inp + + outputs = oup - inp + + branch_main = [ + # pw + nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), + nn.ReLU(), + # dw + nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, + pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), + nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), + # pw-linear + nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=outputs, momentum=0.9), + nn.ReLU(), + ] + self.branch_main = nn.SequentialCell(branch_main) + + if stride == 2: + branch_proj = [ + # dw + nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, + pad_mode='pad', padding=pad, group=inp, has_bias=False), + nn.BatchNorm2d(num_features=inp, momentum=0.9), + # pw-linear + nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=inp, momentum=0.9), + nn.ReLU(), + ] + self.branch_proj = nn.SequentialCell(branch_proj) + else: + self.branch_proj = None + + def construct(self, old_x): + if self.stride == 1: + x_proj, x = self.channel_shuffle(old_x) + return P.Concat(1)((x_proj, self.branch_main(x))) + if self.stride == 2: + x_proj = old_x + x = old_x + return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) + return None + + def channel_shuffle(self, x): + batchsize, num_channels, height, width = P.Shape()(x) + ##assert (num_channels % 4 == 0) + x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) + x = P.Transpose()(x, (1, 0, 2,)) + x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) + return x[0], x[1] + + class ShuffleNetV2(nn.Cell): def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): super(ShuffleNetV2, self).__init__() diff --git a/model_zoo/official/cv/shufflenetv2/train.py b/model_zoo/official/cv/shufflenetv2/train.py index 066b225d9f2..8c7f5b215c7 100644 --- a/model_zoo/official/cv/shufflenetv2/train.py +++ b/model_zoo/official/cv/shufflenetv2/train.py @@ -17,7 +17,6 @@ import argparse import ast import os -from network import ShuffleNetV2 import mindspore.nn as nn from mindspore import context @@ -30,9 +29,11 @@ from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed +from src.shufflenetv2 import ShuffleNetV2 from src.config import config_gpu as cfg from src.dataset import create_dataset from src.lr_generator import get_lr_basic +from src.CrossEntropySmooth import CrossEntropySmooth set_seed(cfg.random_seed) @@ -73,8 +74,8 @@ if __name__ == '__main__': net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) # loss - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, - smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) # learning rate schedule lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, diff --git a/model_zoo/official/recommend/deepfm/train.py b/model_zoo/official/recommend/deepfm/train.py index f3299a42d69..37887acaa2a 100644 --- a/model_zoo/official/recommend/deepfm/train.py +++ b/model_zoo/official/recommend/deepfm/train.py @@ -71,8 +71,14 @@ if __name__ == '__main__': print("Unsupported device_target ", args_opt.device_target) exit() else: - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) + if args_opt.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) + elif args_opt.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + else: + print("Unsupported device_target ", args_opt.device_target) + exit() rank_size = None rank_id = None