!13114 retinanet performance improve
From: @chenmai1102 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
0d1d043d80
|
@ -165,11 +165,11 @@ MSCOCO2017
|
|||
# 八卡并行训练示例:
|
||||
|
||||
创建 RANK_TABLE_FILE
|
||||
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
# 单卡训练示例:
|
||||
|
||||
sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
```
|
||||
|
||||
|
@ -182,6 +182,9 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional)
|
|||
```运行
|
||||
# 训练示例
|
||||
|
||||
训练前,先创建MindRecord文件,以COCO数据集为例
|
||||
python create_data.py --dataset coco
|
||||
|
||||
python:
|
||||
data和存储mindrecord文件的路径在config里设置
|
||||
|
||||
|
@ -193,12 +196,12 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional)
|
|||
|
||||
# 八卡并行训练示例(在retinanet目录下运行):
|
||||
|
||||
sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小)
|
||||
例如:sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
|
||||
sh scripts/run_distribute_train.sh 8 500 0.1 RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小)
|
||||
例如:sh scripts/run_distribute_train.sh 8 500 0.1 scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322
|
||||
|
||||
# 单卡训练示例(在retinanet目录下运行):
|
||||
|
||||
sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322
|
||||
sh scripts/run_single_train.sh 0 500 0.1 /dataset/retinanet-322_458.ckpt 322
|
||||
```
|
||||
|
||||
#### 结果
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord for training retinanet."""
|
||||
|
||||
import argparse
|
||||
from src.dataset import create_mindrecord
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="retinanet dataset create")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
args_opt = parser.parse_args()
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "retinanet.mindrecord", True)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export for retinanet"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.config import config
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="retinanet", help="output file name.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
|
||||
backbone = resnet50(config.num_classes)
|
||||
net = retinanet50(backbone, config)
|
||||
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
shape = [args_opt.batch_size, 3] + config.img_shape
|
||||
input_data = Tensor(np.zeros(shape), mstype.float32)
|
||||
export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -21,27 +21,24 @@ echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /op
|
|||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 5 ] && [ $# != 7 ]
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] \
|
||||
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True
|
||||
core_num=`cat /proc/cpuinfo |grep "processor"|wc -l`
|
||||
process_cores=$(($core_num/8))
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$6
|
||||
PRE_TRAINED_EPOCH_SIZE=$7
|
||||
export RANK_TABLE_FILE=$5
|
||||
PRE_TRAINED=$5
|
||||
PRE_TRAINED_EPOCH_SIZE=$6
|
||||
export RANK_TABLE_FILE=$4
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
@ -51,27 +48,30 @@ do
|
|||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
start=`expr $i \* $process_cores`
|
||||
end=`expr $start \+ $(($process_cores-1))`
|
||||
cmdopt=$start"-"$end
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 5 ]
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py \
|
||||
taskset -c $cmdopt python train.py \
|
||||
--workers=$process_cores \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python train.py \
|
||||
taskset -c $cmdopt python train.py \
|
||||
--workers=$process_cores \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
@ -20,9 +21,9 @@ echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458.
|
|||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
if [ $# != 3 ] && [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] \
|
||||
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -37,9 +38,8 @@ echo "After running the script, the network runs in the background. The log will
|
|||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$5
|
||||
PRE_TRAINED_EPOCH_SIZE=$6
|
||||
PRE_TRAINED=$4
|
||||
PRE_TRAINED_EPOCH_SIZE=$5
|
||||
|
||||
rm -rf LOG$1
|
||||
mkdir ./LOG$1
|
||||
|
@ -48,23 +48,21 @@ cp -r ./src ./LOG$1
|
|||
cd ./LOG$1 || exit
|
||||
echo "start training for device $1"
|
||||
env > env.log
|
||||
if [ $# == 4 ]
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train,py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
|
|
|
@ -389,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="reti
|
|||
|
||||
|
||||
def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=64):
|
||||
is_training=True, num_parallel_workers=24):
|
||||
"""Creatr retinanet dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
|
||||
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
|
|
|
@ -251,9 +251,15 @@ class retinanetWithLossCell(nn.Cell):
|
|||
self.expand_dims = P.ExpandDims()
|
||||
self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
|
||||
self.loc_loss = nn.SmoothL1Loss()
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.network.to_float(mstype.float16)
|
||||
|
||||
def construct(self, x, gt_loc, gt_label, num_matched_boxes):
|
||||
pred_loc, pred_label = self.network(x)
|
||||
pred_loc = self.cast(pred_loc, mstype.float32)
|
||||
pred_label = self.cast(pred_label, mstype.float32)
|
||||
|
||||
mask = F.cast(self.less(0, gt_label), mstype.float32)
|
||||
num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init, get_rank
|
||||
|
@ -29,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from mindspore.common import set_seed
|
||||
from src.retinanet import retinanetWithLossCell, TrainingWrapper, retinanet50, resnet50
|
||||
from src.config import config
|
||||
from src.dataset import create_retinanet_dataset, create_mindrecord
|
||||
from src.dataset import create_retinanet_dataset
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
|
||||
|
@ -59,15 +58,14 @@ class Monitor(Callback):
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="retinanet training")
|
||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
|
||||
help="If set it true, only create Mindrecord, default is False.")
|
||||
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is False.")
|
||||
parser.add_argument("--workers", type=int, default=24, help="Num parallel workers.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
||||
|
@ -98,56 +96,55 @@ def main():
|
|||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "retinanet.mindrecord", True)
|
||||
mindrecord_file = os.path.join(config.mindrecord_dir, "retinanet.mindrecord0")
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
|
||||
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
|
||||
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
|
||||
num_parallel_workers=args_opt.workers,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
|
||||
backbone = resnet50(config.num_classes)
|
||||
retinanet = retinanet50(backbone, config)
|
||||
net = retinanetWithLossCell(retinanet, config)
|
||||
net.to_float(mindspore.float16)
|
||||
init_net_param(net)
|
||||
backbone = resnet50(config.num_classes)
|
||||
retinanet = retinanet50(backbone, config)
|
||||
net = retinanetWithLossCell(retinanet, config)
|
||||
init_net_param(net)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(global_step=config.global_step,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
|
||||
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
model = Model(net)
|
||||
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
|
||||
cb = [TimeMonitor(), LossMonitor()]
|
||||
cb += [Monitor(lr_init=lr.asnumpy())]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
if args_opt.distribute:
|
||||
if rank == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
lr = Tensor(get_lr(global_step=config.global_step,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
|
||||
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
model = Model(net)
|
||||
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
|
||||
cb = [TimeMonitor(), LossMonitor()]
|
||||
cb += [Monitor(lr_init=lr.asnumpy())]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
if args_opt.distribute:
|
||||
if rank == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue