forked from mindspore-Ecosystem/mindspore
!21379 mcnn bug fix
Merge pull request !21379 from wukesong/wl_1.1_mcnn
This commit is contained in:
commit
9c133b6f70
|
@ -53,6 +53,8 @@ Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/Shangha
|
|||
├─ground_truth_csv
|
||||
```
|
||||
|
||||
- note: formatted_trainval dir is generated by file [create_training_set_shtech](https://github.com/svishwa/crowdcount-mcnn/blob/master/data_preparation/create_training_set_shtech.m)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (Ascend)
|
||||
|
@ -68,10 +70,10 @@ Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/Shangha
|
|||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# enter script dir, train AlexNet
|
||||
sh run_standalone_train_ascend.sh [DATA_PATH] [CKPT_SAVE_PATH]
|
||||
# enter script dir, evaluate AlexNet
|
||||
sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_NAME]
|
||||
# enter script dir, train MCNN example
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, evaluate MCNN example
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -123,11 +125,10 @@ Major parameters in train.py and config.py as follows:
|
|||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# python train.py
|
||||
# or enter script dir, and run the distribute script
|
||||
sh run_distribute_train.sh
|
||||
# or enter script dir, and run the standalone script
|
||||
sh run_standalone_train.sh
|
||||
# enter script dir, and run the distribute script
|
||||
sh run_distribute_train.sh ./hccl_table.json ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, and run the standalone script
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
@ -154,9 +155,8 @@ Before running the command below, please check the checkpoint path used for eval
|
|||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# python eval.py
|
||||
# or enter script dir, and run the script
|
||||
sh run_eval.sh
|
||||
# enter script dir, and run the script
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:
|
||||
|
|
|
@ -36,17 +36,18 @@ ckptpath = "obs://lhb1234/MCNN/ckpt"
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=False, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='obs://lhb1234/mcnn-pure/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/images',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='obs://lhb1234/mcnn-pure/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/ground_truth_csv',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
set_seed(64678)
|
||||
|
@ -54,14 +55,13 @@ set_seed(64678)
|
|||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=device_id)
|
||||
context.set_context(device_id=args.device_id)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
|
@ -81,7 +81,7 @@ if __name__ == "__main__":
|
|||
ds_val = ds_val.batch(1)
|
||||
network = MCNN()
|
||||
|
||||
model_name = os.path.join(local_ckpt_url, 'best.ckpt')
|
||||
model_name = local_ckpt_url
|
||||
print(model_name)
|
||||
mae = 0.0
|
||||
mse = 0.0
|
||||
|
|
|
@ -15,10 +15,9 @@
|
|||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=/home/wks/hccl_8p_01234567_127.0.0.1.json
|
||||
export RUN_OFFLINE=$1
|
||||
export DEVICE_NUM=8
|
||||
export RANK_TABLE_FILE=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
|
@ -40,7 +39,7 @@ do
|
|||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,14 +16,13 @@
|
|||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] [0|1|2|3|4|5|6|7] "
|
||||
echo "Usage: sh run_eval.sh [DEVICE_ID] [VAL_PATH] [VAL_GT_PATH] [CKPT_PATH] "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
export RUN_OFFLINE=$1
|
||||
export DEVICE_ID=$1
|
||||
export VAL_PATH=$2
|
||||
export VAL_GT_PATH=$3
|
||||
export CKPT_PATH=$4
|
||||
|
@ -40,6 +39,6 @@ cp -r ../src ./eval
|
|||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --run_offline=$RUN_OFFLINE --val_path=$VAL_PATH \
|
||||
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
python -u eval.py --device_id=$DEVICE_ID --val_path=$VAL_PATH \
|
||||
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
||||
|
|
|
@ -14,9 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=1
|
||||
export RANK_SIZE=1
|
||||
export RUN_OFFLINE=$1
|
||||
export DEVICE_ID=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
|
@ -37,8 +36,8 @@ env > env.
|
|||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
fi
|
||||
cd ..
|
||||
|
||||
|
|
|
@ -36,9 +36,10 @@ from src.Mcnn_Callback import mcnn_callback
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=False, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
|
@ -47,10 +48,10 @@ parser.add_argument('--train_url', default=None, help='Location of training outp
|
|||
parser.add_argument('--train_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--train_gt_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
rand_seed = 64678
|
||||
|
@ -58,26 +59,20 @@ np.random.seed(rand_seed)
|
|||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
|
||||
print("device_id:", device_id)
|
||||
print("device_num:", device_num)
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=device_id)
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
# local_data1_url=os.path.join(local_data1_url,str(device_id)) # 可以删除
|
||||
# local_data2_url=os.path.join(local_data2_url,str(device_id))
|
||||
# local_data3_url=os.path.join(local_data3_url,str(device_id))
|
||||
# local_data4_url=os.path.join(local_data4_url,str(device_id))
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
if args.run_offline:
|
||||
|
|
Loading…
Reference in New Issue