From 09b2dcb3fbf2a7ca8cda5deaa364b15e2ac5d305 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Fri, 24 Apr 2020 21:15:13 +0800 Subject: [PATCH] add mobilenetv2 --- example/mobilenetv2_imagenet2012/README.md | 101 +++++++ example/mobilenetv2_imagenet2012/config.py | 35 +++ example/mobilenetv2_imagenet2012/dataset.py | 84 ++++++ example/mobilenetv2_imagenet2012/eval.py | 56 ++++ example/mobilenetv2_imagenet2012/launch.py | 150 +++++++++ .../mobilenetv2_imagenet2012/lr_generator.py | 54 ++++ example/mobilenetv2_imagenet2012/run_infer.sh | 33 ++ example/mobilenetv2_imagenet2012/run_train.sh | 33 ++ example/mobilenetv2_imagenet2012/train.py | 149 +++++++++ mindspore/model_zoo/mobilenet.py | 284 ++++++++++++++++++ 10 files changed, 979 insertions(+) create mode 100644 example/mobilenetv2_imagenet2012/README.md create mode 100644 example/mobilenetv2_imagenet2012/config.py create mode 100644 example/mobilenetv2_imagenet2012/dataset.py create mode 100644 example/mobilenetv2_imagenet2012/eval.py create mode 100644 example/mobilenetv2_imagenet2012/launch.py create mode 100644 example/mobilenetv2_imagenet2012/lr_generator.py create mode 100644 example/mobilenetv2_imagenet2012/run_infer.sh create mode 100644 example/mobilenetv2_imagenet2012/run_train.sh create mode 100644 example/mobilenetv2_imagenet2012/train.py create mode 100644 mindspore/model_zoo/mobilenet.py diff --git a/example/mobilenetv2_imagenet2012/README.md b/example/mobilenetv2_imagenet2012/README.md new file mode 100644 index 00000000000..bb5288908d5 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/README.md @@ -0,0 +1,101 @@ +# MobileNetV2 Example + +## Description + +This is an example of training MobileNetV2 with ImageNet2012 dataset in MindSpore. + +## Requirements + +* Install [MindSpore](https://www.mindspore.cn/install/en). + +* Download the dataset [ImageNet2012](http://www.image-net.org/). + +> Unzip the ImageNet2012 dataset to any path you want and the folder structure should be as follows: +> ``` +> . +> ├── train # train dataset +> └── val # infer dataset +> ``` + +## Example structure + +``` shell +. +├── config.py # parameter configuration +├── dataset.py # data preprocessing +├── eval.py # infer script +├── launch.py # launcher for distributed training +├── lr_generator.py # generate learning rate for each step +├── run_infer.sh # launch infering +├── run_train.sh # launch training +└── train.py # train script +``` + +## Parameter configuration + +Parameters for both training and inference can be set in 'config.py'. + +``` +"num_classes": 1000, # dataset class num +"image_height": 224, # image height +"image_width": 224, # image width +"batch_size": 256, # training or infering batch size +"epoch_size": 200, # total training epochs, including warmup_epochs +"warmup_epochs": 4, # warmup epochs +"lr": 0.4, # base learning rate +"momentum": 0.9, # momentum +"weight_decay": 4e-5, # weight decay +"loss_scale": 1024, # loss scale +"save_checkpoint": True, # whether save checkpoint +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints +"keep_checkpoint_max": 200, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./checkpoint" # path to save checkpoint +``` + +## Running the example + +### Train + +#### Usage +Usage: sh run_train.sh [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] + +#### Launch + +``` +# training example +sh run_train.sh 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet +``` + +#### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. + +``` +epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] +epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 +epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] +epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 +``` + +### Infer + +#### Usage + +Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH] + +#### Launch + +``` +# infer example +sh run_infer.sh ~/imagenet ~/train/mobilenet-200_625.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Inference result will be stored in the example path, you can find result like the followings in `val.log`. + +``` +result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt +``` diff --git a/example/mobilenetv2_imagenet2012/config.py b/example/mobilenetv2_imagenet2012/config.py new file mode 100644 index 00000000000..32df4eabc93 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/config.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 256, + "epoch_size": 200, + "warmup_epochs": 4, + "lr": 0.4, + "momentum": 0.9, + "weight_decay": 4e-5, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 200, + "save_checkpoint_path": "./checkpoint", +}) diff --git a/example/mobilenetv2_imagenet2012/dataset.py b/example/mobilenetv2_imagenet2012/dataset.py new file mode 100644 index 00000000000..9df34d51dcd --- /dev/null +++ b/example/mobilenetv2_imagenet2012/dataset.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from config import config + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + + Returns: + dataset + """ + rank_size = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + + if rank_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True, + num_shards=rank_size, shard_id=rank_id) + + resize_height = config.image_height + resize_width = config.image_width + rescale = 1.0 / 255.0 + shift = 0.0 + buffer_size = 1000 + + # define map operations + decode_op = C.Decode() + resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.2, 1.0)) + horizontal_flip_op = C.RandomHorizontalFlip() + + resize_op = C.Resize((256, 256)) + center_crop = C.CenterCrop(resize_width) + rescale_op = C.Rescale(rescale, shift) + normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + change_swap_op = C.HWC2CHW() + + if do_train: + trans = [decode_op, resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] + else: + trans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, change_swap_op] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", operations=trans) + ds = ds.map(input_columns="label", operations=type_cast_op) + + # apply shuffle operations + ds = ds.shuffle(buffer_size=buffer_size) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/example/mobilenetv2_imagenet2012/eval.py b/example/mobilenetv2_imagenet2012/eval.py new file mode 100644 index 00000000000..6c51fc042b4 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/eval.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +eval. +""" +import os +import argparse +from dataset import create_dataset +from config import config +from mindspore import context +from mindspore.model_zoo.mobilenet import mobilenet_v2 +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) +context.set_context(enable_task_sink=True) +context.set_context(enable_loop_sink=True) +context.set_context(enable_mem_reuse=True) + +if __name__ == '__main__': + context.set_context(enable_hccl=False) + + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + net = mobilenet_v2() + + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/example/mobilenetv2_imagenet2012/launch.py b/example/mobilenetv2_imagenet2012/launch.py new file mode 100644 index 00000000000..5a8977c64b2 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/launch.py @@ -0,0 +1,150 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""launch train script""" +import os +import sys +import subprocess +import json +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes") + parser.add_argument("--nproc_per_node", type=int, default=1, + help="The number of processes to launch on each node, " + "for D training, this is recommended to be set " + "to the number of D in your system so that " + "each process can be bound to a single D.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_id", type=str, default="", + help="server ip") + parser.add_argument("--training_script", type=str, + help="The full path to the single D training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script") + # rest from the training program + args, unknown = parser.parse_known_args() + args.training_script_args = unknown + return args + + +def main(): + print("start", __file__) + args = parse_args() + print(args) + visible_devices = args.visible_devices.split(',') + assert os.path.isfile(args.training_script) + assert len(visible_devices) >= args.nproc_per_node + print('visible_devices:{}'.format(visible_devices)) + if not args.server_id: + print('pleaser input server ip!!!') + exit(0) + print('server_id:{}'.format(args.server_id)) + + # construct hccn_table + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + hccn_table = {} + hccn_table['board_id'] = '0x0000' + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + usable_dev = '' + for instance_id in range(args.nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + usable_dev += str(device_id) + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = args.server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(args.nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(args.nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(args.nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + if not os.path.exists(table_path): + os.mkdir(table_path) + table_fn = os.path.join(table_path, + 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + + # spawn the processes + current_env = os.environ.copy() + current_env["RANK_SIZE"] = str(args.nproc_per_node) + if args.nproc_per_node > 1: + current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn + processes = [] + cmds = [] + for rank_id in range(0, args.nproc_per_node): + current_env["RANK_ID"] = str(rank_id) + current_env["DEVICE_ID"] = visible_devices[rank_id] + cmd = [sys.executable, "-u"] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + cmds.append(cmd) + for process, cmd in zip(processes, cmds): + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/example/mobilenetv2_imagenet2012/lr_generator.py b/example/mobilenetv2_imagenet2012/lr_generator.py new file mode 100644 index 00000000000..68bbfe31584 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/lr_generator.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/example/mobilenetv2_imagenet2012/run_infer.sh b/example/mobilenetv2_imagenet2012/run_infer.sh new file mode 100644 index 00000000000..dc1e4d0b5d8 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/run_infer.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +if [ $# != 2 ] +then + echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +if [ ! -d $1 ] +then + echo "error: DATASET_PATH=$1 is not a directory" +exit 1 +fi + +if [ ! -f $2 ] +then + echo "error: CHECKPOINT_PATH=$2 is not a file" +exit 1 +fi + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cd ./eval || exit +python ${BASEPATH}/eval.py \ + --checkpoint_path=$2 \ + --dataset_path=$1 &> infer.log & # dataset val folder path diff --git a/example/mobilenetv2_imagenet2012/run_train.sh b/example/mobilenetv2_imagenet2012/run_train.sh new file mode 100644 index 00000000000..3f92b4f1725 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/run_train.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +if [ $# != 4 ] +then + echo "Usage: sh run_train.sh [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]" +exit 1 +fi + +if [ $1 -lt 1 ] && [ $1 -gt 8 ] +then + echo "error: DEVICE_NUM=$1 is not in (1-8)" +exit 1 +fi + +if [ ! -d $4 ] +then + echo "error: DATASET_PATH=$4 is not a directory" +exit 1 +fi + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cd ./train || exit +python ${BASEPATH}/launch.py \ + --nproc_per_node=$1 \ + --visible_devices=$3 \ + --server_id=$2 \ + --training_script=${BASEPATH}/train.py \ + --dataset_path=$4 &> train.log & # dataset train folder diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py new file mode 100644 index 00000000000..584e89fe431 --- /dev/null +++ b/example/mobilenetv2_imagenet2012/train.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import os +import time +import argparse +import random +import numpy as np +from dataset import create_dataset +from lr_generator import get_lr +from config import config +from mindspore import context +from mindspore import Tensor +from mindspore.model_zoo.mobilenet import mobilenet_v2 +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.nn.optim.momentum import Momentum +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits + +from mindspore.train.model import Model, ParallelMode + +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.loss_scale_manager import FixedLossScaleManager +import mindspore.dataset.engine as de +from mindspore.communication.management import init + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) +rank_id = int(os.getenv('RANK_ID')) +rank_size = int(os.getenv('RANK_SIZE')) +run_distribute = rank_size > 1 + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) +context.set_context(enable_task_sink=True) +context.set_context(enable_loop_sink=True) +context.set_context(enable_mem_reuse=True) + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None. + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses) + ), flush=True) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True) + + +if __name__ == '__main__': + if run_distribute: + context.set_context(enable_hccl=True) + context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + init() + else: + context.set_context(enable_hccl=False) + + epoch_size = config.epoch_size + net = mobilenet_v2(num_classes=config.num_classes) + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + + print("train args: ", args_opt, "\ncfg: ", config, + "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) + + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, + repeat_num=epoch_size, batch_size=config.batch_size) + step_size = dataset.get_dataset_size() + + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr, + warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0', + keep_batchnorm_fp32=False) + + cb = None + if rank_id == 0: + cb = [Monitor(lr_init=lr.asnumpy())] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenet", directory=config.save_checkpoint_path, config=config_ck) + cb += [ckpt_cb] + model.train(epoch_size, dataset, callbacks=cb) diff --git a/mindspore/model_zoo/mobilenet.py b/mindspore/model_zoo/mobilenet.py new file mode 100644 index 00000000000..1d4f1b10b58 --- /dev/null +++ b/mindspore/model_zoo/mobilenet.py @@ -0,0 +1,284 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 model define""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops.operations import TensorAdd +from mindspore import Parameter, Tensor +from mindspore.common.initializer import initializer + +__all__ = ['MobileNetV2', 'mobilenet_v2'] + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class DepthwiseConv(nn.Cell): + """ + Depthwise Convolution warpper definition. + + Args: + in_planes (int): Input channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + pad_mode (str): pad mode in (pad, same, valid) + channel_multiplier (int): Output channel multiplier + has_bias (bool): has bias or not + + Returns: + Tensor, output tensor. + + Examples: + >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) + """ + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): + super(DepthwiseConv, self).__init__() + self.has_bias = has_bias + self.in_channels = in_planes + self.channel_multiplier = channel_multiplier + self.out_channels = in_planes * channel_multiplier + self.kernel_size = (kernel_size, kernel_size) + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, + stride=stride, pad_mode=pad_mode, pad=pad) + self.bias_add = P.BiasAdd() + weight_shape = [channel_multiplier, in_planes, *self.kernel_size] + self.weight = Parameter(initializer('ones', weight_shape), name='weight') + + if has_bias: + bias_shape = [channel_multiplier * in_planes] + self.bias = Parameter(initializer('zeros', bias_shape), name='bias') + else: + self.bias = None + + def construct(self, x): + output = self.depthwise_conv(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + return output + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', + padding=padding) + else: + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) + layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.SequentialCell(layers) + self.add = TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + return x + + +class MobileNetV2(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (Cell): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + def __init__(self, num_classes=1000, width_mult=1., + has_dropout=False, inverted_residual_setting=None, round_nearest=8): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + # mobilenet head + head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) + self.head = nn.SequentialCell(head) + + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + x = self.head(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + for _, m in self.cells_and_names(): + if isinstance(m, (nn.Conv2d, DepthwiseConv)): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape()).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape(), dtype="float32"))) + m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape(), dtype="float32"))) + elif isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape()).astype("float32"))) + if m.bias is not None: + m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32"))) + + +def mobilenet_v2(**kwargs): + """ + Constructs a MobileNet V2 model + """ + return MobileNetV2(**kwargs)