cleancode for cpm

This commit is contained in:
gaojing 2021-06-23 03:33:40 -04:00
parent fcfa6626c8
commit cea687dbd0
10 changed files with 212 additions and 82 deletions

View File

@ -31,7 +31,7 @@
# CPM CPM-Description
This is the fine-tune code warehouse of CPM model, which can be used for multi-machine and multi-card training/testing of model finetune. CPM[Project Home Page](https://cpm.baai.ac.cn/) was proposed in 2020 and a large-scale model based on Chinese processing. CPM is mainly used in the field of Chinese natural language processing (NLP) and generating tasks, such as machine translation, word selection and text summarization.
This is the code warehouse of CPM model, which can be used for multi-card training/testing of model finetune. CPM[Project Home Page](https://cpm.baai.ac.cn/) was proposed in 2020 and a large-scale model based on Chinese processing. CPM is mainly used in the field of Chinese natural language processing (NLP) and generating tasks, such as machine translation, word selection and text summarization.
[Paper](https://arxiv.org/abs/2012.00413): Zhang Z, Han X, Zhou H, et al. CPM: A Large-scale Generative Chinese Pre-trained Language Model[J]. arXiv preprint arXiv:2012.00413, 2020.
@ -57,7 +57,7 @@ CPM is implemented by GPT, which includes multi-layer decoder module.
# Quick Start
After dataset preparation, you can start zero-shot inference and finetune, evaluation as follows:
After dataset preparation, you can start zero-shot inference, finetune and evaluation as follows:
```bash
# run zero-shot inference example
@ -72,9 +72,9 @@ sh run_distribute_train_ascend_single_machine.sh /path/train.mindrecord /path/cp
cd scripts
bash run_eval_distribute_ascend.sh /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 /path/rank_table_2p.json
# Selects the best model on the dev dataset, and then tests the example on the test dataset
# Selects the best model on the dev dataset
cd scripts
bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 /path/rank_table_2p.json
bash run_test_standalone_ascend.sh /path/finetune_dev.mindrecord /path/dev.json /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 0
```
# Script Description
@ -90,6 +90,7 @@ bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json
├─run_zero-shot_inference_distribute_ascend.sh // Shell script for distributed zero-shot on ascend.
├─run_distribute_train_ascend_single_machine.sh // Shell script for distributed finetune on ascend with single machine.
├─run_distribute_train_ascend_multi_machine.sh // Shell script for distributed finetune on ascend with multi-machine.
├─run_test_standalone_ascend.sh // Shell script for standalone evaluation and test on ascend.
├─run_test_distribute_ascend.sh // Shell script for distributed evaluation and test on ascend.
└─run_eval_distribute_ascend.sh // Shell script for distributed evaluation on ascend.
├─data_process
@ -109,11 +110,12 @@ bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json
├─util.py // User interface.
└─weight_init.py // Weight init.
├─gpt_ckpt_2_mindspore.py // Transform the model that MindSpore can load.
├─requirements.txt // Requirements of third party package.
├─requirements.txt // Requirements of third party package.
├─zero-shot.py // Zero-shot api entry.
├─export.py // Export model.
├─sort.py // Sort the accuracy on dev dataset.
├─train.py // Train api entry.
├─test.py // Evaluation and test api entry.
├─test.py // Examples of evaluation and test.
└─eval.py // Infer api entry.
```
@ -145,12 +147,12 @@ Parameters for dataset and network (Training/Evaluation):
### Pre-training Model Download
- The CPM network pre training model can be downloaded here: [Model Download](https://cpm.baai.ac.cn/download.html).
- The pre-trained model of CPM network may be downloaded from here: [Model Download](https://cpm.baai.ac.cn/download.html).
Suppose you have the following documents:
- CPM-large/latest_checkpointed_iteration.txt
- CPM-large/80000/mp_rank_00_model_states.pt
- CPM-large/80000/mp_rank_01_model_states.pt
Next, the model integration script[change_mp.py](https://github.com/TsinghuaAI/CPM-Generate/blob/main/change_mp.py) is used to synthesize the above two fragment models into a complete single model.
Next, you may use the model integration linked script [change_mp.py](https://github.com/TsinghuaAI/CPM-Generate/blob/main/change_mp.py) to synthesize the above two fragment models into a complete single model.
```[bash]
python change_mp.py /path/to/CPM 1
@ -159,10 +161,10 @@ Parameters for dataset and network (Training/Evaluation):
The complete single model is as follows:
- CPM-large_MP1/latest_checkpointed_iteration.txt
- CPM-large_MP1/iter_0080000/mp_rank_01_model_states.pt
Then run the file`gpt_ckpt_2_mindspore.py` in the warehouse to convert the model into the model that can be loaded directly by Mindstore in the warehouse. Pay attention to modify the input and output file address in the file.
We get the model that mindpool can load, such as:`cpm_mindspore_1p_fp32.ckpt`.
Then run the file `gpt_ckpt_2_mindspore.py` in the warehouse to convert the model into the model that can be loaded directly by Mindstore. Pay attention to modify the input and output file address in the file.
We get the model that mindspore can load, such as:`cpm_mindspore_1p_fp32.ckpt`.
- Word segmentation Download: [Model Download](https://github.com/TsinghuaAI/CPM-Finetune/tree/main/bpe_3w_new).
- Word segmentation may be downloaded from here: [Model Download](https://github.com/TsinghuaAI/CPM-Finetune/tree/main/bpe_3w_new).
Suppose you have the following documents:
- bpe_3w_new/chinese_vocab.model
- bpe_3w_new/chinese_vocab.vocab
@ -179,7 +181,7 @@ Parameters for dataset and network (Training/Evaluation):
- chid_json/test.json
- chid_json/test_answer.json
- Data preprocessing: you may use [preprocess_chid_zeroshot.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_zeroshot.py)to process the original data into the corresponding JSON format.
- Data preprocessing: you may use this linked file [preprocess_chid_zeroshot.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_zeroshot.py) to process the original data into the corresponding JSON format.
```[bash]
python preprocess_chid_zeroshot.py --data_dir ${PATH_TO_DATA_DIR} --tokenizer_path ${PATH_TO_TOKENIZER VOCAB} --output_dir ${PATH_TO_OUTPUT_JSON}
@ -255,7 +257,7 @@ After reasoning, the accuracy rate will be generated. Please refer to the`zero-s
## Finetune
In addition to zero shot reasoning, the pre training model can also be trained by finetune.
In addition to zero shot inference, the pre-trained model can also be trained by finetune.
### Dataset Preparation
@ -268,7 +270,7 @@ In addition to zero shot reasoning, the pre training model can also be trained b
- chid_json/test.json
- chid_json/test_answer.json
- Data preprocessing: you may use [preprocess_chid_finetune.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_finetune.py) scripts process the original data into the corresponding JSON format.
- Data preprocessing: you may use this linked [preprocess_chid_finetune.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_finetune.py) script process the original data into the corresponding JSON format.
```[bash]
python preprocess_chid_finetune.py --data_dir ${PATH_TO_DATA_DIR} --tokenizer_path ${PATH_TO_TOKENIZER VOCAB} --output_dir ${PATH_TO_OUTPUT_JSON}
@ -278,7 +280,7 @@ Mainly, `data_dir` is the address of the json data, such as `/home/dataset/chid_
`tokenizer_path` is the address folder for the dictionary, such as `/home/vocab/`.
`output_dir` is the preprocessing output address, such as `/home/dataset/finetune_dataset`.
The template is defined and implemented in `process_sample` function of the `preprocess_chid_finetune.py` file. Finally, the data format generated by the file is as follows:
The template is defined and implemented in `process_sample` function of the `preprocess_chid_finetune.py` file. Finally, the data format generated by the file is as follows:
```[python]
[
@ -290,20 +292,20 @@ The template is defined and implemented in `process_sample` function of the `pre
]
```
After processing, three files, namely `train.json`, `valid.json` and `test.json` will be generated in the output directory `--output_dir`.
After processing, three files, namely `train.json`, `dev.json` and `test.json` will be generated in the output directory `--output_dir`.
- After data preprocessing, the JSON data is transformed into mindrecord data set.
- After data preprocessing, the JSON data is transformed into mindrecord dataset.
```[bash]
cd ./data_process/
python3 make_finetune_mindrecord.py --data_file ${PATH_TO_OUTPUT_JSON} --vocab_path ${PATH_TO_TOKENIZER VOCAB} --output_path ${PATH_TO_OUTPUT FILE} --num_patitions ${NUMBER_OF_MINDRECORD_PARTITIONS}
python3 make_finetune_mindrecord.py --data_file ${PATH_TO_OUTPUT_JSON} --vocab_path ${PATH_TO_TOKENIZER VOCAB} --output_path ${PATH_TO_OUTPUT FILE}
```
Mainly, `data_file` is the JSON data address, such as`/home/dataset/finetune_dataset/train.json` and `/home/dataset/finetune_dataset/test.json`.
`vocab_path` is the address folder directory of the dictionary, its definition is the same as before.
`output_path` is the preprocessing output address, such as`/home/dataset/finetune_dataset/`.
After processing, the mindrecord file of training and reasoning is generated in the specified directory `--output_path`, such as`train.mindrecord` and`test.mindrecord`.
After processing, the mindrecord file of training and reasoning is generated in the specified directory `--output_path`, such as `train.mindrecord`, `dev.mindrecord` and `test.mindrecord`.
### Finetune Training Process
@ -320,13 +322,14 @@ After processing, the mindrecord file of training and reasoning is generated in
``` bash
cd scripts
bash run_distribute_train_ascend_multi_machine.sh Dataset_addr PreTrain_ckpt_addr Rank_table_addr SERVER_ID
bash run_distribute_train_ascend_multi_machine.sh Dataset_addr PreTrain_ckpt_addr Rank_table_addr SERVER_ID RANK_SIZE_ALL
```
Mainly, `Dataset_addr` is the address of the dataset, such as `/home/dataset/finetune_dataset/train.mindrecord`.
`PreTrain_ckpt_addr` is the address for the pre training model, such as `/home/cpm_mindspore_1p_fp32.ckpt`.
`Rank_table_addr` is a rank address for distributed training, such as `/home/rank_table_8p.json`.
`Rank_table_addr` is a rank address for distributed training, such as `/home/rank_table_32p.json`.
`SERVER_ID` is the sequence of the machine numbers from 0 in the multi machine process, such as: 0.
`RANK_SIZE_ALL` 'is the total number of cards used, that is, the total number in `Rank_table_addr`, such as: 32.
**Attention**Because the CPM model is too large to train on one card, distributed training is needed, including model parallel and data parallel.
In distributed parallel training, the device of the machine is the device of the device_ ID is numbered from 1 and incremented by 1.
@ -343,11 +346,11 @@ Mainly, `Dataset_addr` is the address of the dataset, such as `/home/dataset/fin
bash run_eval_distribute_ascend.sh Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number Rank_table_addr
```
In general, we select the model with the highest accuracy on the dev dataset, then infer on the test dataset, and finally generate the accuracy on the test dataset. Model selection can refer to the `run_test_distribute_ascend.sh` and `test.py` files for details.
In general, we select the model with the highest accuracy on the dev dataset, then infer on the test dataset, and finally generate the accuracy on the test dataset. Model selection can refer to the `run_test_standalone_ascend.sh` files for details.
```bash
cd scripts
bash run_test_distribute_ascend.sh Dev_MindRecord_addr Dev_json_addr Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number Rank_table_addr
bash run_test_standalone_ascend.sh Dev_MindRecord_addr Dev_json_addr Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number DEVICEID
```
Mainly, `Test_MindRecord_addr` is the address of the test dataset, such as `/home/dataset/finetune_dataset/test.mindrecord`.
@ -356,7 +359,7 @@ Mainly, `Test_MindRecord_addr` is the address of the test dataset, such as `/hom
`Dev_json_addr` is the dev JSON file after data preprocessing, such as `/home/dataset/finetune_dataset/dev.json`.
`Model_addr` is used to infer the partition model of the folder, such as `/home/finetune_model/`.
`Model_patition_number`is the number of fragmentation modelsexcluding policy file`train_strategy.ckpt`. For example, after 8-card training, the number of partitioned models is 8.
`Rank_table_addr` is a rank address for distributed evaluation, such as `/home/rank_table_2p.json`.
`DEVICEID` is the number of card for standalone evaluation, such as 0.
**Attention**: the dataset preprocessing methods of zero-shot and finetuene are different.
@ -372,7 +375,7 @@ The inference performance and accuracy of zero-shot single machine and dual card
| MindSpore Version | 1.3.0 |
| Dataset | ChID |
| Number of parallel models | 2 |
| Speed | 152ms/step (2pcs) |
| Speed | 140ms/step (2pcs) |
| batch_size | 2 |
| Output | Accuracy |
| Accuracy | accuracy=67.94% |

View File

@ -73,9 +73,9 @@ sh run_distribute_train_ascend_single_machine.sh /path/train.mindrecord /path/cp
cd scripts
bash run_eval_distribute_ascend.sh /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 /path/rank_table_2p.json
# Finetune模型在dev数据集上选最优再在test数据集上测试示例
# Finetune模型在dev数据集上选最优checkpoint测试示例
cd scripts
bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 /path/rank_table_2p.json
bash run_test_standalone_ascend.sh /path/finetune_dev.mindrecord /path/dev.json /path/finetune_test.mindrecord /path/test.json /path/ckpt_dictionary/ 8 0
```
# 脚本说明
@ -91,6 +91,7 @@ bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json
├─run_zero-shot_inference_distribute_ascend.sh // Shell script for distributed zero-shot on ascend.
├─run_distribute_train_ascend_single_machine.sh // Shell script for distributed finetune on ascend with single machine.
├─run_distribute_train_ascend_multi_machine.sh // Shell script for distributed finetune on ascend with multi-machine.
├─run_test_standalone_ascend.sh // Shell script for standalone evaluation and test on ascend.
├─run_test_distribute_ascend.sh // Shell script for distributed evaluation and test on ascend.
└─run_eval_distribute_ascend.sh // Shell script for distributed evaluation on ascend.
├─data_process
@ -113,6 +114,7 @@ bash run_test_distribute_ascend.sh /path/finetune_dev.mindrecord /path/dev.json
├─requirements.txt // Requirements of third party package.
├─zero-shot.py // Zero-shot api entry.
├─export.py // Export model.
├─sort.py // Sort the accuracy on dev dataset.
├─train.py // Train api entry.
├─test.py // Evaluation and test api entry.
└─eval.py // Infer api entry.
@ -151,7 +153,7 @@ Parameters for dataset and network (Training/Evaluation):
- CPM-large/latest_checkpointed_iteration.txt
- CPM-large/80000/mp_rank_00_model_states.pt
- CPM-large/80000/mp_rank_01_model_states.pt
接下来,您可能会使用模型合并脚本[change_mp.py](https://github.com/TsinghuaAI/CPM-Generate/blob/main/change_mp.py)将上述两个分片模型合成完整的单个模型:
接下来,您可能会使用模型合并脚本链接[change_mp.py](https://github.com/TsinghuaAI/CPM-Generate/blob/main/change_mp.py)将上述两个分片模型合成完整的单个模型:
```[bash]
python change_mp.py /path/to/CPM 1
@ -171,7 +173,7 @@ Parameters for dataset and network (Training/Evaluation):
### Zero-shot准备数据集
- 原始数据集下载地址[ChiD-Dataset](https://drive.google.com/drive/folders/1gL01xbFBcrgP0TmgOhJ_uplkeG-BCwvM),可参考[ChiD-Dataset说明](https://github.com/chujiezheng/ChID-Dataset)。
- 原始数据集下载地址[ChiD-Dataset](https://drive.google.com/drive/folders/1gL01xbFBcrgP0TmgOhJ_uplkeG-BCwvM),可参考[ChiD-Dataset说明](https://github.com/chujiezheng/ChID-Dataset)。
假设您已获得下列文件:
- chid_json/train.json
- chid_json/train_answer.json
@ -180,7 +182,7 @@ Parameters for dataset and network (Training/Evaluation):
- chid_json/test.json
- chid_json/test_answer.json
- 数据预处理:您可能会使用脚本[preprocess_chid_zeroshot.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_zeroshot.py)将原始数据处理成相应的json格式。
- 数据预处理:您可能会使用脚本链接[preprocess_chid_zeroshot.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_zeroshot.py)(点击该链接)将原始数据处理成相应的json格式。
```[bash]
python preprocess_chid_zeroshot.py --data_dir ${PATH_TO_DATA_DIR} --tokenizer_path ${PATH_TO_TOKENIZER VOCAB} --output_dir ${PATH_TO_OUTPUT_JSON}
@ -226,7 +228,7 @@ Parameters for dataset and network (Training/Evaluation):
预处理完成后,在上述指定的`--output_dir`输出目录下会生成`test.json`文件。
- 将上一步得到的`--output_dir`路径下产生的json数据转换为MindRecord数据格式
- 在本工程下将上一步得到的`--output_dir`路径下产生的json数据转换为MindRecord数据格式
```[bash]
python make_zero_shot_mindrecord.py --data_file ${PATH_TO_DATA_FILE} --vocab_path ${PATH_TO_TOKENIZER VOCAB} --output_path ${PATH_TO_OUTPUT FILE}
@ -270,7 +272,7 @@ Parameters for dataset and network (Training/Evaluation):
- chid_json/test.json
- chid_json/test_answer.json
- 数据预处理:您可能会使用脚本[preprocess_chid_finetune.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_finetune.py)将原始数据处理成相应的json格式。
- 数据预处理:您可能会使用脚本链接[preprocess_chid_finetune.py](https://github.com/TsinghuaAI/CPM-Finetune/blob/main/preprocess_chid_finetune.py)将原始数据处理成相应的json格式。
```[bash]
python preprocess_chid_finetune.py --data_dir ${PATH_TO_DATA_DIR} --tokenizer_path ${PATH_TO_TOKENIZER VOCAB} --output_dir ${PATH_TO_OUTPUT_JSON}
@ -291,19 +293,19 @@ Parameters for dataset and network (Training/Evaluation):
]
```
处理完成后,在上述指定的`--output_dir`输出目录下会生成 `train.json`, `valid.json`, `test.json` 三个文件。
处理完成后,在上述指定的`--output_dir`输出目录下会生成 `train.json`, `dev.json`, `test.json` 三个文件。
- 将上一步得到的`--output_dir`路径下产生的json数据转换为MindRecord数据格式进行训练
- 在本工程里将上一步得到的`--output_dir`路径下产生的json数据转换为MindRecord数据格式进行训练
```[bash]
cd ./data_process/
python3 make_finetune_mindrecord.py --data_file ${PATH_TO_OUTPUT_JSON} --vocab_path ${PATH_TO_TOKENIZER VOCAB} --output_path ${PATH_TO_OUTPUT FILE} --num_patitions ${NUMBER_OF_MINDRECORD_PARTITIONS}
python3 make_finetune_mindrecord.py --data_file ${PATH_TO_OUTPUT_JSON} --vocab_path ${PATH_TO_TOKENIZER VOCAB} --output_path ${PATH_TO_OUTPUT FILE}
```
主要地,`--data_file`是数据地址,如`/home/dataset/finetune_dataset/train.json``--vocab_path`为字典的地址文件夹目录,同上;
`--output_path`为生成的mindrecord的输出结果文件夹目录,如`/home/dataset/finetune_dataset/`
处理完成后,指定的`--output_path`目录下生成训练和推理的mindrecord文件如`train.mindrecord`和`test.mindrecord`。
处理完成后,指定的`--output_path`目录下生成训练和推理的mindrecord文件如`train.mindrecord`、`dev.mindrecord`和`test.mindrecord`。
### Finetune训练过程
@ -320,13 +322,14 @@ Parameters for dataset and network (Training/Evaluation):
``` bash
cd scripts
bash run_distribute_train_ascend_multi_machine.sh Dataset_addr PreTrain_ckpt_addr Rank_table_addr SERVER_ID
bash run_distribute_train_ascend_multi_machine.sh Dataset_addr PreTrain_ckpt_addr Rank_table_addr SERVER_ID RANK_SIZE_ALL
```
主要地,`Dataset_addr` 是数据地址,如`/home/dataset/finetune_dataset/train.mindrecord`
`PreTrain_ckpt_addr` 为预训练模型的地址,如`/home/cpm_mindspore_1p_fp32.ckpt`
`Rank_table_addr` 为Rank_table的地址,如`/home/rank_table_8p.json`
`SERVER_ID` 为多机过程中机器从0开始编号的的依次顺序0。
`SERVER_ID` 为多机过程中机器从0开始编号的的依次顺序0
`RANK_SIZE_ALL`为使用的卡的总数即Rank_table_addr里的卡的数量。
**注意**由于本CPM模型较大无法在一张卡上训练需要进行分布式训练包括模型并行和数据并行。
分布式并行训练时机器的device的device_id从1开始编号依次递增1。
@ -344,11 +347,11 @@ Parameters for dataset and network (Training/Evaluation):
bash run_eval_distribute_ascend.sh Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number Rank_table_addr
```
通常我们会选择在dev数据集上精度最高的模型再在test数据集上进行推理最后会生成测试集上的准确率模型选择可参考`run_test_distribute_ascend.sh`或`test.py`文件
通常我们会选择在dev数据集上精度最高的模型再在test数据集上进行推理最后会生成测试集上的准确率模型选择可参考`run_test_standalone_ascend.sh`
```bash
cd scripts
bash run_test_distribute_ascend.sh Dev_MindRecord_addr Dev_json_addr Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number Rank_table_addr
bash run_test_standalone_ascend.sh Dev_MindRecord_addr Dev_json_addr Test_MindRecord_addr Test_json_addr Model_addr Model_patition_number DEVICEID
```
主要地, `Test_MindRecord_addr`为test数据集mindrecord如`/home/dataset/finetune_dataset/test.mindrecord`
@ -357,7 +360,7 @@ Parameters for dataset and network (Training/Evaluation):
`Dev_json_addr`为预处理后的dev数据集的json文件如`/home/dataset/finetune_dataset/dev.json`
`Model_addr`为Finetune得到的模型文件夹如`/home/finetune_model/`
`Model_patition_number`为Finetune得到的模型的分片数量不包括策略文件`train_strategy.ckpt`, 如单机8卡得到的为8
`Rank_table_addr`为进行推理的时候的分布式推理的rank_table地址如`/home/rank_table_2p.json`
`DEVICEID`为进行推理的卡如0
注意Finetune的推理的数据集预处理和zero-shot的数据集预处理方式不一样。
@ -373,7 +376,7 @@ Zero-shot单机双卡推理性能和精度如下
| MindSpore版本 | 1.3.0 |
| 数据集 | ChID数据集 |
| 模型并行数 | 2 |
| 速度 | 152毫秒/步 |
| 速度 | 140毫秒/步 |
| Ascend芯片使用数量 | 2 |
| batch_size | 2 |
| 输出 | 准确率 |

View File

@ -91,9 +91,6 @@ if __name__ == '__main__':
parser.add_argument("--output_path", type=str, required=False,
default="./output/train.mindrecord",
help="mindrecord dataset output path.")
parser.add_argument("--num_partitions", type=int, required=False,
default=1, help="the number of mindrecord partitions.")
args = parser.parse_args()
# get the tokenizer
@ -109,7 +106,7 @@ if __name__ == '__main__':
"labels": {"type": "int64", "shape": [-1]},
"size": {"type": "int64"}}
writer = FileWriter(file_name=args.output_path, shard_num=args.num_partitions)
writer = FileWriter(file_name=args.output_path)
writer.add_schema(chid_schema, "preprocessed chid dataset")
data = []
for i in trange(len(chidDataset)):

View File

@ -104,10 +104,6 @@ if __name__ == '__main__':
parser.add_argument("--output_path", type=str, required=False,
default="./output/test.mindrecord",
help="mindrecord dataset output path.")
parser.add_argument("--num_partitions", type=int, required=False,
default=4, help="the number of mindrecord partitions.")
args = parser.parse_args()
# get the tokenizer
@ -124,7 +120,7 @@ if __name__ == '__main__':
"labels": {"type": "int64", "shape": [-1]},
"size": {"type": "int64"}}
writer = FileWriter(file_name=args.output_path, shard_num=args.num_partitions)
writer = FileWriter(file_name=args.output_path)
writer.add_schema(chid_schema, "preprocessed chid dataset")
data = []
for i in tqdm(range(len(chidDataset))):

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] ; then
if [ $# != 5 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distribute_train_ascend_multi_machine.sh DATASET_PATH CKPT_PATH RANK_TABLE_PATH SERVER_ID"
echo "sh run_distribute_train_ascend_multi_machine.sh DATASET_PATH CKPT_PATH RANK_TABLE_PATH SERVER_ID RANK_SIZE_ALL"
echo "for example:"
echo "sh run_distribute_train_ascend_multi_machine.sh /disk0/dataset/finetune_dataset/train.mindrecord /disk0/cpm_ckpt_ms/cpm_mindspore_1p_fp32.ckpt /disk0/rank_table_32p.json 0"
echo "sh run_distribute_train_ascend_multi_machine.sh /disk0/dataset/finetune_dataset/train.mindrecord /disk0/cpm_ckpt_ms/cpm_mindspore_1p_fp32.ckpt /disk0/rank_table_32p.json 0 32"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
@ -37,6 +37,7 @@ echo $DATASET
PRECKPT=$(get_real_path $2)
RANK_TABLE_PATH=$(get_real_path $3)
SERVER_ID=$4
RANK_SIZE_ALL=$5
echo $DATANAME
@ -47,7 +48,7 @@ export RANK_TABLE_FILE=$RANK_TABLE_PATH
echo $RANK_TABLE_FILE
export RANK_SIZE=8
export RANK_SIZE=$RANK_SIZE_ALL
export DEVICE_NUM=8
RANK_START=$(($DEVICE_NUM * $SERVER_ID))

View File

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 7 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_test_standalone_ascend.sh DEV_DATASET_PATH DEV_JSON_PATH TEST_DATASET_PATH TEST_JSON_PATH MODEL_CKPT CKPT_NUMBER DEVICEID"
echo "for example:"
echo "sh run_test_standalone_ascend.sh /disk0/dataset/finetune_dataset/finetune_dev.mindrecord /disk0/dataset/finetune_dataset/dev.json /disk0/dataset/finetune_dataset/finetune_test.mindrecord /disk0/dataset/finetune_dataset/test.json /disk2/ckpt_8p 8 4"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DEV_DATASET=$(get_real_path $1)
echo $DEV_DATASET
DEV_LABEL=$(get_real_path $2)
echo $DEV_LABEL
TEST_DATASET=$(get_real_path $3)
echo $TEST_DATASET
TEST_LABEL=$(get_real_path $4)
echo $TEST_LABEL
MODEL_CKPT=$(get_real_path $5)
echo $MODEL_CKPT
CKPT_NUMBER=$6
echo $CKPT_NUMBER
DEVICEID=$7
echo $DEVICEID
current_exec_path=$(pwd)
echo ${current_exec_path}
result_path=${current_exec_path}/result.txt
rm -rf $result_path
echo ${result_path}
export RANK_SIZE=1
export DEVICE_NUM=1
for((ckptepoch=4;ckptepoch<=10;ckptepoch++));
do
rm -rf ${current_exec_path}/eval_${ckptepoch}
mkdir ${current_exec_path}/eval_${ckptepoch}
cd ${current_exec_path}/eval_${ckptepoch}
cp -r ../../*.py ./
cp -r ../../src ./
cp -r ../../scripts/*.sh ./
export RANK_ID=0
export DEVICE_ID=$DEVICEID
echo "start eval for rank $RANK_ID, device $DEVICE_ID"
echo "start eval for ckpt_epoch: $ckptepoch, result_path: ${result_path}"
env > env.log
python ../../test.py --dev_dataset $DEV_DATASET --dev_data_path $DEV_LABEL \
--test_dataset $TEST_DATASET --test_data_path $TEST_LABEL \
--ckpt_path $MODEL_CKPT --ckpt_partition $CKPT_NUMBER \
--ckpt_epoch $ckptepoch --result_path $result_path \
--distribute False --has_train_strategy True> log_cpm.log 2>&1
cd ${current_exec_path}
done
cd ${current_exec_path}
python ../sort.py --result_path=$result_path > log_result.log 2>&1

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sort."""
import os
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Accuracy sort.")
parser.add_argument("--result_path", type=str, default="scripts/result.txt",
help='Text save address.')
args_eval = parser.parse_args()
result_path = args_eval.result_path
if not os.path.exists(result_path):
print("The result file not found!")
with open(result_path, "r") as file:
epoch = []
result_dev = []
result_test = []
for i, line in enumerate(file):
if i == 0:
continue
curLine = line.strip().split(" ")
epoch.append(curLine[0])
result_dev.append(curLine[1])
result_test.append(curLine[2])
print(epoch, " ", result_dev, ",", result_test)
index_max_dev = result_dev.index(max(result_dev))
acc_last = result_test[index_max_dev]
print("++++ Then accuracy on the test dataset is:", acc_last)

View File

@ -87,7 +87,7 @@ finetune_test_distrubute = ed({
"num_attention_heads": 32
})
config_train_8p = ed({
config_train_single_machine = ed({
"dp": 4,
"mp": 2,
"epoch": 10,
@ -109,7 +109,7 @@ config_train_8p = ed({
"sink_size": 1
})
config_train_32p = ed({
config_train_multi_machine = ed({
"dp": 16,
"mp": 2,
"epoch": 10,

View File

@ -34,7 +34,6 @@ context.set_context(mode=context.GRAPH_MODE,
device_id=device_id)
def set_parallel_env():
r"""
Parallel environment.
@ -63,8 +62,12 @@ if __name__ == '__main__':
help='Whether distributed evaluation with model parallel.')
parser.add_argument("--has_train_strategy", type=ast.literal_eval, default=True,
help='Whether the loaded checkpoints have distributed training strategy.')
parser.add_argument("--result_path", type=str, default="/home/result.txt",
help='Text save address.')
parser.add_argument("--ckpt_epoch", type=int, default=4,
help='The number of checkpoint epochs.')
args_eval = parser.parse_args()
if args_eval.distribute:
set_parallel_env()
print("Start validation on 2 devices.")
@ -80,29 +83,22 @@ if __name__ == '__main__':
strategy_ckpt_load_file=train_strategy_list[0]
)
# start run in dev dataset.
result_dev = []
for i in range(4, 11):
ckpt_file_list_dev = None
if args_eval.has_train_strategy:
# Get the checkpoint slice.
ckpt_file_list_dev = create_ckpt_file_list(args_eval, i)
print("++++ Get sliced checkpoint file, lists: ", ckpt_file_list_dev, flush=True)
result_i = 0.0
if args_eval.distribute:
result_i = run_eval(args_eval, finetune_dev_distrubute, ckpt_file_list_dev)
else:
result_i = run_eval(args_eval, finetune_dev_standalone, ckpt_file_list_dev)
print("+++++ i=", i, ", dev_dataset Accuracy: ", result_i)
result_dev.append(result_i)
print("++++ The accuracy of each checkpoint on the validation dataset is:", result_dev)
print("++++ Then we take the model with the highest accuracy ")
print(" on the validation dataset to predict on the test dataset.")
index_max_dev = result_dev.index(max(result_dev)) + 4
ckpt_file_list_dev = None
if args_eval.has_train_strategy:
# Get the checkpoint slice.
ckpt_file_list_dev = create_ckpt_file_list(args_eval, args_eval.ckpt_epoch)
print("++++ Get sliced checkpoint file, lists: ", ckpt_file_list_dev, flush=True)
result_i = 0.0
if args_eval.distribute:
result_i = run_eval(args_eval, finetune_dev_distrubute, ckpt_file_list_dev)
else:
result_i = run_eval(args_eval, finetune_dev_standalone, ckpt_file_list_dev)
print("+++++ ckpt_epoch=", args_eval.ckpt_epoch, ", dev_dataset Accuracy: ", result_i)
print("++++ Then we take the model to predict on the test dataset.")
ckpt_file_list_test = None
if args_eval.has_train_strategy:
# Get the best precision checkpoint slice.
ckpt_file_list_test = create_ckpt_file_list(args_eval, index_max_dev)
ckpt_file_list_test = create_ckpt_file_list(args_eval, args_eval.ckpt_epoch)
args_eval.dataset = args_eval.test_dataset
args_eval.data_path = args_eval.test_data_path
@ -113,3 +109,12 @@ if __name__ == '__main__':
else:
result_last = run_eval(args_eval, finetune_test_standalone, ckpt_file_list_test)
print("++++ Accuracy on test dataset is: ", result_last)
# write to file.
result_path = args_eval.result_path
if not os.path.exists(result_path):
with open(result_path, "w") as file:
file.write("CkptEpcoh Accuracy_dev Accuracy_test\n")
with open(result_path, "a") as file:
file.write(str(args_eval.ckpt_epoch) + " " + str(result_i) + " " + str(result_last) + "\n")

View File

@ -34,7 +34,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.transforms.c_transforms as C
from mindspore.parallel import set_algo_parameters
from src.config import config_train_8p, config_train_32p
from src.config import config_train_single_machine, config_train_multi_machine
from src.cpm_train import CPMWithLoss, CPMTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell, \
CPMTrainAccuStepsWithLossScaleCell
from src.lr_schedule import CPMLearningRate
@ -263,7 +263,7 @@ if __name__ == '__main__':
args = parser.parse_args()
if args.multi_machine:
print("Training on multiple machines")
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_32p)
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_multi_machine)
else:
print("Training on single machine and using 8 cards.")
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_8p)
train_paralle(args.dataset, args.pretrain_ckpt_path, config_train_single_machine)