forked from mindspore-Ecosystem/mindspore
yolov3 network directory rectification
This commit is contained in:
parent
5d7b9d959e
commit
84b0834659
|
@ -19,10 +19,10 @@ import argparse
|
|||
import time
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from util import metrics
|
||||
from src.yolov3 import yolov3_resnet18, YoloWithEval
|
||||
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
from src.utils import metrics
|
||||
|
||||
def yolo_eval(dataset_path, ckpt_path):
|
||||
"""Yolov3 evaluation."""
|
|
@ -45,6 +45,9 @@ echo "After running the scipt, the network runs in the background. The log will
|
|||
export MINDSPORE_HCCL_CONFIG_PATH=$6
|
||||
export RANK_SIZE=$1
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
|
@ -56,6 +59,7 @@ do
|
|||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
@ -63,7 +67,7 @@ do
|
|||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
taskset -c $cmdopt python ../train.py \
|
||||
taskset -c $cmdopt python train.py \
|
||||
--distribute=1 \
|
||||
--lr=0.005 \
|
||||
--device_num=$RANK_SIZE \
|
||||
|
@ -76,7 +80,7 @@ do
|
|||
|
||||
if [ $# == 8 ]
|
||||
then
|
||||
taskset -c $cmdopt python ../train.py \
|
||||
taskset -c $cmdopt python train.py \
|
||||
--distribute=1 \
|
||||
--lr=0.005 \
|
||||
--device_num=$RANK_SIZE \
|
|
@ -20,4 +20,7 @@ echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
|
|||
echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
|
@ -27,6 +27,9 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
|
@ -25,7 +25,7 @@ class ConfigYOLOV3ResNet18:
|
|||
"""
|
||||
img_shape = [352, 640]
|
||||
feature_shape = [32, 3, 352, 640]
|
||||
num_classes = 80
|
||||
num_classes = 2
|
||||
nms_max_num = 50
|
||||
|
||||
backbone_input_shape = [64, 64, 128, 256]
|
|
@ -23,7 +23,7 @@ from PIL import Image
|
|||
import mindspore.dataset as de
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
|
||||
iter_cnt = 0
|
||||
_NUM_BOXES = 50
|
|
@ -15,7 +15,7 @@
|
|||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
|
@ -33,9 +33,9 @@ from mindspore.train import Model, ParallelMode
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
Loading…
Reference in New Issue