From 83580c4decaa2392a195c92fee61a2ccddbf8425 Mon Sep 17 00:00:00 2001 From: hanhaocheng Date: Sat, 1 Aug 2020 11:54:00 +0800 Subject: [PATCH] GPU inceptionv3 support in modelzoo --- model_zoo/official/cv/inceptionv3/README.md | 115 ++++++++ model_zoo/official/cv/inceptionv3/eval.py | 53 ++++ model_zoo/official/cv/inceptionv3/export.py | 40 +++ .../scripts/run_distribute_train_for_gpu.sh | 17 ++ .../inceptionv3/scripts/run_eval_for_gpu.sh | 19 ++ .../scripts/run_standalone_train_for_gpu.sh | 19 ++ .../official/cv/inceptionv3/src/config.py | 43 +++ .../official/cv/inceptionv3/src/dataset.py | 69 +++++ .../cv/inceptionv3/src/inception_v3.py | 257 ++++++++++++++++++ model_zoo/official/cv/inceptionv3/src/loss.py | 60 ++++ .../cv/inceptionv3/src/lr_generator.py | 87 ++++++ model_zoo/official/cv/inceptionv3/train.py | 116 ++++++++ 12 files changed, 895 insertions(+) create mode 100644 model_zoo/official/cv/inceptionv3/README.md create mode 100644 model_zoo/official/cv/inceptionv3/eval.py create mode 100644 model_zoo/official/cv/inceptionv3/export.py create mode 100644 model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh create mode 100644 model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh create mode 100644 model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh create mode 100644 model_zoo/official/cv/inceptionv3/src/config.py create mode 100644 model_zoo/official/cv/inceptionv3/src/dataset.py create mode 100644 model_zoo/official/cv/inceptionv3/src/inception_v3.py create mode 100644 model_zoo/official/cv/inceptionv3/src/loss.py create mode 100644 model_zoo/official/cv/inceptionv3/src/lr_generator.py create mode 100644 model_zoo/official/cv/inceptionv3/train.py diff --git a/model_zoo/official/cv/inceptionv3/README.md b/model_zoo/official/cv/inceptionv3/README.md new file mode 100644 index 00000000000..0d84497ac5b --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/README.md @@ -0,0 +1,115 @@ +# Inception-v3 Example + +## Description + +This is an example of training Inception-v3 in MindSpore. + +## Requirements + +- Install [Mindspore](http://www.mindspore.cn/install/en). +- Downlaod the dataset. + +## Structure + +```shell +. +└─Inception-v3 + ├─README.md + ├─scripts + ├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p) + ├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p) + └─run_eval_for_gpu.sh # launch evaluating with gpu platform + ├─src + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─inception_v3.py # network definition + ├─loss.py # Customized CrossEntropy loss function + ├─lr_generator.py # learning rate generator + ├─eval.py # eval net + ├─export.py # convert checkpoint + └─train.py # train net + +``` + +## Parameter Configuration + +Parameters for both training and evaluating can be set in config.py + +``` +'random_seed': 1, # fix random seed +'rank': 0, # local rank of distributed +'group_size': 1, # world size of distributed +'work_nums': 8, # number of workers to read the data +'decay_method': 'cosine', # learning rate scheduler mode +"loss_scale": 1, # loss scale +'batch_size': 128, # input batchsize +'epoch_size': 250, # total epoch numbers +'num_classes': 1000, # dataset class numbers +'smooth_factor': 0.1, # label smoothing factor +'aux_factor': 0.2, # loss factor of aux logit +'lr_init': 0.00004, # initiate learning rate +'lr_max': 0.4, # max bound of learning rate +'lr_end': 0.000004, # min bound of learning rate +'warmup_epochs': 1, # warmup epoch numbers +'weight_decay': 0.00004, # weight decay +'momentum': 0.9, # momentum +'opt_eps': 1.0, # epsilon +'keep_checkpoint_max': 100, # max numbers to keep checkpoints +'ckpt_path': './checkpoint/', # save checkpoint path +'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters +``` + + + +## Running the example + +### Train + +#### Usage + +``` +# distribute training example(8p) +sh run_distribute_train_for_gpu.sh DATA_DIR +# standalone training +sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR +``` + +#### Launch + +```bash +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset/train +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train +``` + +#### Result + +You can find checkpoint file together with result in log. + +### Evaluation + +#### Usage + +``` +# Evaluation +sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT +``` + +#### Launch + +```bash +# Evaluation with checkpoint +sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/inceptionv3-rank3-247_1251.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. + +``` +acc=78.75%(TOP1) +acc=94.07%(TOP5) +``` \ No newline at end of file diff --git a/model_zoo/official/cv/inceptionv3/eval.py b/model_zoo/official/cv/inceptionv3/eval.py new file mode 100644 index 00000000000..a2f0ade1d34 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/eval.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate_imagenet""" +import argparse +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from src.inception_v3 import InceptionV3 +from src.loss import CrossEntropy_Val + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification evaluation') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) + net = InceptionV3(num_classes=cfg.num_classes, is_training=False) + ckpt = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, ckpt) + net.set_train(False) + dataset = create_dataset(args_opt.dataset_path, False, 0, 1) + loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + model = Model(net, loss, optimizer=None, metrics=eval_metrics) + metrics = model.eval(dataset) + print("metric: ", metrics) diff --git a/model_zoo/official/cv/inceptionv3/export.py b/model_zoo/official/cv/inceptionv3/export.py new file mode 100644 index 00000000000..302ff1302a0 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/export.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into geir and onnx models################# +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export + +from src.config import config_gpu as cfg +from src.inception_v3 import InceptionV3 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='checkpoint export') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)') + args_opt = parser.parse_args() + + net = InceptionV3(num_classes=cfg.num_classes, is_training=False) + param_dict = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32) + export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") + export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR") diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 00000000000..305f1dcfff5 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh new file mode 100644 index 00000000000..0ecd63a434a --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_eval_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 00000000000..7b856bbcf92 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & + diff --git a/model_zoo/official/cv/inceptionv3/src/config.py b/model_zoo/official/cv/inceptionv3/src/config.py new file mode 100644 index 00000000000..b465a7543a8 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/config.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in main.py +""" +from easydict import EasyDict as edict + + +config_gpu = edict({ + 'random_seed': 1, + 'rank': 0, + 'group_size': 1, + 'work_nums': 8, + 'decay_method': 'cosine', + "loss_scale": 1, + 'batch_size': 128, + 'epoch_size': 250, + 'num_classes': 1000, + 'smooth_factor': 0.1, + 'aux_factor': 0.2, + 'lr_init': 0.00004, + 'lr_max': 0.4, + 'lr_end': 0.000004, + 'warmup_epochs': 1, + 'weight_decay': 0.00004, + 'momentum': 0.9, + 'opt_eps': 1.0, + 'keep_checkpoint_max': 100, + 'ckpt_path': './checkpoint/', + 'is_save_on_master': 0 +}) diff --git a/model_zoo/official/cv/inceptionv3/src/dataset.py b/model_zoo/official/cv/inceptionv3/src/dataset.py new file mode 100644 index 00000000000..73c84bc9590 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/dataset.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.c_transforms as C +from src.config import config_gpu as cfg + + +def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided into (default=None). + repeat_num(int): the repeat times of dataset. Default: 1. + + Returns: + dataset + """ + if group_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, + num_shards=group_size, shard_id=rank) + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(299, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + ] + else: + trans = [ + C.Decode(), + C.Resize(299), + C.CenterCrop(299) + ] + trans += [ + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + C.HWC2CHW() + ] + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums) + # apply batch operations + ds = ds.batch(cfg.batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/official/cv/inceptionv3/src/inception_v3.py b/model_zoo/official/cv/inceptionv3/src/inception_v3.py new file mode 100644 index 00000000000..f1339b1c88c --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/inception_v3.py @@ -0,0 +1,257 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inception-v3 model definition""" +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import XavierUniform + + +class BasicConv2d(nn.Cell): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, pad_mode='same', padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, + pad_mode=pad_mode, padding=padding, weight_init=XavierUniform(), has_bias=True) + self.bn = nn.BatchNorm2d(out_channel, eps=0.001, momentum=0.9997) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Inception_A(nn.Cell): + def __init__(self, in_channels, pool_features): + super(Inception_A, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 64, kernel_size=1) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 48, kernel_size=1), + BasicConv2d(48, 64, kernel_size=5) + ]) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, 64, kernel_size=1), + BasicConv2d(64, 96, kernel_size=3), + BasicConv2d(96, 96, kernel_size=3) + + ]) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, pool_features, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Inception_B(nn.Cell): + def __init__(self, in_channels): + super(Inception_B, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2, pad_mode='valid') + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 64, kernel_size=1), + BasicConv2d(64, 96, kernel_size=3), + BasicConv2d(96, 96, kernel_size=3, stride=2, pad_mode='valid') + + ]) + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, branch_pool)) + return out + + +class Inception_C(nn.Cell): + def __init__(self, in_channels, channels_7x7): + super(Inception_C, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 192, kernel_size=1) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, channels_7x7, kernel_size=1), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), + BasicConv2d(channels_7x7, 192, kernel_size=(7, 1)) + ]) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, channels_7x7, kernel_size=1), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), + BasicConv2d(channels_7x7, 192, kernel_size=(1, 7)) + ]) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, 192, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Inception_D(nn.Cell): + def __init__(self, in_channels): + super(Inception_D, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = nn.SequentialCell([ + BasicConv2d(in_channels, 192, kernel_size=1), + BasicConv2d(192, 320, kernel_size=3, stride=2, pad_mode='valid') + ]) + self.branch1 = nn.SequentialCell([ + BasicConv2d(in_channels, 192, kernel_size=1), + BasicConv2d(192, 192, kernel_size=(1, 7)), # check + BasicConv2d(192, 192, kernel_size=(7, 1)), + BasicConv2d(192, 192, kernel_size=3, stride=2, pad_mode='valid') + ]) + self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, branch_pool)) + return out + + +class Inception_E(nn.Cell): + def __init__(self, in_channels): + super(Inception_E, self).__init__() + self.concat = P.Concat(axis=1) + self.branch0 = BasicConv2d(in_channels, 320, kernel_size=1) + self.branch1 = BasicConv2d(in_channels, 384, kernel_size=1) + self.branch1_a = BasicConv2d(384, 384, kernel_size=(1, 3)) + self.branch1_b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch2 = nn.SequentialCell([ + BasicConv2d(in_channels, 448, kernel_size=1), + BasicConv2d(448, 384, kernel_size=3) + ]) + self.branch2_a = BasicConv2d(384, 384, kernel_size=(1, 3)) + self.branch2_b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch_pool = nn.SequentialCell([ + nn.AvgPool2d(kernel_size=3, pad_mode='same'), + BasicConv2d(in_channels, 192, kernel_size=1) + ]) + + def construct(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x1 = self.concat((self.branch1_a(x1), self.branch1_b(x1))) + x2 = self.branch2(x) + x2 = self.concat((self.branch2_a(x2), self.branch2_b(x2))) + branch_pool = self.branch_pool(x) + out = self.concat((x0, x1, x2, branch_pool)) + return out + + +class Logits(nn.Cell): + def __init__(self, num_classes=10, dropout_keep_prob=0.8): + super(Logits, self).__init__() + self.avg_pool = nn.AvgPool2d(8, pad_mode='valid') + self.dropout = nn.Dropout(keep_prob=dropout_keep_prob) + self.flatten = P.Flatten() + self.fc = nn.Dense(2048, num_classes) + + def construct(self, x): + x = self.avg_pool(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class AuxLogits(nn.Cell): + def __init__(self, in_channels, num_classes=10): + super(AuxLogits, self).__init__() + self.avg_pool = nn.AvgPool2d(5, stride=3, pad_mode='valid') + self.conv2d_0 = nn.Conv2d(in_channels, 128, kernel_size=1) + self.conv2d_1 = nn.Conv2d(128, 768, kernel_size=5, pad_mode='valid') + self.flatten = P.Flatten() + self.fc = nn.Dense(in_channels, num_classes) + + def construct(self, x): + x = self.avg_pool(x) + x = self.conv2d_0(x) + x = self.conv2d_1(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class InceptionV3(nn.Cell): + def __init__(self, num_classes=10, is_training=True): + super(InceptionV3, self).__init__() + self.is_training = is_training + self.Conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2, pad_mode='valid') + self.Conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1, pad_mode='valid') + self.Conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a = BasicConv2d(80, 192, kernel_size=3, pad_mode='valid') + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = Inception_A(192, pool_features=32) + self.Mixed_5c = Inception_A(256, pool_features=64) + self.Mixed_5d = Inception_A(288, pool_features=64) + self.Mixed_6a = Inception_B(288) + self.Mixed_6b = Inception_C(768, channels_7x7=128) + self.Mixed_6c = Inception_C(768, channels_7x7=160) + self.Mixed_6d = Inception_C(768, channels_7x7=160) + self.Mixed_6e = Inception_C(768, channels_7x7=192) + self.Mixed_7a = Inception_D(768) + self.Mixed_7b = Inception_E(1280) + self.Mixed_7c = Inception_E(2048) + if is_training: + self.aux_logits = AuxLogits(768, num_classes) + self.logits = Logits(num_classes, dropout_keep_prob=0.5) + + def construct(self, x): + x = self.Conv2d_1a(x) + x = self.Conv2d_2a(x) + x = self.Conv2d_2b(x) + x = self.maxpool1(x) + x = self.Conv2d_3b(x) + x = self.Conv2d_4a(x) + x = self.maxpool2(x) + x = self.Mixed_5b(x) + x = self.Mixed_5c(x) + x = self.Mixed_5d(x) + x = self.Mixed_6a(x) + x = self.Mixed_6b(x) + x = self.Mixed_6c(x) + x = self.Mixed_6d(x) + x = self.Mixed_6e(x) + if self.is_training: + aux_logits = self.aux_logits(x) + else: + aux_logits = None + x = self.Mixed_7a(x) + x = self.Mixed_7b(x) + x = self.Mixed_7c(x) + logits = self.logits(x) + if self.is_training: + return logits, aux_logits + return logits diff --git a/model_zoo/official/cv/inceptionv3/src/loss.py b/model_zoo/official/cv/inceptionv3/src/loss.py new file mode 100644 index 00000000000..413e1f0f399 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/loss.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): + super(CrossEntropy, self).__init__() + self.factor = factor + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + logit, aux = logits + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss_logit = self.ce(logit, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) + loss_aux = self.ce(aux, one_hot_label_aux) + loss_aux = self.mean(loss_aux, 0) + return loss_logit + self.factor*loss_aux + + +class CrossEntropy_Val(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" + def __init__(self, smooth_factor=0, num_classes=1000): + super(CrossEntropy_Val, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/official/cv/inceptionv3/src/lr_generator.py b/model_zoo/official/cv/inceptionv3/src/lr_generator.py new file mode 100644 index 00000000000..7a057f7251d --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/src/lr_generator.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'steps_decay': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + decay_nums = math.floor((float(i-warmup_steps)/steps_per_epoch) / 2) + decay_rate = pow(0.94, decay_nums) + lr = float(lr_max)*decay_rate + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps)) + lr = (lr_max-lr_end)*cosine_decay + lr_end + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + learning_rate = np.array(lr_each_step).astype(np.float32) + return learning_rate diff --git a/model_zoo/official/cv/inceptionv3/train.py b/model_zoo/official/cv/inceptionv3/train.py new file mode 100644 index 00000000000..518b9f02c43 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/train.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import argparse +import os +import random +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import ParallelMode +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.nn.optim.rmsprop import RMSProp +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import dataset as de + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from src.inception_v3 import InceptionV3 +from src.lr_generator import get_lr +from src.loss import CrossEntropy + +random.seed(cfg.random_seed) +np.random.seed(cfg.random_seed) +de.config.set_seed(cfg.random_seed) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification training') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') + parser.add_argument('--is_distributed', action='store_true', default=False, + help='distributed training') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args_opt.is_distributed: + if args_opt.platform == "Ascend": + init() + else: + init("nccl") + cfg.rank = get_rank() + cfg.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + cfg.rank = 0 + cfg.group_size = 1 + + # dataloader + dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size) + batches_per_epoch = dataset.get_dataset_size() + + # network + net = InceptionV3(num_classes=cfg.num_classes) + + # loss + loss = CrossEntropy(smooth_factor=cfg.smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor) + + # learning rate schedule + lr = get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=cfg.decay_method) + lr = Tensor(lr) + + # optimizer + decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) + no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] + group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + optimizer = RMSProp(group_params, lr, decay=0.9, weight_decay=cfg.weight_decay, + momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + + if args_opt.resume: + ckpt = load_checkpoint(args_opt.resume) + load_param_into_net(net, ckpt) + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}) + + print("============== Starting Training ==============") + loss_cb = LossMonitor(per_print_times=batches_per_epoch) + time_cb = TimeMonitor(data_size=batches_per_epoch) + callbacks = [loss_cb, time_cb] + config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck) + if args_opt.is_distributed & cfg.is_save_on_master: + if cfg.rank == 0: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + else: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + print("train success")