diff --git a/model_zoo/official/cv/mobilenetv2/default_config.yaml b/model_zoo/official/cv/mobilenetv2/default_config.yaml index 6da0abc210b..b7b3194466b 100644 --- a/model_zoo/official/cv/mobilenetv2/default_config.yaml +++ b/model_zoo/official/cv/mobilenetv2/default_config.yaml @@ -24,7 +24,7 @@ lr_init: 0.00 lr_end: 0.00 lr_max: 0.4 momentum: 0.9 -weight_decay: 0.00004 # 4e-5 +weight_decay: 0.00004 label_smooth: 0.1 loss_scale: 1024 save_checkpoint: True @@ -32,10 +32,10 @@ save_checkpoint_epochs: 1 keep_checkpoint_max: 200 save_checkpoint_path: "./" platform: 'Ascend' -device_id: int(os.getenv('DEVICE_ID', '0')) -rank_id: int(os.getenv('RANK_ID', '0')) -rank_size: int(os.getenv('RANK_SIZE', '1')) -run_distribute: int(os.getenv('RANK_SIZE', '1')) > 1. +device_id: 0 +rank_id: 0 +rank_size: 1 +run_distribute: False activation: "Softmax" # Image classification trian. train_parse_args():return train_args diff --git a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh index 0fcab0d43a7..0054a477455 100644 --- a/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv2/scripts/run_train.sh @@ -35,7 +35,7 @@ run_ascend() exit 1 fi; - if [ $2 -lt 1 ] && [ $2 -gt 8 ] + if [ $2 -lt 1 ] || [ $2 -gt 8 ] then echo "error: DEVICE_NUM=$2 is not in (1-8)" exit 1 @@ -116,11 +116,14 @@ run_gpu() GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]" exit 1 fi; - if [ $2 -lt 1 ] && [ $2 -gt 8 ] - then + if [ $2 -eq 1 ] ; then + RUN_DISTRIBUTE=False + elif [ $2 -gt 1 ] && [ $2 -le 8 ] ; then + RUN_DISTRIBUTE=True + else echo "error: DEVICE_NUM=$2 is not in (1-8)" - exit 1 - fi + exit 1 + fi; if [ ! -d $4 ] then @@ -144,6 +147,7 @@ run_gpu() python ${BASEPATH}/../train.py \ --config_path=$CONFIG_FILE \ --platform=$1 \ + --run_distribute=$RUN_DISTRIBUTE \ --dataset_path=$4 \ --pretrain_ckpt=$PRETRAINED_CKPT \ --freeze_layer=$FREEZE_LAYER \ diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index 4e66c7e59ea..dadacc4ee96 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -41,10 +41,6 @@ from src.model_utils.device_adapter import get_device_id, get_device_num, get_ra set_seed(1) -config.device_id = get_device_id() -config.rank_id = get_rank_id() -config.rank_size = get_device_num() - def modelarts_pre_process(): def unzip(zip_file, save_dir): @@ -105,6 +101,13 @@ def modelarts_pre_process(): @moxing_wrapper(pre_process=modelarts_pre_process) def train_mobilenetv2(): config.dataset_path = os.path.join(config.dataset_path, 'train') + + config.device_id = get_device_id() + config.rank_id = get_rank_id() + config.rank_size = get_device_num() + if config.platform == 'Ascend': + config.run_distribute = config.rank_size > 1. + print('\nconfig: \n', config) start = time.time() # set context and device init