!20639 Repair mobilenetv2 network training with 1p
Merge pull request !20639 from huchunmei/master
This commit is contained in:
commit
67c68c099e
|
@ -24,7 +24,7 @@ lr_init: 0.00
|
|||
lr_end: 0.00
|
||||
lr_max: 0.4
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00004 # 4e-5
|
||||
weight_decay: 0.00004
|
||||
label_smooth: 0.1
|
||||
loss_scale: 1024
|
||||
save_checkpoint: True
|
||||
|
@ -32,10 +32,10 @@ save_checkpoint_epochs: 1
|
|||
keep_checkpoint_max: 200
|
||||
save_checkpoint_path: "./"
|
||||
platform: 'Ascend'
|
||||
device_id: int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id: int(os.getenv('RANK_ID', '0'))
|
||||
rank_size: int(os.getenv('RANK_SIZE', '1'))
|
||||
run_distribute: int(os.getenv('RANK_SIZE', '1')) > 1.
|
||||
device_id: 0
|
||||
rank_id: 0
|
||||
rank_size: 1
|
||||
run_distribute: False
|
||||
activation: "Softmax"
|
||||
|
||||
# Image classification trian. train_parse_args():return train_args
|
||||
|
|
|
@ -35,7 +35,7 @@ run_ascend()
|
|||
exit 1
|
||||
fi;
|
||||
|
||||
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
|
||||
if [ $2 -lt 1 ] || [ $2 -gt 8 ]
|
||||
then
|
||||
echo "error: DEVICE_NUM=$2 is not in (1-8)"
|
||||
exit 1
|
||||
|
@ -116,11 +116,14 @@ run_gpu()
|
|||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi;
|
||||
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
|
||||
then
|
||||
if [ $2 -eq 1 ] ; then
|
||||
RUN_DISTRIBUTE=False
|
||||
elif [ $2 -gt 1 ] && [ $2 -le 8 ] ; then
|
||||
RUN_DISTRIBUTE=True
|
||||
else
|
||||
echo "error: DEVICE_NUM=$2 is not in (1-8)"
|
||||
exit 1
|
||||
fi
|
||||
exit 1
|
||||
fi;
|
||||
|
||||
if [ ! -d $4 ]
|
||||
then
|
||||
|
@ -144,6 +147,7 @@ run_gpu()
|
|||
python ${BASEPATH}/../train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--run_distribute=$RUN_DISTRIBUTE \
|
||||
--dataset_path=$4 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
--freeze_layer=$FREEZE_LAYER \
|
||||
|
|
|
@ -41,10 +41,6 @@ from src.model_utils.device_adapter import get_device_id, get_device_num, get_ra
|
|||
|
||||
|
||||
set_seed(1)
|
||||
config.device_id = get_device_id()
|
||||
config.rank_id = get_rank_id()
|
||||
config.rank_size = get_device_num()
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
def unzip(zip_file, save_dir):
|
||||
|
@ -105,6 +101,13 @@ def modelarts_pre_process():
|
|||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_mobilenetv2():
|
||||
config.dataset_path = os.path.join(config.dataset_path, 'train')
|
||||
|
||||
config.device_id = get_device_id()
|
||||
config.rank_id = get_rank_id()
|
||||
config.rank_size = get_device_num()
|
||||
if config.platform == 'Ascend':
|
||||
config.run_distribute = config.rank_size > 1.
|
||||
|
||||
print('\nconfig: \n', config)
|
||||
start = time.time()
|
||||
# set context and device init
|
||||
|
|
Loading…
Reference in New Issue