From c53cd6bb222617337456f26145eb628e0fef23b3 Mon Sep 17 00:00:00 2001 From: Payne Date: Thu, 17 Sep 2020 13:08:32 +0800 Subject: [PATCH] align to docs and r1.0 --- .../scripts/{run_infer.sh => run_eval.sh} | 37 ++++++++++----- .../cv/mobilenetv2/scripts/run_train.sh | 47 +++++++++++++++---- model_zoo/official/cv/mobilenetv2/src/args.py | 13 +++-- .../official/cv/mobilenetv2/src/config.py | 3 +- .../official/cv/mobilenetv2/src/launch.py | 12 ++--- .../official/cv/mobilenetv2/src/models.py | 15 +----- .../official/cv/mobilenetv2/src/utils.py | 5 +- model_zoo/official/cv/mobilenetv2/train.py | 15 +++++- 8 files changed, 96 insertions(+), 51 deletions(-) rename model_zoo/official/cv/mobilenetv2/scripts/{run_infer.sh => run_eval.sh} (64%) diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_eval.sh similarity index 64% rename from model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh rename to model_zoo/official/cv/mobilenetv2/scripts/run_eval.sh index c596e33a23e..e6ae0769935 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_eval.sh @@ -13,13 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -if [ $# != 3 ] + +run_ascend() +{ + export DEVICE_ID=0 + export RANK_ID=0 + export RANK_SIZE=1 + + +} + +if [ $# -gt 4 ] || [ $# -lt 3 ] then - echo "Ascend: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH] \ - GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" + echo "Usage: + Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] + GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] + CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]" exit 1 fi + # check dataset path if [ ! -d $2 ] then @@ -30,16 +43,13 @@ fi # check checkpoint file if [ ! -f $3 ] then - echo "error: CHECKPOINT_PATH=$3 is not a file" + echo "error: PRETRAIN_CKPT=$3 is not a file" exit 1 fi # set environment BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH -export DEVICE_ID=0 -export RANK_ID=0 -export RANK_SIZE=1 if [ -d "../eval" ]; then rm -rf ../eval @@ -47,9 +57,14 @@ fi mkdir ../eval cd ../eval || exit +if [ $1 = "CPU" ] ; then + run_ascend "$@" +fi; + # launch python ${BASEPATH}/../eval.py \ - --device_target=$1 \ - --dataset_path=$2 \ - --checkpoint_path=$3 \ - &> ../infer.log & # dataset val folder path + --platform=$1 \ + --dataset_path=$2 \ + --pretrain_ckpt=$3 \ + --head_ckpt=$4 \ + &> ../eval.log & # dataset val folder path \ No newline at end of file diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh index 0cd24a0f2fe..7c1634c98df 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -38,12 +38,14 @@ run_ascend() mkdir ../train cd ../train || exit python ${BASEPATH}/../src/launch.py \ + --platform=$1 \ --nproc_per_node=$2 \ --visible_devices=$3 \ --training_script=${BASEPATH}/../train.py \ --dataset_path=$5 \ - --pre_trained=$6 \ - --device_target=$1 &> ../train.log & # dataset train folder + --train_method=$6 \ + --pretrain_ckpt=$7 \ + &> ../train.log & # dataset train folder } run_gpu() @@ -72,17 +74,43 @@ run_gpu() export CUDA_VISIBLE_DEVICES="$3" mpirun -n $2 --allow-run-as-root \ python ${BASEPATH}/../train.py \ + --platform=$1 \ --dataset_path=$4 \ - --pre_trained=$5 \ - --device_target=$1 \ + --train_method=$5 \ + --pretrain_ckpt=$6 \ &> ../train.log & # dataset train folder } -if [ $# -gt 6 ] || [ $# -lt 4 ] +run_cpu() +{ + if [ ! -d $2 ] + then + echo "error: DATASET_PATH=$2 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + python ${BASEPATH}/../train.py \ + --platform=$1 \ + --dataset_path=$2 \ + --train_method=$3 \ + --pretrain_ckpt=$4 \ + &> ../train.log & # dataset train folder +} +if [ $# -gt 7 ] || [ $# -lt 4 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \ + CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \ " exit 1 fi @@ -91,7 +119,8 @@ if [ $1 = "Ascend" ] ; then run_ascend "$@" elif [ $1 = "GPU" ] ; then run_gpu "$@" +elif [ $1 = "CPU" ] ; then + run_cpu "$@" else - echo "Unsupported device_target." + echo "Unsupported platform." fi; - diff --git a/model_zoo/official/cv/mobilenetv2/src/args.py b/model_zoo/official/cv/mobilenetv2/src/args.py index 8ce819bb0bd..48e2aeb9f69 100644 --- a/model_zoo/official/cv/mobilenetv2/src/args.py +++ b/model_zoo/official/cv/mobilenetv2/src/args.py @@ -38,25 +38,24 @@ def launch_parse_args(): def train_parse_args(): train_parser = argparse.ArgumentParser(description='Image classification trian') - train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ help='run platform, only support CPU, GPU and Ascend') + train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ + help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ + train from initialization model") train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ for fine tune or incremental learning') train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') - train_parser.add_argument('--train_method', type=str, required=True, choices=("train", "fine_tune", \ - "incremental_learn"), help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after \ - loading the ckpt, \"train\" to train from initialization model") - train_args = train_parser.parse_args() return train_args def eval_parse_args(): eval_parser = argparse.ArgumentParser(description='Image classification eval') - eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ help='run platform, only support GPU, CPU and Ascend') - eval_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ + eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') + eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \ for fine tune or incremental learning') eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \ for fine tune or incremental learning') diff --git a/model_zoo/official/cv/mobilenetv2/src/config.py b/model_zoo/official/cv/mobilenetv2/src/config.py index 08623bb4059..d16e9abe45d 100644 --- a/model_zoo/official/cv/mobilenetv2/src/config.py +++ b/model_zoo/official/cv/mobilenetv2/src/config.py @@ -37,7 +37,8 @@ def set_config(args): "save_checkpoint_epochs": 1, "keep_checkpoint_max": 20, "save_checkpoint_path": "./checkpoint", - "platform": args.platform + "platform": args.platform, + "run_distribute": False }) config_gpu = ed({ "num_classes": 1000, diff --git a/model_zoo/official/cv/mobilenetv2/src/launch.py b/model_zoo/official/cv/mobilenetv2/src/launch.py index 8785186dcf9..e890790338c 100644 --- a/model_zoo/official/cv/mobilenetv2/src/launch.py +++ b/model_zoo/official/cv/mobilenetv2/src/launch.py @@ -38,17 +38,17 @@ def main(): for rank_id in range(0, args.nproc_per_node): os.chdir(cur_path) device_id = visible_devices[rank_id] - device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) + rank_dir = os.path.join(cur_path, f'rank{rank_id}') env['RANK_ID'] = str(rank_id) env['DEVICE_ID'] = str(device_id) - if os.path.exists(device_dir): - shutil.rmtree(device_dir) - os.mkdir(device_dir) - os.chdir(device_dir) + if os.path.exists(rank_dir): + shutil.rmtree(rank_dir) + os.mkdir(rank_dir) + os.chdir(rank_dir) cmd = [sys.executable, '-u'] cmd.append(args.training_script) cmd.extend(args.training_script_args) - log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + log_file = open(f'{rank_dir}/log{rank_id}.log', 'w') process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) processes.append(process) cmds.append(cmd) diff --git a/model_zoo/official/cv/mobilenetv2/src/models.py b/model_zoo/official/cv/mobilenetv2/src/models.py index 5d895a754f3..6456645a89f 100644 --- a/model_zoo/official/cv/mobilenetv2/src/models.py +++ b/model_zoo/official/cv/mobilenetv2/src/models.py @@ -119,20 +119,9 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): for param in network.get_parameters(): param.requires_grad = False -def define_net(args, config): - backbone_net = MobileNetV2Backbone(platform=args.platform) +def define_net(config): + backbone_net = MobileNetV2Backbone(platform=config.platform) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) net = mobilenet_v2(backbone_net, head_net) - # load the ckpt file to the network for fine tune or incremental leaning - if args.pretrain_ckpt: - if args.train_method == "fine_tune": - load_ckpt(net, args.pretrain_ckpt) - elif args.train_method == "incremental_learn": - load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False) - elif args.train_method == "train": - pass - else: - raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") - return backbone_net, head_net, net diff --git a/model_zoo/official/cv/mobilenetv2/src/utils.py b/model_zoo/official/cv/mobilenetv2/src/utils.py index b1e2fd6c399..4fd8a8d0129 100644 --- a/model_zoo/official/cv/mobilenetv2/src/utils.py +++ b/model_zoo/official/cv/mobilenetv2/src/utils.py @@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype from mindspore.train.model import ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.communication.management import get_rank, init, get_group_size +from mindspore.parallel._auto_parallel_context import auto_parallel_context from src.models import Monitor @@ -58,8 +59,8 @@ def context_device_init(config): if config.run_distribute: context.set_auto_parallel_context(device_num=config.rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, gradients_mean=True, - all_reduce_fusion_config=[140]) + parameter_broadcast=True, mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([140]) init() else: raise ValueError("Only support CPU, GPU and Ascend.") diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index dbcf32a252a..e845745c3c2 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -35,7 +35,7 @@ from src.config import set_config from src.args import train_parse_args from src.utils import context_device_init, switch_precision, config_ckpoint, set_seed -from src.models import CrossEntropyWithLabelSmooth, define_net +from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt set_seed(1) @@ -50,7 +50,18 @@ if __name__ == '__main__': context_device_init(config) # define network - backbone_net, head_net, net = define_net(args_opt, config) + backbone_net, head_net, net = define_net(config) + + # load the ckpt file to the network for fine tune or incremental leaning + if args_opt.pretrain_ckpt: + if args_opt.train_method == "fine_tune": + load_ckpt(net, args_opt.pretrain_ckpt) + elif args_opt.train_method == "incremental_learn": + load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) + elif args_opt.train_method == "train": + pass + else: + raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") # CPU only support "incremental_learn" if args_opt.train_method == "incremental_learn":