yolov3_darknet53 suit for gpu

This commit is contained in:
hanhuifeng2020 2020-08-25 15:51:30 +08:00
parent 98528bbc16
commit 1f0a760cdb
9 changed files with 318 additions and 49 deletions

View File

@ -53,8 +53,8 @@ Dataset used: [COCO2014](https://cocodataset.org/#download)
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- HardwareAscend - HardwareAscend/GPU
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below - For more information, please check the resources below
@ -65,7 +65,7 @@ Dataset used: [COCO2014](https://cocodataset.org/#download)
# [Quick Start](#contents) # [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation in Ascend as follows: After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
``` ```
# The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper.
@ -87,9 +87,12 @@ python train.py \
# standalone training example(1p) by shell script # standalone training example(1p) by shell script
sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt
# distributed training example(8p) by shell script # For Ascend device, distributed training example(8p) by shell script
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
# For GPU device, distributed training example(8p) by shell script
sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt
# run evaluation by python command # run evaluation by python command
python eval.py \ python eval.py \
--data_dir=./dataset/coco2014 \ --data_dir=./dataset/coco2014 \
@ -113,6 +116,9 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
├─run_standalone_train.sh # launch standalone training(1p) in ascend ├─run_standalone_train.sh # launch standalone training(1p) in ascend
├─run_distribute_train.sh # launch distributed training(8p) in ascend ├─run_distribute_train.sh # launch distributed training(8p) in ascend
└─run_eval.sh # launch evaluating in ascend └─run_eval.sh # launch evaluating in ascend
├─run_standalone_train_gpu.sh # launch standalone training(1p) in gpu
├─run_distribute_train_gpu.sh # launch distributed training(8p) in gpu
└─run_eval_gpu.sh # launch evaluating in gpu
├─src ├─src
├─__init__.py # python init file ├─__init__.py # python init file
├─config.py # parameter configuration ├─config.py # parameter configuration
@ -138,6 +144,7 @@ Major parameters in train.py as follow.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
--data_dir DATA_DIR Train dataset directory. --data_dir DATA_DIR Train dataset directory.
--per_batch_size PER_BATCH_SIZE --per_batch_size PER_BATCH_SIZE
Batch size for Training. Default: 32. Batch size for Training. Default: 32.
@ -212,7 +219,7 @@ python train.py \
--lr_scheduler=cosine_annealing > log.txt 2>&1 & --lr_scheduler=cosine_annealing > log.txt 2>&1 &
``` ```
The python command above will run in the background, you can view the results through the file `log.txt`. The python command above will run in the background, you can view the results through the file `log.txt`. If running on GPU, please add `--device_target=GPU` in the python command.
After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows: After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows:
@ -228,9 +235,14 @@ The model checkpoint will be saved in outputs directory.
### Distributed Training ### Distributed Training
For Ascend device, distributed training example(8p) by shell script
``` ```
sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json
``` ```
For GPU device, distributed training example(8p) by shell script
```
sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows: The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows:
@ -254,7 +266,7 @@ epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e
### Evaluation ### Evaluation
Before running the command below. Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh").
``` ```
python eval.py \ python eval.py \

View File

@ -35,9 +35,6 @@ from src.logger import get_logger
from src.yolo_dataset import create_yolo_dataset from src.yolo_dataset import create_yolo_dataset
from src.config import ConfigYOLOV3DarkNet53 from src.config import ConfigYOLOV3DarkNet53
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
class Redirct: class Redirct:
def __init__(self): def __init__(self):
@ -208,6 +205,10 @@ def parse_args():
"""Parse arguments.""" """Parse arguments."""
parser = argparse.ArgumentParser('mindspore coco testing') parser = argparse.ArgumentParser('mindspore coco testing')
# device related
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
# dataset related # dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir') parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
@ -243,10 +244,13 @@ def test():
start_time = time.time() start_time = time.time()
args = parse_args() args = parse_args()
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid)
# logger # logger
args.outputs_dir = os.path.join(args.log_path, args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
rank_id = int(os.environ.get('RANK_ID')) rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
args.logger = get_logger(args.outputs_dir, rank_id) args.logger = get_logger(args.outputs_dir, rank_id)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()

View File

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
PRETRAINED_BACKBONE=$(get_real_path $2)
echo $DATASET_PATH
echo $PRETRAINED_BACKBONE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
export DEVICE_NUM=8
rm -rf ./train_parallel
mkdir ./train_parallel
cp ../*.py ./train_parallel
cp -r ../src ./train_parallel
cd ./train_parallel || exit
env > env.log
mpirun --allow-run-as-root -n ${DEVICE_NUM} python train.py \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--device_target=GPU \
--is_distributed=1 \
--lr=0.1 \
--T_max=320 \
--max_epoch=320 \
--warmup_epochs=4 \
--training_shape=416 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CHECKPOINT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py \
--device_target="GPU" \
--data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \
--testing_shape=416 > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,75 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
PRETRAINED_BACKBONE=$(get_real_path $2)
echo $PRETRAINED_BACKBONE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--device_targe="GPU" \
--data_dir=$DATASET_PATH \
--pretrained_backbone=$PRETRAINED_BACKBONE \
--is_distributed=0 \
--lr=0.1 \
--T_max=320 \
--max_epoch=320 \
--warmup_epochs=4 \
--training_shape=416 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..

View File

@ -465,6 +465,11 @@ class MultiScaleTrans:
self.seed_list = self.generate_seed_list(seed_num=self.seed_num) self.seed_list = self.generate_seed_list(seed_num=self.seed_num)
self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate))
self.device_num = device_num self.device_num = device_num
self.anchor_scales = config.anchor_scales
self.num_classes = config.num_classes
self.max_box = config.max_box
self.label_smooth = config.label_smooth
self.label_smooth_factor = config.label_smooth_factor
def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)):
seed_list = [] seed_list = []
@ -474,13 +479,20 @@ class MultiScaleTrans:
seed_list.append(seed) seed_list.append(seed)
return seed_list return seed_list
def __call__(self, imgs, annos, batchInfo): def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batchInfo):
epoch_num = batchInfo.get_epoch_num() epoch_num = batchInfo.get_epoch_num()
size_idx = int(batchInfo.get_batch_num() / self.resize_rate) size_idx = int(batchInfo.get_batch_num() / self.resize_rate)
seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num]
ret_imgs = [] ret_imgs = []
ret_annos = [] ret_annos = []
bbox1 = []
bbox2 = []
bbox3 = []
gt1 = []
gt2 = []
gt3 = []
if self.size_dict.get(seed_key, None) is None: if self.size_dict.get(seed_key, None) is None:
random.seed(seed_key) random.seed(seed_key)
new_size = random.choice(self.config.multi_scale) new_size = random.choice(self.config.multi_scale)
@ -491,8 +503,19 @@ class MultiScaleTrans:
for img, anno in zip(imgs, annos): for img, anno in zip(imgs, annos):
img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num)
ret_imgs.append(img.transpose(2, 0, 1).copy()) ret_imgs.append(img.transpose(2, 0, 1).copy())
ret_annos.append(anno) bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
return np.array(ret_imgs), np.array(ret_annos) _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2],
num_classes=self.num_classes, max_boxes=self.max_box,
label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor)
bbox1.append(bbox_true_1)
bbox2.append(bbox_true_2)
bbox3.append(bbox_true_3)
gt1.append(gt_box1)
gt2.append(gt_box2)
gt3.append(gt_box3)
ret_annos.append(0)
return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \
np.array(gt1), np.array(gt2), np.array(gt3)
def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2,

View File

@ -15,6 +15,9 @@
"""Util class or function.""" """Util class or function."""
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype
from .yolo import YoloLossBlock
class AverageMeter: class AverageMeter:
@ -175,3 +178,10 @@ class ShapeRecord:
for key in self.shape_record: for key in self.shape_record:
rate = self.shape_record[key] / float(self.shape_record['total']) rate = self.shape_record[key] / float(self.shape_record['total'])
logger.info('shape {}: {:.2f}%'.format(key, rate*100)) logger.info('shape {}: {:.2f}%'.format(key, rate*100))
def keep_loss_fp32(network):
"""Keep loss of network with float32"""
for _, cell in network.cells_and_names():
if isinstance(cell, (YoloLossBlock,)):
cell.to_float(mstype.float32)

View File

@ -15,6 +15,7 @@
"""YOLOV3 dataset.""" """YOLOV3 dataset."""
import os import os
import multiprocessing
from PIL import Image from PIL import Image
from pycocotools.coco import COCO from pycocotools.coco import COCO
import mindspore.dataset as de import mindspore.dataset as de
@ -126,7 +127,7 @@ class COCOYoloDataset:
tmp.append(int(label)) tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label] # tmp [x_min y_min x_max y_max, label]
out_target.append(tmp) out_target.append(tmp)
return img, out_target return img, out_target, [], [], [], [], [], []
def __len__(self): def __len__(self):
return len(self.img_ids) return len(self.img_ids)
@ -155,20 +156,22 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num,
hwc_to_chw = CV.HWC2CHW() hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset) config.dataset_size = len(yolo_dataset)
num_parallel_workers1 = int(64 / device_num) cores = multiprocessing.cpu_count()
num_parallel_workers2 = int(16 / device_num) num_parallel_workers = int(cores / device_num)
if is_training: if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num) multi_scale_trans = MultiScaleTrans(config, device_num)
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
if device_num != 8: if device_num != 8:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names,
num_parallel_workers=num_parallel_workers1, num_parallel_workers=min(32, num_parallel_workers),
sampler=distributed_sampler) sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=num_parallel_workers2, drop_remainder=True) num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else: else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=8, drop_remainder=True) num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
else: else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler) sampler=distributed_sampler)

View File

@ -28,6 +28,8 @@ from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms import mindspore as ms
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
from src.logger import get_logger from src.logger import get_logger
@ -37,13 +39,7 @@ from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
from src.yolo_dataset import create_yolo_dataset from src.yolo_dataset import create_yolo_dataset
from src.initializer import default_recurisive_init from src.initializer import default_recurisive_init
from src.config import ConfigYOLOV3DarkNet53 from src.config import ConfigYOLOV3DarkNet53
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single from src.util import keep_loss_fp32
from src.util import ShapeRecord
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Ascend", save_graphs=True, device_id=devid)
class BuildTrainNetwork(nn.Cell): class BuildTrainNetwork(nn.Cell):
@ -62,6 +58,10 @@ def parse_args():
"""Parse train arguments.""" """Parse train arguments."""
parser = argparse.ArgumentParser('mindspore coco training') parser = argparse.ArgumentParser('mindspore coco training')
# device related
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
# dataset related # dataset related
parser.add_argument('--data_dir', type=str, help='Train dataset directory.') parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.') parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.')
@ -136,9 +136,16 @@ def train():
"""Train function.""" """Train function."""
args = parse_args() args = parse_args()
devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=True, device_id=devid)
# init distributed # init distributed
if args.is_distributed: if args.is_distributed:
if args.device_target == "Ascend":
init() init()
else:
init("nccl")
args.rank = get_rank() args.rank = get_rank()
args.group_size = get_group_size() args.group_size = get_group_size()
@ -259,7 +266,17 @@ def train():
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
loss_scale=args.loss_scale) loss_scale=args.loss_scale)
enable_amp = False
is_gpu = context.get_context("device_target") == "GPU"
if is_gpu:
enable_amp = True
if enable_amp:
loss_scale_value = 1.0
loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
level="O2", keep_batchnorm_fp32=True)
keep_loss_fp32(network)
else:
network = TrainingWrapper(network, opt) network = TrainingWrapper(network, opt)
network.set_train() network.set_train()
@ -282,28 +299,19 @@ def train():
t_end = time.time() t_end = time.time()
data_loader = ds.create_dict_iterator() data_loader = ds.create_dict_iterator()
shape_record = ShapeRecord()
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
images = data["image"] images = data["image"]
input_shape = images.shape[2:4] input_shape = images.shape[2:4]
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
shape_record.set(input_shape)
images = Tensor(images) images = Tensor(images)
annos = data["annotation"]
if args.group_size == 1:
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
batch_preprocess_true_box(annos, config, input_shape)
else:
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
batch_preprocess_true_box_single(annos, config, input_shape)
batch_y_true_0 = Tensor(batch_y_true_0) batch_y_true_0 = Tensor(data['bbox1'])
batch_y_true_1 = Tensor(batch_y_true_1) batch_y_true_1 = Tensor(data['bbox2'])
batch_y_true_2 = Tensor(batch_y_true_2) batch_y_true_2 = Tensor(data['bbox3'])
batch_gt_box0 = Tensor(batch_gt_box0) batch_gt_box0 = Tensor(data['gt_box1'])
batch_gt_box1 = Tensor(batch_gt_box1) batch_gt_box1 = Tensor(data['gt_box2'])
batch_gt_box2 = Tensor(batch_gt_box2) batch_gt_box2 = Tensor(data['gt_box3'])
input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,