change some settings in YOLOv3

This commit is contained in:
zhaoting 2020-04-14 16:27:12 +08:00
parent cc6258d6ac
commit 59604af98b
3 changed files with 12 additions and 6 deletions

View File

@ -22,7 +22,6 @@ from PIL import Image
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
import mindspore.dataset as de import mindspore.dataset as de
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
import mindspore.dataset.transforms.vision.py_transforms as P
import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as C
from config import ConfigYOLOV3ResNet18 from config import ConfigYOLOV3ResNet18
@ -301,13 +300,12 @@ def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
if is_training: if is_training:
hwc_to_chw = P.HWC2CHW() hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers) operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.shuffle(buffer_size=256)
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)
else: else:

View File

@ -19,6 +19,7 @@ echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH" echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json" echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "The learning rate is 0.005 as default, if you want other lr, please change the value in this script."
echo "==============================================================================================================" echo "=============================================================================================================="
EPOCH_SIZE=$2 EPOCH_SIZE=$2
@ -38,6 +39,11 @@ export RANK_SIZE=$1
for((i=0;i<RANK_SIZE;i++)) for((i=0;i<RANK_SIZE;i++))
do do
export DEVICE_ID=$i export DEVICE_ID=$i
start=`expr $i \* 12`
end=`expr $start \+ 11`
cmdopt=$start"-"$end
rm -rf LOG$i rm -rf LOG$i
mkdir ./LOG$i mkdir ./LOG$i
cp *.py ./LOG$i cp *.py ./LOG$i
@ -45,8 +51,9 @@ do
export RANK_ID=$i export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID" echo "start training for rank $i, device $DEVICE_ID"
env > env.log env > env.log
python ../train.py \ taskset -c $cmdopt python ../train.py \
--distribute=1 \ --distribute=1 \
--lr=0.005 \
--device_num=$RANK_SIZE \ --device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--mindrecord_dir=$MINDRECORD_DIR \ --mindrecord_dir=$MINDRECORD_DIR \

View File

@ -67,6 +67,7 @@ if __name__ == '__main__':
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate, default is 0.001.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink")
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
@ -137,8 +138,8 @@ if __name__ == '__main__':
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size, lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=0, global_step=args_opt.epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95)) decay_step=1000, decay_rate=0.95, steps=True))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale) net = TrainingWrapper(net, opt, loss_scale)