!2729 Add resnext50 network

Merge pull request !2729 from z00378171/r0.5
This commit is contained in:
mindspore-ci-bot 2020-06-30 11:15:52 +08:00 committed by Gitee
commit 3c324e1031
23 changed files with 2038 additions and 0 deletions

View File

@ -0,0 +1,128 @@
# ResNext50 Example
## Description
This is an example of training ResNext50 with ImageNet dataset in Mindspore.
## Requirements
- Install [Mindspore](http://www.mindspore.cn/install/en).
- Downlaod the dataset ImageNet2012.
## Structure
```shell
.
└─resnext50
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training(1p)
├─run_distribute_train.sh # launch distributed training(8p)
└─run_eval.sh # launch evaluating
├─src
├─backbone
├─_init_.py # initalize
├─resnet.py # resnext50 backbone
├─utils
├─_init_.py # initalize
├─cunstom_op.py # network operation
├─logging.py # print log
├─optimizers_init_.py # get parameters
├─sampler.py # distributed sampler
├─var_init_.py # calculate gain value
├─_init_.py # initalize
├─config.py # parameter configuration
├─crossentropy.py # CrossEntropy loss function
├─dataset.py # data preprocessing
├─head.py # commom head
├─image_classification.py # get resnet
├─linear_warmup.py # linear warmup learning rate
├─warmup_cosine_annealing.py # learning rate each step
├─warmup_step_lr.py # warmup step learning rate
├─eval.py # eval net
└─train.py # train net
```
## Parameter Configuration
Parameters for both training and evaluating can be set in config.py
```
"image_height": '224,224' # image size
"num_classes": 1000, # dataset class number
"per_batch_size": 128, # batch size of input tensor
"lr": 0.05, # base learning rate
"lr_scheduler": 'cosine_annealing', # learning rate mode
"lr_epochs": '30,60,90,120', # epoch of lr changing
"lr_gamma": 0.1, # decrease lr by a factor of exponential lr_scheduler
"eta_min": 0, # eta_min in cosine_annealing scheduler
"T_max": 150, # T-max in cosine_annealing scheduler
"max_epoch": 150, # max epoch num to train the model
"backbone": 'resnext50', # backbone metwork
"warmup_epochs" : 1, # warmup epoch
"weight_decay": 0.0001, # weight decay
"momentum": 0.9, # momentum
"is_dynamic_loss_scale": 0, # dynamic loss scale
"loss_scale": 1024, # loss scale
"label_smooth": 1, # label_smooth
"label_smooth_factor": 0.1, # label_smooth_factor
"ckpt_interval": 2000, # ckpt_interval
"ckpt_path": 'outputs/', # checkpoint save location
"is_save_on_master": 1,
"rank": 0, # local rank of distributed
"group_size": 1 # world size of distributed
```
## Running the example
### Train
#### Usage
```
# distribute training example(8p)
sh run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH
# standalone training
sh run_standalone_train.sh DEVICE_ID DATA_PATH
```
#### Launch
```bash
# distributed training example(8p)
sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /ImageNet/train
# standalone training example
sh scripts/run_standalone_train.sh 0 /ImageNet_Original/train
```
#### Result
You can find checkpoint file together with result in log.
### Evaluation
#### Usage
```
# Evaluation
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH
```
#### Launch
```bash
# Evaluation with checkpoint
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt
```
> checkpoint can be produced in training process.
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
```
acc=78,16%(TOP1)
acc=93.88%(TOP5)
```

243
model_zoo/resnext50/eval.py Normal file
View File

@ -0,0 +1,243 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import os
import time
import argparse
import datetime
import glob
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.communication.management import init, get_rank, get_group_size, release
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from src.utils.logging import get_logger
from src.image_classification import get_network
from src.dataset import classification_dataset
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Ascend", save_graphs=False, device_id=devid)
class ParameterReduce(nn.Cell):
"""ParameterReduce"""
def __init__(self):
super(ParameterReduce, self).__init__()
self.cast = P.Cast()
self.reduce = P.AllReduce()
def construct(self, x):
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
out = x * one
ret = self.reduce(out)
return ret
def parse_args(cloud_args=None):
"""parse_args"""
parser = argparse.ArgumentParser('mindspore classification test')
# dataset related
parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir')
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
# network related
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config.image_size
args.num_classes = config.num_classes
args.backbone = config.backbone
args.rank = config.rank
args.group_size = config.group_size
args.image_size = list(map(int, args.image_size.split(',')))
return args
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt in zip(top5_arg, gt_class):
if gt in top5:
sub_count += 1
return sub_count
def merge_args(args, cloud_args):
"""merge_args"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def test(cloud_args=None):
"""test"""
args = parse_args(cloud_args)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
if os.path.isdir(args.pretrained):
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
print(models)
if args.graph_ckpt:
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
else:
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
args.models = sorted(models, key=f)
else:
args.models = [args.pretrained,]
for model in args.models:
de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
per_batch_size=args.per_batch_size,
max_epoch=1, rank=args.rank, group_size=args.group_size,
mode='eval')
eval_dataloader = de_dataset.create_tuple_iterator()
network = get_network(args.backbone, args.num_classes)
if network is None:
raise NotImplementedError('not implement {}'.format(args.backbone))
param_dict = load_checkpoint(model)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(model))
# must add
network.add_flags_recursive(fp16=True)
img_tot = 0
top1_correct = 0
top5_correct = 0
network.set_train(False)
t_end = time.time()
it = 0
for data, gt_classes in eval_dataloader:
output = network(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += args.per_batch_size
if args.rank == 0 and it == 0:
t_end = time.time()
it = 1
if args.rank == 0:
time_used = time.time() - t_end
fps = (img_tot - args.per_batch_size) * args.group_size / time_used
args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results))
if args.is_distributed:
model_md5 = model.replace('/', '')
tmp_dir = '/cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break
top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
results = np.array(results)
else:
results = np.array(results)
args.logger.info('after results={}'.format(results))
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
args.logger.info('after allreduce eval: top1_correct={}, tot={},'
'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
args.logger.info('after allreduce eval: top5_correct={}, tot={},'
'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
if args.is_distributed:
release()
if __name__ == "__main__":
test()

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../train.py \
--is_distribute=1 \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,24 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
python eval.py \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
python train.py \
--is_distribute=0 \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &

View File

View File

@ -0,0 +1,16 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""resnet"""
from .resnet import *

View File

@ -0,0 +1,273 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ResNet based ResNext
"""
import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd, Split, Concat
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from src.utils.cunstom_op import SEBlock, GroupConv
__all__ = ['ResNet', 'resnext50']
def weight_variable(shape, factor=0.1):
return TruncatedNormal(0.02)
def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
class _DownSample(nn.Cell):
"""
Downsample for ResNext-ResNet.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the 1*1 convolutional layer.
Returns:
Tensor, output tensor.
Examples:
>>>DownSample(32, 64, 2)
"""
def __init__(self, in_channels, out_channels, stride):
super(_DownSample, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride=stride, padding=0)
self.bn = nn.BatchNorm2d(out_channels)
def construct(self, x):
out = self.conv(x)
out = self.bn(out)
return out
class BasicBlock(nn.Cell):
"""
ResNet basic block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>>BasicBlock(32, 256, stride=2)
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = P.ReLU()
self.conv2 = conv3x3(out_channels, out_channels, stride=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class Bottleneck(nn.Cell):
"""
ResNet Bottleneck block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the initial convolutional layer. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>>Bottleneck(3, 256, stride=2)
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
base_width=64, groups=1, use_se=False, **kwargs):
super(Bottleneck, self).__init__()
width = int(out_channels * (base_width / 64.0)) * groups
self.groups = groups
self.conv1 = conv1x1(in_channels, width, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.relu = P.ReLU()
self.conv3x3s = nn.CellList()
self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups)
self.op_split = Split(axis=1, output_num=self.groups)
self.op_concat = Concat(axis=1)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels * self.expansion)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.cast = P.Cast()
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (cell): Block for network.
layers (list): Numbers of block in different layers.
width_per_group (int): Width of every group.
groups (int): Groups number.
Returns:
Tuple, output tensor tuple.
Examples:
>>>ResNet()
"""
def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False):
super(ResNet, self).__init__()
self.in_channels = 64
self.groups = groups
self.base_width = width_per_group
self.conv = conv7x7(3, self.in_channels, stride=2, padding=3)
self.bn = nn.BatchNorm2d(self.in_channels)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se)
self.out_channels = 512 * block.expansion
self.cast = P.Cast()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False):
"""_make_layer"""
down_sample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
down_sample = _DownSample(self.in_channels,
out_channels * block.expansion,
stride=stride)
layers = []
layers.append(block(self.in_channels,
out_channels,
stride=stride,
down_sample=down_sample,
base_width=self.base_width,
groups=self.groups,
use_se=use_se))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks_num):
layers.append(block(self.in_channels, out_channels,
base_width=self.base_width, groups=self.groups, use_se=use_se))
return nn.SequentialCell(layers)
def get_out_channels(self):
return self.out_channels
def resnext50():
return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32)

View File

@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config"""
from easydict import EasyDict as ed
config = ed({
"image_size": '224,224',
"num_classes": 1000,
"lr": 0.4,
"lr_scheduler": 'cosine_annealing',
"lr_epochs": '30,60,90,120',
"lr_gamma": 0.1,
"eta_min": 0,
"T_max": 150,
"max_epoch": 150,
"backbone": 'resnext50',
"warmup_epochs": 1,
"weight_decay": 0.0001,
"momentum": 0.9,
"is_dynamic_loss_scale": 0,
"loss_scale": 1024,
"label_smooth": 1,
"label_smooth_factor": 0.1,
"ckpt_interval": 1250,
"ckpt_path": 'outputs/',
"is_save_on_master": 1,
"rank": 0,
"group_size": 1
})

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
define loss function for network.
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""
the redefined loss function with SoftmaxCrossEntropyWithLogits.
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,155 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import os
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as V_C
from PIL import Image, ImageFile
from src.utils.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
class TxtDataset():
"""
create txt dataset.
Args:
Returns:
de_dataset.
"""
def __init__(self, root, txt_name):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
fin = open(txt_name, "r")
for line in fin:
img_name, label = line.strip().split(' ')
self.imgs.append(os.path.join(root, img_name))
self.labels.append(int(label))
fin.close()
def __getitem__(self, index):
img = Image.open(self.imgs[index]).convert('RGB')
return img, self.labels[index]
def __len__(self):
return len(self.imgs)
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
mode='train',
input_mode='folder',
root='',
num_parallel_workers=None,
shuffle=None,
sampler=None,
class_indexing=None,
drop_remainder=True,
transform=None,
target_transform=None):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (str): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
max_epoch (int): the number of epochs.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from mindvision.common.datasets.classification import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> dataset_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> dataset_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if transform is None:
if mode == 'train':
transform_img = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5),
V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = [
V_C.Decode(),
V_C.Resize((256, 256)),
V_C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = transform
if target_transform is None:
transform_label = [C.TypeCast(mstype.int32)]
else:
transform_label = target_transform
if input_mode == 'folder':
de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=group_size, shard_id=rank)
else:
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset.set_dataset_size(len(sampler))
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
columns_to_project = ["image", "label"]
de_dataset = de_dataset.project(columns=columns_to_project)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
de_dataset = de_dataset.repeat(max_epoch)
return de_dataset

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
common architecture.
"""
import mindspore.nn as nn
from src.utils.cunstom_op import GlobalAvgPooling
__all__ = ['CommonHead']
class CommonHead(nn.Cell):
"""
commom architecture definition.
Args:
num_classes (int): Number of classes.
out_channels (int): Output channels.
Returns:
Tensor, output tensor.
"""
def __init__(self, num_classes, out_channels):
super(CommonHead, self).__init__()
self.avgpool = GlobalAvgPooling()
self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True)
def construct(self, x):
x = self.avgpool(x)
x = self.fc(x)
return x

View File

@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image classifiation.
"""
import math
import mindspore.nn as nn
from mindspore.common import initializer as init
import src.backbone as backbones
import src.head as heads
from src.utils.var_init import default_recurisive_init, KaimingNormal
class ImageClassificationNetwork(nn.Cell):
"""
architecture of image classification network.
Args:
Returns:
Tensor, output tensor.
"""
def __init__(self, backbone, head):
super(ImageClassificationNetwork, self).__init__()
self.backbone = backbone
self.head = head
def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x
class Resnet(ImageClassificationNetwork):
"""
Resnet architecture.
Args:
backbone_name (string): backbone.
num_classes (int): number of classes.
Returns:
Resnet.
"""
def __init__(self, backbone_name, num_classes):
self.backbone_name = backbone_name
backbone = backbones.__dict__[self.backbone_name]()
out_channels = backbone.get_out_channels()
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
super(Resnet, self).__init__(backbone, head)
default_recurisive_init(self)
for cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor()
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor()
cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor()
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for cell in self.cells_and_names():
if isinstance(cell, backbones.resnet.Bottleneck):
cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor()
elif isinstance(cell, backbones.resnet.BasicBlock):
cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor()
def get_network(backbone_name, num_classes):
if backbone_name in ['resnext50']:
return Resnet(backbone_name, num_classes)
return None

View File

@ -0,0 +1,21 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
linear warm up learning rate.
"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network operations
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class GlobalAvgPooling(nn.Cell):
"""
global average pooling feature map.
Args:
mean (tuple): means for each channel.
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(True)
self.shape = P.Shape()
self.reshape = P.Reshape()
def construct(self, x):
x = self.mean(x, (2, 3))
b, c, _, _ = self.shape(x)
x = self.reshape(x, (b, c))
return x
class SEBlock(nn.Cell):
"""
squeeze and excitation block.
Args:
channel (int): number of feature maps.
reduction (int): weight.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = GlobalAvgPooling()
self.fc1 = nn.Dense(channel, channel // reduction)
self.relu = P.ReLU()
self.fc2 = nn.Dense(channel // reduction, channel)
self.sigmoid = P.Sigmoid()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.sum = P.Sum()
self.cast = P.Cast()
def construct(self, x):
b, c = self.shape(x)
y = self.avg_pool(x)
y = self.reshape(y, (b, c))
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
y = self.reshape(y, (b, c, 1, 1))
return x * y
class GroupConv(nn.Cell):
"""
group convolution operation.
Args:
in_channels (int): Input channels of feature map.
out_channels (int): Output channels of feature map.
kernel_size (int): Size of convolution kernel.
stride (int): Stride size for the group convolution layer.
Returns:
tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
assert in_channels % groups == 0 and out_channels % groups == 0
self.groups = groups
self.convs = nn.CellList()
self.op_split = P.Split(axis=1, output_num=self.groups)
self.op_concat = P.Concat(axis=1)
self.cast = P.Cast()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1))
def construct(self, x):
features = self.op_split(x)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
out = self.op_concat(outputs)
return out

View File

@ -0,0 +1,82 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
optimizer parameters.
"""
def get_param_groups(network):
"""get param groups"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
choose samples from the dataset
"""
import math
import numpy as np
class DistributedSampler():
"""
sampling the dataset.
Args:
Returns:
num_samples, number of samples.
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,213 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
np.random.seed(0)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
np.random.seed(0)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up cosine annealing learning rate.
"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""warm up cosine annealing learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up step learning rate.
"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup_step_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

View File

@ -0,0 +1,289 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train ImageNet."""
import os
import time
import argparse
import datetime
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from src.dataset import classification_dataset
from src.crossentropy import CrossEntropy
from src.warmup_step_lr import warmup_step_lr
from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.image_classification import get_network
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Ascend", save_graphs=False, device_id=devid)
class BuildTrainNetwork(nn.Cell):
"""build training network"""
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class ProgressMonitor(Callback):
"""monitor loss and time"""
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.me_epoch_start_time = 0
self.me_epoch_start_step_num = 0
self.args = args
self.ckpt_history = []
def begin(self, run_context):
self.args.logger.info('start network train...')
def epoch_begin(self, run_context):
pass
def epoch_end(self, run_context, *me_args):
cb_params = run_context.original_args()
me_step = cb_params.cur_step_num - 1
real_epoch = me_step // self.args.steps_per_epoch
time_used = time.time() - self.me_epoch_start_time
fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}'
'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean))
if self.args.rank_save_ckpt_flag:
import glob
ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
for ckpt in ckpts:
ckpt_fn = os.path.basename(ckpt)
if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
continue
if ckpt in self.ckpt_history:
continue
self.ckpt_history.append(ckpt)
self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},'
'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn))
self.me_epoch_start_step_num = me_step
self.me_epoch_start_time = time.time()
def step_begin(self, run_context):
pass
def step_end(self, run_context, *me_args):
pass
def end(self, run_context):
self.args.logger.info('end network train...')
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config.image_size
args.num_classes = config.num_classes
args.lr = config.lr
args.lr_scheduler = config.lr_scheduler
args.lr_epochs = config.lr_epochs
args.lr_gamma = config.lr_gamma
args.eta_min = config.eta_min
args.T_max = config.T_max
args.max_epoch = config.max_epoch
args.backbone = config.backbone
args.warmup_epochs = config.warmup_epochs
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.is_dynamic_loss_scale = config.is_dynamic_loss_scale
args.loss_scale = config.loss_scale
args.label_smooth = config.label_smooth
args.label_smooth_factor = config.label_smooth_factor
args.ckpt_interval = config.ckpt_interval
args.ckpt_path = config.ckpt_path
args.is_save_on_master = config.is_save_on_master
args.rank = config.rank
args.group_size = config.group_size
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.image_size = list(map(int, args.image_size.split(',')))
return args
def merge_args(args, cloud_args):
"""dictionary"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def train(cloud_args=None):
"""training process"""
args = parse_args(cloud_args)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, args.max_epoch,
args.rank, args.group_size)
de_dataset.map_model = 4 # !!!important
args.steps_per_epoch = de_dataset.get_dataset_size()
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# get network and init
network = get_network(args.backbone, args.num_classes)
if network is None:
raise NotImplementedError('not implement {}'.format(args.backbone))
network.add_flags_recursive(fp16=True)
# loss
if not args.label_smooth:
args.label_smooth_factor = 0.0
criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
num_classes=args.num_classes)
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
# lr scheduler
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
criterion.add_flags_recursive(fp32=True)
# package training process, adjust lr + forward + backward + optimizer
train_net = BuildTrainNetwork(network, criterion)
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
if args.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
# Model api changed since TR5_branch 2020/03/09
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, mirror_mean=True)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager)
# checkpoint save
progress_cb = ProgressMonitor(args)
callbacks = [progress_cb,]
if args.rank_save_ckpt_flag:
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
if __name__ == "__main__":
train()