forked from mindspore-Ecosystem/mindspore
update tiny bert script and readme
This commit is contained in:
parent
86c39d34ad
commit
fe016d321c
|
@ -44,12 +44,12 @@ After installing MindSpore via the official website, you can start general disti
|
|||
# run standalone general distill example
|
||||
bash scripts/run_standalone_gd.sh
|
||||
|
||||
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir` and `schema_dir` in the run_standalone_gd.sh file first. If running on GPU, please set the `device_target=GPU`.
|
||||
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir`, `schema_dir` and `dataset_type` in the run_standalone_gd.sh file first. If running on GPU, please set the `device_target=GPU`.
|
||||
|
||||
# For Ascend device, run distributed general distill example
|
||||
bash scripts/run_distributed_gd_ascend.sh 8 1 /path/hccl.json
|
||||
|
||||
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir` and `schema_dir` in the run_distributed_gd_ascend.sh file first.
|
||||
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir`, `schema_dir` and `dataset_type` in the run_distributed_gd_ascend.sh file first.
|
||||
|
||||
# For GPU device, run distributed general distill example
|
||||
bash scripts/run_distributed_gd_gpu.sh 8 1 /path/data/ /path/schema.json /path/teacher.ckpt
|
||||
|
@ -57,7 +57,7 @@ bash scripts/run_distributed_gd_gpu.sh 8 1 /path/data/ /path/schema.json /path/t
|
|||
# run task distill and evaluation example
|
||||
bash scripts/run_standalone_td.sh
|
||||
|
||||
Before running the shell script, please set the `task_name`, `load_teacher_ckpt_path`, `load_gd_ckpt_path`, `train_data_dir`, `eval_data_dir` and `schema_dir` in the run_standalone_td.sh file first.
|
||||
Before running the shell script, please set the `task_name`, `load_teacher_ckpt_path`, `load_gd_ckpt_path`, `train_data_dir`, `eval_data_dir`, `schema_dir` and `dataset_type` in the run_standalone_td.sh file first.
|
||||
If running on GPU, please set the `device_target=GPU`.
|
||||
```
|
||||
|
||||
|
@ -101,7 +101,7 @@ usage: run_general_distill.py [--distribute DISTRIBUTE] [--epoch_size N] [----
|
|||
[--save_ckpt_path SAVE_CKPT_PATH]
|
||||
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
|
||||
[--save_checkpoint_step N] [--max_ckpt_num N]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] [train_steps N]
|
||||
|
||||
options:
|
||||
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||
|
@ -118,6 +118,7 @@ options:
|
|||
--load_teacher_ckpt_path path to load teacher checkpoint files: PATH, default is ""
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
|
||||
```
|
||||
|
||||
### Task Distill
|
||||
|
@ -132,7 +133,7 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN
|
|||
[--load_td1_ckpt_path LOAD_TD1_CKPT_PATH]
|
||||
[--train_data_dir TRAIN_DATA_DIR]
|
||||
[--eval_data_dir EVAL_DATA_DIR]
|
||||
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR]
|
||||
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE]
|
||||
|
||||
options:
|
||||
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
|
||||
|
@ -153,6 +154,7 @@ options:
|
|||
--eval_data_dir path to eval dataset directory: PATH, default is ""
|
||||
--task_name classification task: "SST-2" | "QNLI" | "MNLI", default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
|
||||
```
|
||||
|
||||
## Options and Parameters
|
||||
|
@ -344,4 +346,4 @@ In run_general_distill.py, we set the random seed to make sure distribute traini
|
|||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -55,7 +55,8 @@ def run_general_distill():
|
|||
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
|
|
|
@ -68,7 +68,8 @@ def parse_args():
|
|||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord",
|
||||
help="dataset type tfrecord/mindrecord, default is tfrecord")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ do
|
|||
--max_ckpt_num=1 \
|
||||
--load_teacher_ckpt_path="" \
|
||||
--data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
||||
--schema_dir="" \
|
||||
--dataset_type="tfrecord" > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
|
|
|
@ -37,5 +37,6 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
|
|||
--save_ckpt_path="" \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR \
|
||||
--dataset_type="tfrecord" \
|
||||
--enable_data_sink=False \
|
||||
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &
|
||||
|
|
|
@ -43,5 +43,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \
|
|||
--load_td1_ckpt_path="" \
|
||||
--train_data_dir="" \
|
||||
--eval_data_dir="" \
|
||||
--schema_dir="" > log.txt 2>&1 &
|
||||
--schema_dir="" \
|
||||
--dataset_type="tfrecord" > log.txt 2>&1 &
|
||||
|
||||
|
|
Loading…
Reference in New Issue