diff --git a/model_zoo/official/cv/MCNN/README.md b/model_zoo/official/cv/MCNN/README.md index 0d7514aa5d2..b738cd467e2 100644 --- a/model_zoo/official/cv/MCNN/README.md +++ b/model_zoo/official/cv/MCNN/README.md @@ -53,6 +53,8 @@ Dataset used: [ShanghaitechA]( env.log - python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \ - --val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & + python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \ + --val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & cd .. done diff --git a/model_zoo/official/cv/MCNN/scripts/run_eval.sh b/model_zoo/official/cv/MCNN/scripts/run_eval.sh index a7fec47c8b4..5a138dc548f 100644 --- a/model_zoo/official/cv/MCNN/scripts/run_eval.sh +++ b/model_zoo/official/cv/MCNN/scripts/run_eval.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,14 +16,13 @@ if [ $# != 4 ] then - echo "Usage: sh run_eval.sh [se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] [0|1|2|3|4|5|6|7] " + echo "Usage: sh run_eval.sh [DEVICE_ID] [VAL_PATH] [VAL_GT_PATH] [CKPT_PATH] " exit 1 fi ulimit -u unlimited -export DEVICE_ID=0 export RANK_SIZE=1 -export RUN_OFFLINE=$1 +export DEVICE_ID=$1 export VAL_PATH=$2 export VAL_GT_PATH=$3 export CKPT_PATH=$4 @@ -40,6 +39,6 @@ cp -r ../src ./eval cd ./eval || exit env > env.log echo "start evaluation for device $DEVICE_ID" -python eval.py --run_offline=$RUN_OFFLINE --val_path=$VAL_PATH \ - --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & +python -u eval.py --device_id=$DEVICE_ID --val_path=$VAL_PATH \ + --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & cd .. diff --git a/model_zoo/official/cv/MCNN/scripts/run_standalone_train.sh b/model_zoo/official/cv/MCNN/scripts/run_standalone_train.sh index f5fdcf08f1e..c0733eefa95 100644 --- a/model_zoo/official/cv/MCNN/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/MCNN/scripts/run_standalone_train.sh @@ -14,9 +14,8 @@ # limitations under the License. # ============================================================================ ulimit -u unlimited -export DEVICE_ID=1 export RANK_SIZE=1 -export RUN_OFFLINE=$1 +export DEVICE_ID=$1 export TRAIN_PATH=$2 export TRAIN_GT_PATH=$3 export VAL_PATH=$4 @@ -37,8 +36,8 @@ env > env. if [ $# == 6 ] then - python train.py --run_offline=$RUN_OFFLINE --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \ - --val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & + python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \ + --val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log & fi cd .. diff --git a/model_zoo/official/cv/MCNN/train.py b/model_zoo/official/cv/MCNN/train.py index d159df12021..f4dfd969826 100644 --- a/model_zoo/official/cv/MCNN/train.py +++ b/model_zoo/official/cv/MCNN/train.py @@ -36,9 +36,10 @@ from src.Mcnn_Callback import mcnn_callback parser = argparse.ArgumentParser(description='MindSpore MCNN Example') parser.add_argument('--run_offline', type=ast.literal_eval, - default=False, help='run in offline is False or True') -parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + default=True, help='run in offline is False or True') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'], help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)') parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.') parser.add_argument('--data_url', default=None, help='Location of data.') @@ -47,10 +48,10 @@ parser.add_argument('--train_url', default=None, help='Location of training outp parser.add_argument('--train_path', required=True, default=None, help='Location of data.') parser.add_argument('--train_gt_path', required=True, default=None, help='Location of data.') parser.add_argument('--val_path', required=True, - default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val', + default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val', help='Location of data.') parser.add_argument('--val_gt_path', required=True, - default='/lhb1234/mcnn/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den', + default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den', help='Location of data.') args = parser.parse_args() rand_seed = 64678 @@ -58,26 +59,20 @@ np.random.seed(rand_seed) if __name__ == "__main__": device_num = int(os.getenv("RANK_SIZE")) - device_id = int(os.getenv("DEVICE_ID")) - print("device_id:", device_id) print("device_num:", device_num) device_target = args.device_target context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(save_graphs=False) if device_target == "Ascend": - context.set_context(device_id=device_id) + context.set_context(device_id=args.device_id) if device_num > 1: context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) init() - # local_data1_url=os.path.join(local_data1_url,str(device_id)) # 可以删除 - # local_data2_url=os.path.join(local_data2_url,str(device_id)) - # local_data3_url=os.path.join(local_data3_url,str(device_id)) - # local_data4_url=os.path.join(local_data4_url,str(device_id)) else: raise ValueError("Unsupported platform.") if args.run_offline: