add mobilenetv2

This commit is contained in:
wandongdong 2020-04-24 21:15:13 +08:00
parent e8172b3e8f
commit 09b2dcb3fb
10 changed files with 979 additions and 0 deletions

View File

@ -0,0 +1,101 @@
# MobileNetV2 Example
## Description
This is an example of training MobileNetV2 with ImageNet2012 dataset in MindSpore.
## Requirements
* Install [MindSpore](https://www.mindspore.cn/install/en).
* Download the dataset [ImageNet2012](http://www.image-net.org/).
> Unzip the ImageNet2012 dataset to any path you want and the folder structure should be as follows:
> ```
> .
> ├── train # train dataset
> └── val # infer dataset
> ```
## Example structure
``` shell
.
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├── eval.py # infer script
├── launch.py # launcher for distributed training
├── lr_generator.py # generate learning rate for each step
├── run_infer.sh # launch infering
├── run_train.sh # launch training
└── train.py # train script
```
## Parameter configuration
Parameters for both training and inference can be set in 'config.py'.
```
"num_classes": 1000, # dataset class num
"image_height": 224, # image height
"image_width": 224, # image width
"batch_size": 256, # training or infering batch size
"epoch_size": 200, # total training epochs, including warmup_epochs
"warmup_epochs": 4, # warmup epochs
"lr": 0.4, # base learning rate
"momentum": 0.9, # momentum
"weight_decay": 4e-5, # weight decay
"loss_scale": 1024, # loss scale
"save_checkpoint": True, # whether save checkpoint
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints
"keep_checkpoint_max": 200, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./checkpoint" # path to save checkpoint
```
## Running the example
### Train
#### Usage
Usage: sh run_train.sh [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
#### Launch
```
# training example
sh run_train.sh 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet
```
#### Result
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
```
epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
```
### Infer
#### Usage
Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]
#### Launch
```
# infer example
sh run_infer.sh ~/imagenet ~/train/mobilenet-200_625.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the example path, you can find result like the followings in `val.log`.
```
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
```

View File

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
})

View File

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from config import config
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if rank_size == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=16, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
buffer_size = 1000
# define map operations
decode_op = C.Decode()
resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.2, 1.0))
horizontal_flip_op = C.RandomHorizontalFlip()
resize_op = C.Resize((256, 256))
center_crop = C.CenterCrop(resize_width)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
change_swap_op = C.HWC2CHW()
if do_train:
trans = [decode_op, resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
else:
trans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans)
ds = ds.map(input_columns="label", operations=type_cast_op)
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
from dataset import create_dataset
from config import config
from mindspore import context
from mindspore.model_zoo.mobilenet import mobilenet_v2
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
context.set_context(enable_hccl=False)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
net = mobilenet_v2()
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""launch train script"""
import os
import sys
import subprocess
import json
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training launch "
"helper utilty that will spawn up "
"multiple distributed processes")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for D training, this is recommended to be set "
"to the number of D in your system so that "
"each process can be bound to a single D.")
parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
help="will use the visible devices sequentially")
parser.add_argument("--server_id", type=str, default="",
help="server ip")
parser.add_argument("--training_script", type=str,
help="The full path to the single D training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
args, unknown = parser.parse_known_args()
args.training_script_args = unknown
return args
def main():
print("start", __file__)
args = parse_args()
print(args)
visible_devices = args.visible_devices.split(',')
assert os.path.isfile(args.training_script)
assert len(visible_devices) >= args.nproc_per_node
print('visible_devices:{}'.format(visible_devices))
if not args.server_id:
print('pleaser input server ip!!!')
exit(0)
print('server_id:{}'.format(args.server_id))
# construct hccn_table
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
hccn_table['board_id'] = '0x0000'
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
usable_dev = ''
for instance_id in range(args.nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
usable_dev += str(device_id)
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = args.server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(args.nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(args.nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(args.nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(args.nproc_per_node)
hccn_table['status'] = 'completed'
# save hccn_table to file
table_path = os.getcwd()
if not os.path.exists(table_path):
os.mkdir(table_path)
table_fn = os.path.join(table_path,
'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id))
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
sys.stdout.flush()
# spawn the processes
current_env = os.environ.copy()
current_env["RANK_SIZE"] = str(args.nproc_per_node)
if args.nproc_per_node > 1:
current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn
processes = []
cmds = []
for rank_id in range(0, args.nproc_per_node):
current_env["RANK_ID"] = str(rank_id)
current_env["DEVICE_ID"] = visible_devices[rank_id]
cmd = [sys.executable, "-u"]
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
cmds.append(cmd)
for process, cmd in zip(processes, cmds):
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,33 @@
#!/usr/bin/env bash
if [ $# != 2 ]
then
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cd ./eval || exit
python ${BASEPATH}/eval.py \
--checkpoint_path=$2 \
--dataset_path=$1 &> infer.log & # dataset val folder path

View File

@ -0,0 +1,33 @@
#!/usr/bin/env bash
if [ $# != 4 ]
then
echo "Usage: sh run_train.sh [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cd ./train || exit
python ${BASEPATH}/launch.py \
--nproc_per_node=$1 \
--visible_devices=$3 \
--server_id=$2 \
--training_script=${BASEPATH}/train.py \
--dataset_path=$4 &> train.log & # dataset train folder

View File

@ -0,0 +1,149 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import time
import argparse
import random
import numpy as np
from dataset import create_dataset
from lr_generator import get_lr
from config import config
from mindspore import context
from mindspore import Tensor
from mindspore.model_zoo.mobilenet import mobilenet_v2
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.dataset.engine as de
from mindspore.communication.management import init
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
rank_id = int(os.getenv('RANK_ID'))
rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None.
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)
), flush=True)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True)
if __name__ == '__main__':
if run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = mobilenet_v2(num_classes=config.num_classes)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
print("train args: ", args_opt, "\ncfg: ", config,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr,
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0',
keep_batchnorm_fp32=False)
cb = None
if rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)

View File

@ -0,0 +1,284 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2 model define"""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd
from mindspore import Parameter, Tensor
from mindspore.common.initializer import initializer
__all__ = ['MobileNetV2', 'mobilenet_v2']
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class GlobalAvgPooling(nn.Cell):
"""
Global avg pooling definition.
Args:
Returns:
Tensor, output tensor.
Examples:
>>> GlobalAvgPooling()
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class DepthwiseConv(nn.Cell):
"""
Depthwise Convolution warpper definition.
Args:
in_planes (int): Input channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
pad_mode (str): pad mode in (pad, same, valid)
channel_multiplier (int): Output channel multiplier
has_bias (bool): has bias or not
Returns:
Tensor, output tensor.
Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
"""
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__()
self.has_bias = has_bias
self.in_channels = in_planes
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class ConvBNReLU(nn.Cell):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
padding=padding)
else:
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class InvertedResidual(nn.Cell):
"""
Mobilenetv2 residual block definition.
Args:
inp (int): Input channel.
oup (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
expand_ratio (int): expand ration of input channel
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, 1, 1)
"""
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.SequentialCell(layers)
self.add = TensorAdd()
self.cast = P.Cast()
def construct(self, x):
identity = x
x = self.conv(x)
if self.use_res_connect:
return self.add(identity, x)
return x
class MobileNetV2(nn.Cell):
"""
MobileNetV2 architecture.
Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.
Examples:
>>> MobileNetV2(num_classes=1000)
"""
def __init__(self, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
# setting of inverted residual blocks
self.cfgs = inverted_residual_setting
if inverted_residual_setting is None:
self.cfgs = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList
self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
self.head = nn.SequentialCell(head)
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.head(x)
return x
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_parameter_data(Tensor(np.ones(m.gamma.data.shape(), dtype="float32")))
m.beta.set_parameter_data(Tensor(np.zeros(m.beta.data.shape(), dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape()).astype("float32")))
if m.bias is not None:
m.bias.set_parameter_data(Tensor(np.zeros(m.bias.data.shape(), dtype="float32")))
def mobilenet_v2(**kwargs):
"""
Constructs a MobileNet V2 model
"""
return MobileNetV2(**kwargs)