forked from mindspore-Ecosystem/mindspore
!5708 support startup script of bert_thor with relative path
Merge pull request !5708 from wangmin0104/master
This commit is contained in:
commit
1cd4ac2ad4
|
@ -128,12 +128,12 @@ Parameters for both training and inference can be set in config.py.
|
||||||
```
|
```
|
||||||
sh run_distribute_pretrain.sh [DEVICE_NUM] [EPOCH_SIZE] [DATA_DIR] [SCHEMA_DIR] [RANK_TABLE_FILE]
|
sh run_distribute_pretrain.sh [DEVICE_NUM] [EPOCH_SIZE] [DATA_DIR] [SCHEMA_DIR] [RANK_TABLE_FILE]
|
||||||
```
|
```
|
||||||
We need three parameters for this scripts.
|
We need five parameters for this scripts.
|
||||||
- `DEVICE_NUM`: the device number for distributed train.
|
- `DEVICE_NUM`: the device number for distributed train.
|
||||||
- `EPOCH_SIZE`: Epoch size used in the model
|
- `EPOCH_SIZE`: Epoch size used in the model
|
||||||
- `DATA_DIR`:Data path, it is better to use absolute path.
|
- `DATA_DIR`:Data path, it is better to use absolute path.
|
||||||
- `SCHEMA_DIR `:Schema path, it is better to use absolute path
|
- `SCHEMA_DIR `:Schema path, it is better to use absolute path
|
||||||
- `RANK_TABLE_FILE`: the path of rank_table.json
|
- `RANK_TABLE_FILE`: rank table file with JSON format
|
||||||
|
|
||||||
Training result will be stored in the current path, whose folder name begins with the file name that the user defines. Under this, you can find checkpoint file together with result like the followings in log.
|
Training result will be stored in the current path, whose folder name begins with the file name that the user defines. Under this, you can find checkpoint file together with result like the followings in log.
|
||||||
```
|
```
|
||||||
|
|
|
@ -153,10 +153,8 @@ def MLM_eval():
|
||||||
net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2],
|
net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2],
|
||||||
metrics={'name': myMetric()})
|
metrics={'name': myMetric()})
|
||||||
res = net.eval(dataset, dataset_sink_mode=False)
|
res = net.eval(dataset, dataset_sink_mode=False)
|
||||||
print("==============================================================")
|
|
||||||
for _, v in res.items():
|
for _, v in res.items():
|
||||||
print("Accuracy is: ", v)
|
print("Accuracy is: ", v)
|
||||||
print("==============================================================")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -25,6 +25,9 @@ EPOCH_SIZE=$2
|
||||||
DATA_DIR=$3
|
DATA_DIR=$3
|
||||||
SCHEMA_DIR=$4
|
SCHEMA_DIR=$4
|
||||||
|
|
||||||
|
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
cd $BASE_PATH/ || exit
|
||||||
|
|
||||||
ulimit -u unlimited
|
ulimit -u unlimited
|
||||||
export RANK_TABLE_FILE=$5
|
export RANK_TABLE_FILE=$5
|
||||||
export RANK_SIZE=$1
|
export RANK_SIZE=$1
|
||||||
|
@ -55,6 +58,7 @@ do
|
||||||
--load_checkpoint_path="" \
|
--load_checkpoint_path="" \
|
||||||
--save_checkpoint_path='./' \
|
--save_checkpoint_path='./' \
|
||||||
--save_checkpoint_steps=1000 \
|
--save_checkpoint_steps=1000 \
|
||||||
|
--train_steps=3000 \
|
||||||
--save_checkpoint_num=30 \
|
--save_checkpoint_num=30 \
|
||||||
--data_dir=$DATA_DIR \
|
--data_dir=$DATA_DIR \
|
||||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
||||||
|
|
|
@ -24,6 +24,9 @@ EPOCH_SIZE=$2
|
||||||
DATA_DIR=$3
|
DATA_DIR=$3
|
||||||
SCHEMA_DIR=$4
|
SCHEMA_DIR=$4
|
||||||
|
|
||||||
|
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
cd $BASE_PATH/ || exit
|
||||||
|
|
||||||
ulimit -u unlimited
|
ulimit -u unlimited
|
||||||
export DEVICE_ID=$1
|
export DEVICE_ID=$1
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
|
@ -51,6 +54,7 @@ python run_pretrain.py \
|
||||||
--load_checkpoint_path="" \
|
--load_checkpoint_path="" \
|
||||||
--save_checkpoint_path='./' \
|
--save_checkpoint_path='./' \
|
||||||
--save_checkpoint_steps=5000 \
|
--save_checkpoint_steps=5000 \
|
||||||
|
--train_steps=-1 \
|
||||||
--save_checkpoint_num=20 \
|
--save_checkpoint_num=20 \
|
||||||
--data_dir=$DATA_DIR \
|
--data_dir=$DATA_DIR \
|
||||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
||||||
|
|
|
@ -55,7 +55,7 @@ def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps,
|
||||||
return learning_rate
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
# bert kfac hyperparam setting
|
# bert thor hyperparam setting
|
||||||
def get_bert_lr():
|
def get_bert_lr():
|
||||||
learning_rate = Tensor(
|
learning_rate = Tensor(
|
||||||
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000,
|
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000,
|
||||||
|
|
Loading…
Reference in New Issue