forked from OSSInnovation/mindspore
train paramters for GPU
This commit is contained in:
parent
8d41931456
commit
9748a3d2ee
|
@ -32,13 +32,21 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$3
|
||||
export RANK_SIZE=$3
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
|
@ -46,12 +54,12 @@ do
|
|||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp *.py ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -46,17 +46,17 @@ config_gpu = ed({
|
|||
"loss_scale": 128,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"epoch_size": 45,
|
||||
"epoch_size": 40,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 15,
|
||||
"save_checkpoint_path": "./",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.04,
|
||||
"lr_decay": 5,
|
||||
"lr_end_epoch": 58,
|
||||
"damping_init": 0.02,
|
||||
"damping_decay": 0.87,
|
||||
"lr_init": 0.05672,
|
||||
"lr_decay": 4.9687,
|
||||
"lr_end_epoch": 50,
|
||||
"damping_init": 0.02345,
|
||||
"damping_decay": 0.5467,
|
||||
"frequency": 834,
|
||||
})
|
||||
|
|
|
@ -109,6 +109,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
|
|
Loading…
Reference in New Issue