!21379 mcnn bug fix

Merge pull request !21379 from wukesong/wl_1.1_mcnn
This commit is contained in:
i-robot 2021-08-09 02:40:36 +00:00 committed by Gitee
commit 9c133b6f70
6 changed files with 37 additions and 45 deletions

View File

@ -53,6 +53,8 @@ Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/Shangha
├─ground_truth_csv
```
- note: formatted_trainval dir is generated by file [create_training_set_shtech](https://github.com/svishwa/crowdcount-mcnn/blob/master/data_preparation/create_training_set_shtech.m)
# [Environment Requirements](#contents)
- Hardware (Ascend)
@ -68,10 +70,10 @@ Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/Shangha
After installing MindSpore via the official website, you can start training and evaluation as follows:
```bash
# enter script dir, train AlexNet
sh run_standalone_train_ascend.sh [DATA_PATH] [CKPT_SAVE_PATH]
# enter script dir, evaluate AlexNet
sh run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_NAME]
# enter script dir, train MCNN example
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
# enter script dir, evaluate MCNN example
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
```
# [Script Description](#contents)
@ -123,11 +125,10 @@ Major parameters in train.py and config.py as follows:
- running on Ascend
```bash
# python train.py
# or enter script dir, and run the distribute script
sh run_distribute_train.sh
# or enter script dir, and run the standalone script
sh run_standalone_train.sh
# enter script dir, and run the distribute script
sh run_distribute_train.sh ./hccl_table.json ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
# enter script dir, and run the standalone script
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
```
After training, the loss value will be achieved as follows:
@ -154,9 +155,8 @@ Before running the command below, please check the checkpoint path used for eval
- running on Ascend
```bash
# python eval.py
# or enter script dir, and run the script
sh run_eval.sh
# enter script dir, and run the script
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
```
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:

View File

@ -36,17 +36,18 @@ ckptpath = "obs://lhb1234/MCNN/ckpt"
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
parser.add_argument('--run_offline', type=ast.literal_eval,
default=False, help='run in offline is False or True')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
default=True, help='run in offline is False or True')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--val_path', required=True,
default='obs://lhb1234/mcnn-pure/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/images',
help='Location of data.')
parser.add_argument('--val_gt_path', required=True,
default='obs://lhb1234/mcnn-pure/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/ground_truth_csv',
help='Location of data.')
args = parser.parse_args()
set_seed(64678)
@ -54,14 +55,13 @@ set_seed(64678)
if __name__ == "__main__":
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
if device_target == "Ascend":
context.set_context(device_id=device_id)
context.set_context(device_id=args.device_id)
else:
raise ValueError("Unsupported platform.")
@ -81,7 +81,7 @@ if __name__ == "__main__":
ds_val = ds_val.batch(1)
network = MCNN()
model_name = os.path.join(local_ckpt_url, 'best.ckpt')
model_name = local_ckpt_url
print(model_name)
mae = 0.0
mse = 0.0

View File

@ -15,10 +15,9 @@
# ============================================================================
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=/home/wks/hccl_8p_01234567_127.0.0.1.json
export RUN_OFFLINE=$1
export DEVICE_NUM=8
export RANK_TABLE_FILE=$1
export TRAIN_PATH=$2
export TRAIN_GT_PATH=$3
export VAL_PATH=$4
@ -40,7 +39,7 @@ do
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
cd ..
done

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,14 +16,13 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] [0|1|2|3|4|5|6|7] "
echo "Usage: sh run_eval.sh [DEVICE_ID] [VAL_PATH] [VAL_GT_PATH] [CKPT_PATH] "
exit 1
fi
ulimit -u unlimited
export DEVICE_ID=0
export RANK_SIZE=1
export RUN_OFFLINE=$1
export DEVICE_ID=$1
export VAL_PATH=$2
export VAL_GT_PATH=$3
export CKPT_PATH=$4
@ -40,6 +39,6 @@ cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --run_offline=$RUN_OFFLINE --val_path=$VAL_PATH \
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
python -u eval.py --device_id=$DEVICE_ID --val_path=$VAL_PATH \
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
cd ..

View File

@ -14,9 +14,8 @@
# limitations under the License.
# ============================================================================
ulimit -u unlimited
export DEVICE_ID=1
export RANK_SIZE=1
export RUN_OFFLINE=$1
export DEVICE_ID=$1
export TRAIN_PATH=$2
export TRAIN_GT_PATH=$3
export VAL_PATH=$4
@ -37,8 +36,8 @@ env > env.
if [ $# == 6 ]
then
python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
fi
cd ..

View File

@ -36,9 +36,10 @@ from src.Mcnn_Callback import mcnn_callback
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
parser.add_argument('--run_offline', type=ast.literal_eval,
default=False, help='run in offline is False or True')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
default=True, help='run in offline is False or True')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
parser.add_argument('--data_url', default=None, help='Location of data.')
@ -47,10 +48,10 @@ parser.add_argument('--train_url', default=None, help='Location of training outp
parser.add_argument('--train_path', required=True, default=None, help='Location of data.')
parser.add_argument('--train_gt_path', required=True, default=None, help='Location of data.')
parser.add_argument('--val_path', required=True,
default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
help='Location of data.')
parser.add_argument('--val_gt_path', required=True,
default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
help='Location of data.')
args = parser.parse_args()
rand_seed = 64678
@ -58,26 +59,20 @@ np.random.seed(rand_seed)
if __name__ == "__main__":
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
print("device_id:", device_id)
print("device_num:", device_num)
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
if device_target == "Ascend":
context.set_context(device_id=device_id)
context.set_context(device_id=args.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
# local_data1_url=os.path.join(local_data1_url,str(device_id)) # 可以删除
# local_data2_url=os.path.join(local_data2_url,str(device_id))
# local_data3_url=os.path.join(local_data3_url,str(device_id))
# local_data4_url=os.path.join(local_data4_url,str(device_id))
else:
raise ValueError("Unsupported platform.")
if args.run_offline: