This commit is contained in:
huchunmei 2021-07-21 14:03:01 +08:00
parent 5c40ca1f23
commit 6779c4b18b
3 changed files with 21 additions and 14 deletions

View File

@ -24,7 +24,7 @@ lr_init: 0.00
lr_end: 0.00
lr_max: 0.4
momentum: 0.9
weight_decay: 0.00004 # 4e-5
weight_decay: 0.00004
label_smooth: 0.1
loss_scale: 1024
save_checkpoint: True
@ -32,10 +32,10 @@ save_checkpoint_epochs: 1
keep_checkpoint_max: 200
save_checkpoint_path: "./"
platform: 'Ascend'
device_id: int(os.getenv('DEVICE_ID', '0'))
rank_id: int(os.getenv('RANK_ID', '0'))
rank_size: int(os.getenv('RANK_SIZE', '1'))
run_distribute: int(os.getenv('RANK_SIZE', '1')) > 1.
device_id: 0
rank_id: 0
rank_size: 1
run_distribute: False
activation: "Softmax"
# Image classification trian. train_parse_args():return train_args

View File

@ -35,7 +35,7 @@ run_ascend()
exit 1
fi;
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
if [ $2 -lt 1 ] || [ $2 -gt 8 ]
then
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
@ -116,11 +116,14 @@ run_gpu()
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi;
if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then
if [ $2 -eq 1 ] ; then
RUN_DISTRIBUTE=False
elif [ $2 -gt 1 ] && [ $2 -le 8 ] ; then
RUN_DISTRIBUTE=True
else
echo "error: DEVICE_NUM=$2 is not in (1-8)"
exit 1
fi
exit 1
fi;
if [ ! -d $4 ]
then
@ -144,6 +147,7 @@ run_gpu()
python ${BASEPATH}/../train.py \
--config_path=$CONFIG_FILE \
--platform=$1 \
--run_distribute=$RUN_DISTRIBUTE \
--dataset_path=$4 \
--pretrain_ckpt=$PRETRAINED_CKPT \
--freeze_layer=$FREEZE_LAYER \

View File

@ -41,10 +41,6 @@ from src.model_utils.device_adapter import get_device_id, get_device_num, get_ra
set_seed(1)
config.device_id = get_device_id()
config.rank_id = get_rank_id()
config.rank_size = get_device_num()
def modelarts_pre_process():
def unzip(zip_file, save_dir):
@ -105,6 +101,13 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_mobilenetv2():
config.dataset_path = os.path.join(config.dataset_path, 'train')
config.device_id = get_device_id()
config.rank_id = get_rank_id()
config.rank_size = get_device_num()
if config.platform == 'Ascend':
config.run_distribute = config.rank_size > 1.
print('\nconfig: \n', config)
start = time.time()
# set context and device init