forked from mindspore-Ecosystem/mindspore
yolov3_darknet53 suit for gpu
This commit is contained in:
parent
98528bbc16
commit
1f0a760cdb
|
@ -53,8 +53,8 @@ Dataset used: [COCO2014](https://cocodataset.org/#download)
|
||||||
|
|
||||||
# [Environment Requirements](#contents)
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware(Ascend)
|
- Hardware(Ascend/GPU)
|
||||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
|
@ -65,7 +65,7 @@ Dataset used: [COCO2014](https://cocodataset.org/#download)
|
||||||
|
|
||||||
# [Quick Start](#contents)
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
After installing MindSpore via the official website, you can start training and evaluation in Ascend as follows:
|
After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
|
||||||
|
|
||||||
```
|
```
|
||||||
# The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper.
|
# The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper.
|
||||||
|
@ -87,9 +87,12 @@ python train.py \
|
||||||
# standalone training example(1p) by shell script
|
# standalone training example(1p) by shell script
|
||||||
sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt
|
sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt
|
||||||
|
|
||||||
# distributed training example(8p) by shell script
|
# For Ascend device, distributed training example(8p) by shell script
|
||||||
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
|
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
|
||||||
|
|
||||||
|
# For GPU device, distributed training example(8p) by shell script
|
||||||
|
sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt
|
||||||
|
|
||||||
# run evaluation by python command
|
# run evaluation by python command
|
||||||
python eval.py \
|
python eval.py \
|
||||||
--data_dir=./dataset/coco2014 \
|
--data_dir=./dataset/coco2014 \
|
||||||
|
@ -113,6 +116,9 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
|
||||||
├─run_standalone_train.sh # launch standalone training(1p) in ascend
|
├─run_standalone_train.sh # launch standalone training(1p) in ascend
|
||||||
├─run_distribute_train.sh # launch distributed training(8p) in ascend
|
├─run_distribute_train.sh # launch distributed training(8p) in ascend
|
||||||
└─run_eval.sh # launch evaluating in ascend
|
└─run_eval.sh # launch evaluating in ascend
|
||||||
|
├─run_standalone_train_gpu.sh # launch standalone training(1p) in gpu
|
||||||
|
├─run_distribute_train_gpu.sh # launch distributed training(8p) in gpu
|
||||||
|
└─run_eval_gpu.sh # launch evaluating in gpu
|
||||||
├─src
|
├─src
|
||||||
├─__init__.py # python init file
|
├─__init__.py # python init file
|
||||||
├─config.py # parameter configuration
|
├─config.py # parameter configuration
|
||||||
|
@ -138,6 +144,7 @@ Major parameters in train.py as follow.
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||||
--data_dir DATA_DIR Train dataset directory.
|
--data_dir DATA_DIR Train dataset directory.
|
||||||
--per_batch_size PER_BATCH_SIZE
|
--per_batch_size PER_BATCH_SIZE
|
||||||
Batch size for Training. Default: 32.
|
Batch size for Training. Default: 32.
|
||||||
|
@ -212,7 +219,7 @@ python train.py \
|
||||||
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
The python command above will run in the background, you can view the results through the file `log.txt`.
|
The python command above will run in the background, you can view the results through the file `log.txt`. If running on GPU, please add `--device_target=GPU` in the python command.
|
||||||
|
|
||||||
After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows:
|
After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows:
|
||||||
|
|
||||||
|
@ -228,9 +235,14 @@ The model checkpoint will be saved in outputs directory.
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
|
For Ascend device, distributed training example(8p) by shell script
|
||||||
```
|
```
|
||||||
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
|
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
|
||||||
```
|
```
|
||||||
|
For GPU device, distributed training example(8p) by shell script
|
||||||
|
```
|
||||||
|
sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows:
|
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows:
|
||||||
|
|
||||||
|
@ -254,7 +266,7 @@ epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e
|
||||||
|
|
||||||
### Evaluation
|
### Evaluation
|
||||||
|
|
||||||
Before running the command below.
|
Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
|
||||||
|
|
||||||
```
|
```
|
||||||
python eval.py \
|
python eval.py \
|
||||||
|
|
|
@ -35,9 +35,6 @@ from src.logger import get_logger
|
||||||
from src.yolo_dataset import create_yolo_dataset
|
from src.yolo_dataset import create_yolo_dataset
|
||||||
from src.config import ConfigYOLOV3DarkNet53
|
from src.config import ConfigYOLOV3DarkNet53
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
|
|
||||||
|
|
||||||
|
|
||||||
class Redirct:
|
class Redirct:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -208,6 +205,10 @@ def parse_args():
|
||||||
"""Parse arguments."""
|
"""Parse arguments."""
|
||||||
parser = argparse.ArgumentParser('mindspore coco testing')
|
parser = argparse.ArgumentParser('mindspore coco testing')
|
||||||
|
|
||||||
|
# device related
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
|
||||||
# dataset related
|
# dataset related
|
||||||
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
||||||
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
|
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
|
||||||
|
@ -243,10 +244,13 @@ def test():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||||
|
|
||||||
# logger
|
# logger
|
||||||
args.outputs_dir = os.path.join(args.log_path,
|
args.outputs_dir = os.path.join(args.log_path,
|
||||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
rank_id = int(os.environ.get('RANK_ID'))
|
rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
|
||||||
args.logger = get_logger(args.outputs_dir, rank_id)
|
args.logger = get_logger(args.outputs_dir, rank_id)
|
||||||
|
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
|
||||||
|
rm -rf ./train_parallel
|
||||||
|
mkdir ./train_parallel
|
||||||
|
cp ../*.py ./train_parallel
|
||||||
|
cp -r ../src ./train_parallel
|
||||||
|
cd ./train_parallel || exit
|
||||||
|
env > env.log
|
||||||
|
mpirun --allow-run-as-root -n ${DEVICE_NUM} python train.py \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||||
|
--device_target=GPU \
|
||||||
|
--is_distributed=1 \
|
||||||
|
--lr=0.1 \
|
||||||
|
--T_max=320 \
|
||||||
|
--max_epoch=320 \
|
||||||
|
--warmup_epochs=4 \
|
||||||
|
--training_shape=416 \
|
||||||
|
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||||
|
cd ..
|
|
@ -0,0 +1,67 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
CHECKPOINT_PATH=$(get_real_path $2)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
echo $CHECKPOINT_PATH
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $CHECKPOINT_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start infering for device $DEVICE_ID"
|
||||||
|
python eval.py \
|
||||||
|
--device_target="GPU" \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained=$CHECKPOINT_PATH \
|
||||||
|
--testing_shape=416 > log.txt 2>&1 &
|
||||||
|
cd ..
|
|
@ -0,0 +1,75 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
|
||||||
|
python train.py \
|
||||||
|
--device_targe="GPU" \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||||
|
--is_distributed=0 \
|
||||||
|
--lr=0.1 \
|
||||||
|
--T_max=320 \
|
||||||
|
--max_epoch=320 \
|
||||||
|
--warmup_epochs=4 \
|
||||||
|
--training_shape=416 \
|
||||||
|
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||||
|
cd ..
|
|
@ -465,6 +465,11 @@ class MultiScaleTrans:
|
||||||
self.seed_list = self.generate_seed_list(seed_num=self.seed_num)
|
self.seed_list = self.generate_seed_list(seed_num=self.seed_num)
|
||||||
self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate))
|
self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate))
|
||||||
self.device_num = device_num
|
self.device_num = device_num
|
||||||
|
self.anchor_scales = config.anchor_scales
|
||||||
|
self.num_classes = config.num_classes
|
||||||
|
self.max_box = config.max_box
|
||||||
|
self.label_smooth = config.label_smooth
|
||||||
|
self.label_smooth_factor = config.label_smooth_factor
|
||||||
|
|
||||||
def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)):
|
def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)):
|
||||||
seed_list = []
|
seed_list = []
|
||||||
|
@ -474,13 +479,20 @@ class MultiScaleTrans:
|
||||||
seed_list.append(seed)
|
seed_list.append(seed)
|
||||||
return seed_list
|
return seed_list
|
||||||
|
|
||||||
def __call__(self, imgs, annos, batchInfo):
|
def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batchInfo):
|
||||||
epoch_num = batchInfo.get_epoch_num()
|
epoch_num = batchInfo.get_epoch_num()
|
||||||
size_idx = int(batchInfo.get_batch_num() / self.resize_rate)
|
size_idx = int(batchInfo.get_batch_num() / self.resize_rate)
|
||||||
seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num]
|
seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num]
|
||||||
ret_imgs = []
|
ret_imgs = []
|
||||||
ret_annos = []
|
ret_annos = []
|
||||||
|
|
||||||
|
bbox1 = []
|
||||||
|
bbox2 = []
|
||||||
|
bbox3 = []
|
||||||
|
gt1 = []
|
||||||
|
gt2 = []
|
||||||
|
gt3 = []
|
||||||
|
|
||||||
if self.size_dict.get(seed_key, None) is None:
|
if self.size_dict.get(seed_key, None) is None:
|
||||||
random.seed(seed_key)
|
random.seed(seed_key)
|
||||||
new_size = random.choice(self.config.multi_scale)
|
new_size = random.choice(self.config.multi_scale)
|
||||||
|
@ -491,8 +503,19 @@ class MultiScaleTrans:
|
||||||
for img, anno in zip(imgs, annos):
|
for img, anno in zip(imgs, annos):
|
||||||
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num)
|
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num)
|
||||||
ret_imgs.append(img.transpose(2, 0, 1).copy())
|
ret_imgs.append(img.transpose(2, 0, 1).copy())
|
||||||
ret_annos.append(anno)
|
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
|
||||||
return np.array(ret_imgs), np.array(ret_annos)
|
_preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2],
|
||||||
|
num_classes=self.num_classes, max_boxes=self.max_box,
|
||||||
|
label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor)
|
||||||
|
bbox1.append(bbox_true_1)
|
||||||
|
bbox2.append(bbox_true_2)
|
||||||
|
bbox3.append(bbox_true_3)
|
||||||
|
gt1.append(gt_box1)
|
||||||
|
gt2.append(gt_box2)
|
||||||
|
gt3.append(gt_box3)
|
||||||
|
ret_annos.append(0)
|
||||||
|
return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \
|
||||||
|
np.array(gt1), np.array(gt2), np.array(gt3)
|
||||||
|
|
||||||
|
|
||||||
def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2,
|
def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2,
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
"""Util class or function."""
|
"""Util class or function."""
|
||||||
from mindspore.train.serialization import load_checkpoint
|
from mindspore.train.serialization import load_checkpoint
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
from .yolo import YoloLossBlock
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
|
@ -175,3 +178,10 @@ class ShapeRecord:
|
||||||
for key in self.shape_record:
|
for key in self.shape_record:
|
||||||
rate = self.shape_record[key] / float(self.shape_record['total'])
|
rate = self.shape_record[key] / float(self.shape_record['total'])
|
||||||
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
|
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
|
||||||
|
|
||||||
|
|
||||||
|
def keep_loss_fp32(network):
|
||||||
|
"""Keep loss of network with float32"""
|
||||||
|
for _, cell in network.cells_and_names():
|
||||||
|
if isinstance(cell, (YoloLossBlock,)):
|
||||||
|
cell.to_float(mstype.float32)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""YOLOV3 dataset."""
|
"""YOLOV3 dataset."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
|
@ -126,7 +127,7 @@ class COCOYoloDataset:
|
||||||
tmp.append(int(label))
|
tmp.append(int(label))
|
||||||
# tmp [x_min y_min x_max y_max, label]
|
# tmp [x_min y_min x_max y_max, label]
|
||||||
out_target.append(tmp)
|
out_target.append(tmp)
|
||||||
return img, out_target
|
return img, out_target, [], [], [], [], [], []
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.img_ids)
|
return len(self.img_ids)
|
||||||
|
@ -155,20 +156,22 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num,
|
||||||
hwc_to_chw = CV.HWC2CHW()
|
hwc_to_chw = CV.HWC2CHW()
|
||||||
|
|
||||||
config.dataset_size = len(yolo_dataset)
|
config.dataset_size = len(yolo_dataset)
|
||||||
num_parallel_workers1 = int(64 / device_num)
|
cores = multiprocessing.cpu_count()
|
||||||
num_parallel_workers2 = int(16 / device_num)
|
num_parallel_workers = int(cores / device_num)
|
||||||
if is_training:
|
if is_training:
|
||||||
multi_scale_trans = MultiScaleTrans(config, device_num)
|
multi_scale_trans = MultiScaleTrans(config, device_num)
|
||||||
|
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
|
||||||
|
"gt_box1", "gt_box2", "gt_box3"]
|
||||||
if device_num != 8:
|
if device_num != 8:
|
||||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
|
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names,
|
||||||
num_parallel_workers=num_parallel_workers1,
|
num_parallel_workers=min(32, num_parallel_workers),
|
||||||
sampler=distributed_sampler)
|
sampler=distributed_sampler)
|
||||||
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
|
||||||
num_parallel_workers=num_parallel_workers2, drop_remainder=True)
|
num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
|
||||||
else:
|
else:
|
||||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
|
||||||
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
|
||||||
num_parallel_workers=8, drop_remainder=True)
|
num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
|
||||||
else:
|
else:
|
||||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
||||||
sampler=distributed_sampler)
|
sampler=distributed_sampler)
|
||||||
|
|
|
@ -28,6 +28,8 @@ from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore import amp
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
|
||||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||||
from src.logger import get_logger
|
from src.logger import get_logger
|
||||||
|
@ -37,13 +39,7 @@ from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
|
||||||
from src.yolo_dataset import create_yolo_dataset
|
from src.yolo_dataset import create_yolo_dataset
|
||||||
from src.initializer import default_recurisive_init
|
from src.initializer import default_recurisive_init
|
||||||
from src.config import ConfigYOLOV3DarkNet53
|
from src.config import ConfigYOLOV3DarkNet53
|
||||||
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
|
from src.util import keep_loss_fp32
|
||||||
from src.util import ShapeRecord
|
|
||||||
|
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
||||||
device_target="Ascend", save_graphs=True, device_id=devid)
|
|
||||||
|
|
||||||
|
|
||||||
class BuildTrainNetwork(nn.Cell):
|
class BuildTrainNetwork(nn.Cell):
|
||||||
|
@ -62,6 +58,10 @@ def parse_args():
|
||||||
"""Parse train arguments."""
|
"""Parse train arguments."""
|
||||||
parser = argparse.ArgumentParser('mindspore coco training')
|
parser = argparse.ArgumentParser('mindspore coco training')
|
||||||
|
|
||||||
|
# device related
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
|
||||||
# dataset related
|
# dataset related
|
||||||
parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
|
parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
|
||||||
parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.')
|
parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.')
|
||||||
|
@ -136,9 +136,16 @@ def train():
|
||||||
"""Train function."""
|
"""Train function."""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
|
device_target=args.device_target, save_graphs=True, device_id=devid)
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
|
if args.device_target == "Ascend":
|
||||||
init()
|
init()
|
||||||
|
else:
|
||||||
|
init("nccl")
|
||||||
args.rank = get_rank()
|
args.rank = get_rank()
|
||||||
args.group_size = get_group_size()
|
args.group_size = get_group_size()
|
||||||
|
|
||||||
|
@ -259,7 +266,17 @@ def train():
|
||||||
momentum=args.momentum,
|
momentum=args.momentum,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
loss_scale=args.loss_scale)
|
loss_scale=args.loss_scale)
|
||||||
|
enable_amp = False
|
||||||
|
is_gpu = context.get_context("device_target") == "GPU"
|
||||||
|
if is_gpu:
|
||||||
|
enable_amp = True
|
||||||
|
if enable_amp:
|
||||||
|
loss_scale_value = 1.0
|
||||||
|
loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
|
||||||
|
network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
|
||||||
|
level="O2", keep_batchnorm_fp32=True)
|
||||||
|
keep_loss_fp32(network)
|
||||||
|
else:
|
||||||
network = TrainingWrapper(network, opt)
|
network = TrainingWrapper(network, opt)
|
||||||
network.set_train()
|
network.set_train()
|
||||||
|
|
||||||
|
@ -282,28 +299,19 @@ def train():
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
data_loader = ds.create_dict_iterator()
|
data_loader = ds.create_dict_iterator()
|
||||||
|
|
||||||
shape_record = ShapeRecord()
|
|
||||||
for i, data in enumerate(data_loader):
|
for i, data in enumerate(data_loader):
|
||||||
images = data["image"]
|
images = data["image"]
|
||||||
input_shape = images.shape[2:4]
|
input_shape = images.shape[2:4]
|
||||||
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
|
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
|
||||||
shape_record.set(input_shape)
|
|
||||||
|
|
||||||
images = Tensor(images)
|
images = Tensor(images)
|
||||||
annos = data["annotation"]
|
|
||||||
if args.group_size == 1:
|
|
||||||
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
|
|
||||||
batch_preprocess_true_box(annos, config, input_shape)
|
|
||||||
else:
|
|
||||||
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
|
|
||||||
batch_preprocess_true_box_single(annos, config, input_shape)
|
|
||||||
|
|
||||||
batch_y_true_0 = Tensor(batch_y_true_0)
|
batch_y_true_0 = Tensor(data['bbox1'])
|
||||||
batch_y_true_1 = Tensor(batch_y_true_1)
|
batch_y_true_1 = Tensor(data['bbox2'])
|
||||||
batch_y_true_2 = Tensor(batch_y_true_2)
|
batch_y_true_2 = Tensor(data['bbox3'])
|
||||||
batch_gt_box0 = Tensor(batch_gt_box0)
|
batch_gt_box0 = Tensor(data['gt_box1'])
|
||||||
batch_gt_box1 = Tensor(batch_gt_box1)
|
batch_gt_box1 = Tensor(data['gt_box2'])
|
||||||
batch_gt_box2 = Tensor(batch_gt_box2)
|
batch_gt_box2 = Tensor(data['gt_box3'])
|
||||||
|
|
||||||
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
|
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
|
||||||
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
|
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
|
||||||
|
|
Loading…
Reference in New Issue