add gpt2 to model zoo

This commit is contained in:
ouyangyangxy 2020-12-29 17:01:03 +08:00
parent 44e068f643
commit a465a48fb7
55 changed files with 60443 additions and 0 deletions

View File

@ -0,0 +1,931 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [GPT-2模型](#GPT-2模型)
- [模型架构](#模型架构)
- [下游任务](#下游任务)
- [脚本说明](#脚本说明)
- [模型转换](#模型转换)
- [准备数据集](#准备数据集)
- [Language Modeling 语言建模任务](#Language Modeling语言建模任务)
- [Children's Book Test 任务](#Children's Book Test任务)
- [LAMBADA 任务](#LAMBADA任务)
- [Reading Comprehension 任务](#Reading Comprehension任务)
- [Summarization 任务](#Summarization任务)
- [Translation 任务](#Translation任务)
- [配置](#配置)
- [微调&评估过程](#微调&训练评估过程)
- [Language Modeling 任务](#Language Modeling任务)
- 微调
- 评估
- [Children's Book Test 任务](#Children's Book Test任务)
- 评估
- [LAMBADA 任务](#LAMBADA任务)
- 评估
- [Reading Comprehension 任务](#Reading Comprehension任务)
- 评估
- [Summarization 任务](#Summarization任务)
- 评估
- [Translation 任务](#Translation任务)
- 评估
- [环境要求](#环境要求)
- [平台](#平台)
- [其他要求](#其他要求)
- [性能](#性能)
- [推理性能](#推理性能)
- [Language Modeling 任务](#Language Modeling任务)
- [Children's Book Test 任务](#Children's Book Test任务)
- [LAMBADA 任务](#LAMBADA任务)
- [Reading Comprehension 任务](#Reading Comprehension任务)
- [Summarization 任务](#Summarization任务)
- [Translation 任务](#Translation任务)
- [训练性能](#训练性能)
- [推理性能](#推理性能)
- [其他](#其他)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# GPT-2模型
[GPT-2介绍](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 由Open于2019年发布。GPT-2模型是继承于GPT模型GPT-2是一个非常庞大的语言模型它主要是用于预测下一个单词。按照参数量的大小GPT-2模型可分为small117M、medium345M、large762M、xlarge1542M
[GPT-2介绍](https://openai.com/blog/better-language-models/)
[GPT-2论文](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf): Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
# 模型架构
GPT-2模型由Transformer的解码器实现Transformer包括多个编码器层和多个解码器层但在GPT-2模型中仅使用了Transformer的解码器部分。
微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。
测试过程中通过微调后的模型预测结果对于某些任务可以直接进行zero-shot评估即可。
# 下游任务
本文主要涉及6个下游任务包括
- Language Modeling 任务
- Childrens Book Test 任务
- LAMBADA任务
- Reading Comprehension任务
- Summarization任务
- Translation任务
数据集相关信息,参见[https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)。
## 脚本说明
GPT-2脚本及代码结构如下
```text
├── GPT-2
├── README.md // MASS模型介绍
├── scripts
│ ├──run_cbt.sh // CBT任务的微调&评估脚本
│ ├──run_lambada.sh // LAMBADA任务的微调&评估脚本
│ ├──run_language_model.sh // 语言建模任务的微调&评估脚本
│ ├──run_read_comprehension.sh // 阅读理解任务的微调&评估脚本
│ ├──run_summarization.sh // 摘要生成任务的微调&评估脚本
│ ├──run_translation.sh // 翻译任务的微调&评估脚本
├──src
│ ├──clip_grad_utils.py // 用于梯度裁剪
| ├──dataset.py // 数据集加载用于微调或推理
│ ├──finetune_eval_config.py // 微调和推理配置文件
│ ├──gpt2_for_finetune.py // 用于梯度裁剪
| ├──GPT2_generation.py // 生成模块
│ ├──GPT2_model.py // GPT2模型脚本
│ ├──GPT2ForCBT.py // CBT任务的模型脚本
│ ├──GPT2ForLanguageModel.py // 语言建模任务的模型脚本
│ ├──GPT2ForReadComprehension.py // 阅读理解任务的模型脚本
│ ├──GPT2ForSummarization.py // 摘要生成任务的模型脚本
│ ├──GPT2ForTranslation.py // 翻译任务的模型脚本
│ ├──weight_init.py // 初始化权重
│ ├──utils
│ ├──bleu_score.py // 用于计算BLEU分数
│ ├──rouge_score.py // 用于计算ROUGE分数
│ ├──CrossEntropy.py // 交叉熵损失
│ ├──data_preprocess.py // 数据集预处理脚本
│ ├──generation_utils.py // 用于帮助生成模型,包含采样等方法
│ ├──get_config_setting.py // 获取配置信息
│ ├──task_utils.py // 辅助下游任务的功能脚本
│ ├──lr_schedule.py // 学习率策略脚本
│ ├──metric_method.py // 下游任务的评价指标
│ ├──tensor_manipulations.py // 涉及张量操作
│ ├──tokenization.py // 标记化包含BPE编码和解码
│ ├──pretrain-data
│ ├──stopwords.txt // 用于LAMBADA任务的stopword filter
├──create_cbt_data.py // 用于CBT任务创建mindrecord
├──create_lambada_data.py // 用于lambada任务创建mindrecord
├──create_lambada_data.py // 用于其他任务创建mindrecord
├──create_summary_data.py // 用于summarization任务创建mindrecord
├──download_cnn_dailymail.py // 下载CNN & Dailymail数据集
├──cnn_dataset_sampler.py // CNN & Dailymail训练集采样器
├──eval_rc_addition_answer.py // 使用addition_answer评估阅读理解任务
├──run_CBT_task.py // CBT任务微调&推理API入口
├──run_lambada.py // LAMBADA任务微调&推理API入口
├──run_language_mdoel.py // 语言建模任务微调&推理API入口
├──run_ReadComprehension.py // 阅读理解任务微调&推理API入口
├──run_summarization.py // 摘要生成任务微调&推理API入口
├──run_translation.py // 翻译任务微调&推理API入口
├──task_dataset_preprocess.py // 各个任务的数据集处理入口
├──convert_tf_ckpt
│ ├──read_weight_tf.py // 读取tensorflow下的预训练模型
│ ├──trans_dict.py // 模型参数名称字典
│ ├──save_weight_ms.py // 生成mindspore ckpt
├──third_party
│ ├──gpt2-merges.txt
│ ├──gpt2-vocab.json // GPT-2预训练词表
│ ├──bleu.py // 辅助bleu值计算的第三方代码
```
## 模型转换
- 下载GPT-2的预训练模型 [GPT-2预训练模型下载](https://github.com/openai/gpt-2/blob/master/download_model.py)
- 在tensorflow的环境下运行`read_weight_tf.py`,示例代码如下:
`python read_weight_tf.py --ckpt_file_path=/{path}/model.ckpt`
- 在mindspore的环境下运行`save_weight_ms.py`,示例代码如下:
`python save_weight_ms.py --output_file_name="mindspore_gpt2_small.ckpt"`
## 准备数据集
### Language Modeling语言建模任务
#### WikiText2 、WikiText103、PTB、1BW 数据集
- [WikiText2数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) 解压后使用`wikitext-2 /wiki.test.tokens`作为测试集
- [WikiText103数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 解压后使用`wikitext-103 /wiki.test.tokens`作为测试集
- [PTB数据集下载](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz) 解压后使用 `/simple-examples/data/ptb.test.txt` 测试集,使用 `/simple-examples/data/ptb.test.txt` 作为训练集
- [1BW数据集下载](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz) 解压后使用`1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050`作为测试集,使用`1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00001-of-00100`作为原始训练集进行随机采样后得到30000条训练集样本
使用`task_dataset_preprocess.py`可以对以上数据集进行清洗。
`task_dataset_preprocess.py`的主要参数如下:
```bash
--task: The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada, Summarization, ReadingComprehension].
--input_file: The raw dataset path.
--dataset: The name of dataset which should be processed, only for LanguageModeling task.
--output_file: The output dataset path after preprocessing.
--condition: Process train or test dataset, including [train, test], only for 1BW and CNN & DailyMail dataset.
```
示例代码如下:
清洗PTB训练集和测试集
```bash
python task_dataset_preprocess.py --task "LanguageModeling" --input_file /{path}/ptb.test.txt --dataset "ptb" --output_file /{path}/ptb_clean_test.txt --condition "test"
```
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
`create_lm_data.py`的主要参数如下:
```bash
--input_file: Input raw text file.
--output_file: Output MindRecord file.
--num_splits: The MindRecord file will be split into the number of partition.
--max_seq_length: Maximum sequence length.
--vocab_file: url of gpt2-vocab.json.
--merge_file: url of gpt2-merges.txt
```
示例代码如下:
```bash
python create_lm_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
```
### Children's Book Test任务
#### CBT-CN / CBT-NE 数据集
- [CBT数据集下载](http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz) 使用在`/data`目录下使用`cbtest_CN_valid_2000ex.txt、cbtest_NE_valid_2000ex.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
```bash
python task_dataset_preprocess.py --task "CBT" --input_file /{path}/cbtest_CN_valid_2000ex.txt --dataset "cbt" --output_file /{path}/cbt_cn_valid.txt
```
使用`create_cbt_data.py`可以将以上数据集格式转换为mindrecord
`create_cbt_data.py`的主要参数如下:
```bash
--input_file: Input raw text file.
--output_file: Output MindRecord file.
--num_splits: The MindRecord file will be split into the number of partition.
--max_seq_length: Maximum sequence length.
--num_choice: Number of choices.
--vocab_file: url of gpt2-vocab.json.
--merge_file: url of gpt2-merges.txt
```
示例代码如下:
```bash
python create_cbt_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --num_choice 10 --vocab_file={path} --merge_file={path}
```
### LAMBADA任务
#### LAMBADA 数据集
- [LAMBADA数据集下载](https://zenodo.org/record/2630551#.X-yCSTTithH) 使用`lambada_test_plain_text.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
```bash
python task_dataset_preprocess.py --task "LAMBADA" --input_file /{path}/lambada_test_plain_text.txt --dataset "LAMBADA" --output_file /{path}/lambada_test_clean.txt
```
使用`create_lambada_data.py`可以将以上数据集格式转换为mindrecord
`create_lambada_data.py`的主要参数如下:
```bash
--input_file: Input raw text file.
--output_file: Output MindRecord file.
--num_splits: The MindRecord file will be split into the number of partition.
--max_seq_length: Maximum sequence length.
--vocab_file: url of gpt2-vocab.json.
--merge_file: url of gpt2-merges.txt
```
示例代码如下:
```bash
python create_lambada_data.py --input_file /{path}/lambada_test_clean.txt --output_file /{path}/lambada-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
```
### Reading Comprehension 任务
#### CoQA数据集
- [CoQA数据集下载](http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json) 使用`coqa-dev-v1.0.json`作为该任务的评估集,清洗该数据集,示例代码如下:
```bash
python task_dataset_preprocess.py --task "ReadingComprehension" --input_file /{path}/coqa-dev-v1.0.json --dataset "coqa" --output_file /{path}/coqa_dev.txt
```
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
示例代码如下:
```bash
python create_lm_data.py --input_file /{path}/coqa_dev.txt --output_file /{path}/coqa-dev-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
```
### Summarization 任务
#### CNN & Dailymail数据集
- 下载该数据集,使用`download_cnn_dailymail.py`脚本进行下载,示例代码如下:
```bash
下载测试集
python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split test
下载训练集
python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split train
```
从训练集中随机采用10000条样本作为最终的微调的训练集使用`cnn_dataset_sampler.py`脚本进行训练的采样操作,生成新的训练集,示例代码如下:
```bash
GPT-2 small和GPT-2 medium模型的训练集中seq_length=1024, 因此该脚本中设置max_length=1022
python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
--output_path="/{path}/cnn_train_hint_small.txt"
--replace_hint="true"
--sample="true"
--max_length=1022
--prob=0.25
--max_items=10000
--hint="TL;DR:"
GPT-2 large模型的训练集中seq_length=768,因此该脚本中设置max_length=766
python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
--output_path="/{path}/cnn_train_hint_large.txt"
--replace_hint="true"
--sample="true"
--max_length=766
--prob=0.25
--max_items=10000
--hint="TL;DR:"
```
使用`create_summary_data.py`可以将以上数据集格式转换为mindrecord
示例代码如下:
```bash
python create_summary_data.py --input_file /{path}/cnn_dailymail_test.txt --output_file /{path}/cnn_dailymail-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} --mode 'cnn_dailymail'
```
### Translation 任务
#### WMT14 En-Fr数据集
- [WMT14 En-Fr数据集下载](http://statmt.org/wmt14/test-full.tgz) 使用`newstest2014-fren-ref.en.sgm`和`newstest2014-fren-ref.fr.sgm`作为该任务的评估集,合并且清洗该数据集,示例代码如下:
```bash
python task_dataset_preprocess.py --task "Translation" --input_file /{path}/test-full --dataset "wmt14" --output_file /{path}/wmt14
```
在`output_file`路径下会生成两个文件`wmt14.en_fr.txt`和`wmt14.fr_en.txt`,分别用于评估`En-Fr`和`Fr-En`。
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
示例代码如下:
```bash
python create_lm_data.py --input_file /{path}/wmt14.en_fr.txt --output_file /{path}/en-fr-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
python create_lm_data.py --input_file /{path}/wmt14.fr_en.txt --output_file /{path}/fr-en-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
```
## 配置
`src/finetune_eval_config.py`为GPT-2模型训练和推理的配置文件便于为大多数选项及参数赋值包括GPT-2 模型规模、模型的配置、优化器参数等。
有关属性的详细信息,参见`src/finetune_eval_config.py`文件。
## 微调&评估过程
### Language Modeling 语言建模任务
#### 微调
- PTB数据集
GPT-2 small / GPT-2 medium / GPT-2 large模型需要在PTB训练集上进行微调。微调模型时只需要使用shell脚本`scripts/run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`scripts/run_language_model.sh`脚本。
微调模型时,首先配置`src/finetune_eval_config.py`中的选项:
将`cfg`下的`gpt2_network`设置为相应的GPT-2模型大小`[small/medium/large]`。
将`cfg`下的`optimizer`设置为`Lamb`,进行优化器的选择(可采用'momentum/adam/lamb
选定了GPT-2模型后需要设置模型的参数包括`batch_size`和`seq_length`。
而后执行`scripts/run_language_model.sh`这个shell脚本
```bash
sh scripts/run_language_model.sh --device_target="Ascend"
--do_train="true"
--do_eval="false"
--epoch_num=1
--train_data_shuffle="true"
--eval_data_shuffle="false"
--save_finetune_ckpt_path={save_finetune_ckpt_path}
--load_pretrain_ckpt_path={load_pretrain_ckpt_path}
--train_data_file_path={train_data_file_path}
```
日志和输出文件可以在`./ms_log/`路径下获取。
```bash
sh scripts/run_language_model.sh [--options]
```
`run_language_model.sh`的用法如下:
```text
usage: run_language_model.sh [--device_target DEVICE_TARGET] [--device_id N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--metric_method The eval method including [PPL]. Default: "PPL"
--do_train Enable train. Default: "false"
--do_eval Enable evaluation. Default: "true"
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
--epoch_num Epoch number. Default: 1
--train_data_shuffle Enable train data shuffle. Default: "true"
--eval_data_shuffle Enable eval data shuffle. Default: "false"
--save_finetune_ckpt_path Save the finetuned checkpoint path
--load_pretrain_ckpt_path Load the checkpoint file path for train
--load_finetune_ckpt_path Load the checkpoint file path for evaluation
--train_data_file_path Data path, it is better to use absolute path
--eval_data_file_path Data path, it is better to use absolute path
```
- 1BW数据集
GPT-2 large模型需要在1BW训练集上进行微调。微调模型时只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。该微调方法与PTB数据集的一致。
#### 评估
GPT-2模型可以在`WikiText2/WikiText103/PTB/1BW`测试集上进行对应的评估针对以上数据集的评估其评估方法采用PPL即设置`--metric_method="PPL"`。
评估模型时只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_language_model.sh`这个shell脚本若该模型在某个数据集上被微调了则使用该模型进行对应测试集的评估时需要设置`--eval_type="finetuned"`,否则设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是微调好后的checkpoint文件位置
```bash
sh scripts/run_language_model.sh --device_target="Ascend"
--metric_method="PPL"
--do_train="false"
--do_eval="true"
--eval_type="finetuned"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
```
日志和输出文件可以在`./ms_log/`路径下获取。
### Children's Book Test任务
#### 评估
GPT-2模型可以在`CBT-CN/CBT-NE`验证集上进行对应的评估针对以上数据集的评估其评估方法采用Accuracy即设置`--metric_method="Accuracy"`。
评估模型时只需要使用shell脚本`run_cbt.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_CBT_task.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_cbt.sh`这个shell脚本且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
```bash
sh scripts/run_cbt.sh --device_target="Ascend"
--num_choice=10
--metric_method="Accuarcy"
--do_train="false"
--do_eval="true"
--eval_type="zero-shot"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
```
日志和输出文件可以在`./ms_log/`路径下获取。
```bash
sh scripts/run_cbt.sh [--options]
```
`run_cbt.sh`的用法如下:
```text
usage: run_CBT_task.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--num_choice The number of choice in CBT task
--metric_method The eval method including [Accuracy]. Default: "Accuracy"
--do_train Enable train. Default: "false"
--do_eval Enable evaluation. Default: "true"
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
--epoch_num Epoch number. Default: 1
--train_data_shuffle Enable train data shuffle. Default: "true"
--eval_data_shuffle Enable eval data shuffle. Default: "false"
--save_finetune_ckpt_path Save the finetuned checkpoint path
--load_pretrain_ckpt_path Load the checkpoint file path for train
--load_finetune_ckpt_path Load the checkpoint file path for evaluation
--train_data_file_path Data path, it is better to use absolute path
--eval_data_file_path Data path, it is better to use absolute path
```
### LAMBADA任务
#### 评估
GPT-2模型可以在`LAMBADA`测试集上进行对应的评估针对以上数据集的评估其评估方法采用Accuracy和PPL即设置`--metric_method="Accuracy"` 或者`--metric_method="PPL"`。
评估模型时只需要使用shell脚本`run_lambada.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_lambada.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_lambada.sh`这个shell脚本且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
评估Accuracy
```bash
sh scripts/run_lambada.sh --device_target="Ascend"
--metric_method="Accuarcy"
--do_train="false"
--do_eval="true"
--eval_type="zero-shot"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--generate_length_dynamically="true"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
--tokenizer_file_path={tokenizer_file_path}
--stop_word_file_path={stop_word_file_path}
```
评估PPL
```bash
sh scripts/run_lambada.sh --device_target="Ascend"
--metric_method="PPL"
--do_train="false"
--do_eval="true"
--eval_type="zero-shot"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
```
日志和输出文件可以在`./ms_log/`路径下获取。
```bash
sh scripts/run_lambada.sh [--options]
```
```text
usage: run_lambada.sh [--device_target DEVICE_TARGET] [--device_id N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--generate_length_dynamically GENERATE_LENGTH_DYNAMICALLY]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
[--tokenizer_file_path TOKENIZER_FILE_PATH]
[--stop_word_file_path STOP_WORD_FILE_PATH]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--metric_method The eval method including [Accuracy, PPL]. Default: "Accuracy"
--do_train Enable train. Default: "false"
--do_eval Enable evaluation. Default: "true"
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
--epoch_num Epoch number. Default: 1
--train_data_shuffle Enable train data shuffle. Default: "true"
--eval_data_shuffle Enable eval data shuffle. Default: "false"
--generate_length_dynamically Enable generate_length_Dynamically. Default: "true"
--save_finetune_ckpt_path Save the checkpoint path
--load_pretrain_ckpt_path Load the checkpoint file path
--load_finetune_ckpt_path Load the checkpoint file path
--train_data_file_path Data path, it is better to use absolute path
--eval_data_file_path Data path, it is better to use absolute path
--tokenizer_file_path pretrained vocab and merge file path
--stop_word_file_path The stop word file path
```
### Reading Comprehension任务
#### 评估
GPT-2模型可以在`CoQA`开发集上进行对应的评估针对以上数据集的评估其评估方法采用F1即设置`--metric_method="F1"` 。
评估模型时只需要使用shell脚本`run_read_comprehension.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_read_comprehension.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_read_comprehension.sh`这个shell脚本且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
```bash
sh scripts/run_read_comprehension.sh --device_target="Ascend"
--metric_method="F1"
--do_train="false"
--do_eval="true"
--eval_type="zero-shot"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
--tokenizer_file_path={tokenizer_file_path}
--generate_length=55
--top_k=1
--top_p="1.0"
--temperature="1.0"
```
日志和输出文件可以在`./ms_log/`路径下获取。而后将得到的日志文件作为`eval_rc_addition_answer.py`脚本的`input_file`同时将原CoQA开发集`coqa-dev-v1.0.json`作为`addition_file`。
执行`python eval_rc_addition_answer.py --input_file={path} --addition_file={path}`得到最终的F1值。
```bash
sh scripts/run_read_comprehension.sh [--options]
```
```text
usage: run_read_comprehension.sh [--device_target DEVICE_TARGET] [--device_id N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
[--tokenizer_file_path TOKENIZER_FILE_PATH]
[--generate_length N] [--top_k N] [--top_p TOP_P]
[--temperature TEMPERATURE]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--metric_method The eval method including [F1]. Default: "F1"
--do_train Enable train. Default: "false"
--do_eval Enable evaluation. Default: "false"
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
--epoch_num Epoch number. Default: 1
--train_data_shuffle Enable train data shuffle. Default: "true"
--eval_data_shuffle Enable eval data shuffle. Default: "false"
--save_finetune_ckpt_path Save the checkpoint path
--load_pretrain_ckpt_path Load the checkpoint file path
--load_finetune_ckpt_path Load the checkpoint file path
--train_data_file_path Data path, it is better to use absolute path
--eval_data_file_path Data path, it is better to use absolute path
--tokenizer_file_path pretrained vocab and merge file path
--generate_length The generation length of answer sentence
--top_k Parameter for Top-K sampling
--top_p Parameter for Top-P sampling
--temperature Parameter for generation, greater if generation more diverse
```
### Summarization任务
#### 评估
GPT-2模型可以在`CNN_Dailymail`开发集上进行对应的评估针对以上数据集的评估其评估方法采用F1即设置`--metric_method="ROUGE"` 。
评估模型时只需要使用shell脚本`run_summarization.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_summarization.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_summarization.sh`这个shell脚本且对于`hint`的情况设置`eval_type="finetuned"``--load_finetune_ckpt_path`是需要加载微调好的checkpoint文件而对于`no hint`的情况设置`eval_type="zero-shot"`除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
```bash
sh scripts/run_summarization.sh --device_target="Ascend"
--do_train="false"
--do_eval="true"
--metric_method="Rouge"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--generate_length=100
--top_k=2
--top_p="1.0"
--temperature="1.0"
--eval_type="finetuned"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
--tokenizer_file_path={tokenizer_file_path}
```
日志和输出文件可以在`./ms_log/`路径下获取。
```bash
sh scripts/run_summarization.sh [--options]
```
`run_summarization.sh`的用法如下:
```text
usage: run_summarization.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--do_train Enable train. Default: false.
--do_eval Enable evaluation. Default: false.
--metric_method The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge. Default: "false"
--epoch_num Epoch number. Default: 2.
--train_data_shuffle Enable train data shuffle. Default: true.
--eval_data_shuffle Enable eval data shuffle. Default: false.
--save_finetune_ckpt_path Save the checkpoint path.
--load_pretrain_ckpt_path Load the checkpoint file path.
--load_finetune_ckpt_path Load the checkpoint file path.
--train_data_file_path Data path, it is better to use absolute path.
--eval_data_file_path Data path, it is better to use absolute path.
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.
--top_k Top k tokens chosen for sampling.
--top_p Top p accumulated probability threshold for logit to be counted.
--generate_length The number of generated tokens.
--temperature Temperature on logits for sampling.
--tokenizer_file_path Vocab & merge file path.
```
### Translation任务
#### 评估
GPT-2模型可以在`WMT14 En-Fr`和`WMT14 Fr-En`测试集上进行对应的评估针对以上数据集的评估其评估方法采用BLEU即设置`--metric_method="BLEU"` 。
注:读者需要自行下载`bleu.py`脚本[脚本链接](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py), 而后将该脚本放置于`src/utils/`目录下
评估模型时只需要使用shell脚本`run_translation.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_translation.py`脚本。
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_translation.sh`这个shell脚本且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
```bash
sh scripts/run_translation.sh --device_target="Ascend"
--metric_method="BLEU"
--do_train="false"
--do_eval="true"
--eval_type="zero-shot"
--train_data_shuffle="true"
--eval_data_shuffle="false"
--load_finetune_ckpt_path={load_eval_ckpt_path}
--eval_data_file_path={eval_data_file_path}
--tokenizer_file_path={tokenizer_file_path}
--generate_length=100
--top_k=1
--top_p="1.0"
--temperature="1.0"
```
```bash
sh scripts/run_translation.sh [--options]
```
```text
usage: run_translation.sh [--device_target DEVICE_TARGET] [--device_id N]
[--metric_method METRIC_METHOD]
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
[--eval_type EVAL_TYPE] [--epoch_num N]
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
[--train_data_file_path TRAIN_DATA_FILE_PATH]
[--eval_data_file_path EVAL_DATA_FILE_PATH]
[--tokenizer_file_path TOKENIZER_FILE_PATH]
[--generate_length N] [--top_k N] [--top_p TOP_P]
[--temperature TEMPERATURE]
options:
--device_target Device type. Default: "Ascend"
--device_id ID of target device
--metric_method The eval method including [BLEU]. Default: "BLEU"
--do_train Enable train. Default: "false"
--do_eval Enable evaluation. Default: "true"
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
--epoch_num Epoch number. Default: 1
--train_data_shuffle Enable train data shuffle. Default: "true"
--eval_data_shuffle Enable eval data shuffle. Default: "false"
--save_finetune_ckpt_path Save the checkpoint path
--load_pretrain_ckpt_path Load the checkpoint file path
--load_finetune_ckpt_path Load the checkpoint file path
--train_data_file_path Data path, it is better to use absolute path
--eval_data_file_path Data path, it is better to use absolute path
--tokenizer_file_path pretrained vocab and merge file path
--generate_length The generation length of translation sentence
--top_k Parameter for Top-K sampling
--top_p Parameter for Top-P sampling
--temperature Parameter for generation, greater if generation more diverse
```
# 环境要求
## 平台
- 硬件Ascend
- 使用Ascend处理器准备硬件环境。- 如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com申请通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 更多关于Mindspore的信息请查看以下资源
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
## 其他要求
```text
math
numpy
copy
collections
re
rouge 1.0.0
datasets >=0.4.0
json
tensorflow
```
# 性能
## 推理性能
### Language Modeling任务
下表展示了GPT-2 small、medium、large三种规模的模型在Language Modeling任务中的PPL得分情况。
| 模型 | dataset | device | eval_type | PPL | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | WikiText2 | Ascend | zero-shot | 24.5 | 29.41 |
| GPT-2 medium | WikiText2 | Ascend | zero-shot | 19.41 | 22.76 |
| GPT-2 large | WikiText2 | Ascend | zero-shot | 17.08 | 19.93 |
| GPT-2 small | WikiText103 | Ascend | zero-shot | 26.89 | 37.5 |
| GPT-2 medium | WikiText103 | Ascend | zero-shot | 20.23 | 26.37 |
| GPT-2 large | WikiText103 | Ascend | zero-shot | 17.48 | 22.05 |
| GPT-2 small | PTB | Ascend | finetune | 23.91 | 65.85 |
| GPT-2 medium | PTB | Ascend | finetune | 20.06 | 47.33 |
| GPT-2 large | PTB | Ascend | finetune | 18.84 | 40.31 |
| GPT-2 small | 1BW | Ascend | zero-shot | 63.13 | 75.2 |
| GPT-2 medium | 1BW | Ascend | zero-shot | 50.98 | 55.72 |
| GPT-2 large | 1BW | Ascend | finetune | 29.28 | 44.575 |
### Children's Book Test 任务
下表展示了GPT-2 small、medium、large三种规模的模型在Children's Book Test 任务中的Accuracy得分情况。
| 模型 | dataset | device | eval_type | ACC | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | CBT-CN valid | Ascend | zero-shot | 87.85 | 87.65 |
| GPT-2 medium | CBT-CN valid | Ascend | zero-shot | 92.1 | 92.35 |
| GPT-2 large | CBT-CN valid | Ascend | zero-shot | 93.7 | 93.45 |
| GPT-2 small | CBT-NE valid | Ascend | zero-shot | 85.1 | 83.4 |
| GPT-2 medium | CBT-NE valid | Ascend | zero-shot | 87.55 | 87.1 |
| GPT-2 large | CBT-NE valid | Ascend | zero-shot | 89.1 | 88 |
### LAMBADA 任务
下表展示了GPT-2 small、medium、large三种规模的模型在LAMBADA 任务中的Accuracy和PPL得分情况。
| 模型 | dataset | device | eval_type | ACC | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | Lambada-test | Ascend | zero-shot | 45.99 | 45.99 |
| GPT-2 medium | Lambada-test | Ascend | zero-shot | 58.59 | 55.48 |
| GPT-2 large | Lambada-test | Ascend | zero-shot | 62.74 | 60.12 |
| 模型 | dataset | device | eval_type | PPL | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | Lambada-test | Ascend | zero-shot | 22.95 | 35.13 |
| GPT-2 medium | Lambada-test | Ascend | zero-shot | 10.69 | 15.6 |
| GPT-2 large | Lambada-test | Ascend | zero-shot | 8.64 | 10.87 |
### Reading Comprehension 任务
下表展示了GPT-2 small、medium、large三种规模的模型在Reading Comprehension任务中的F1得分情况。
| 模型 | dataset | device | eval_type | F1 | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | CoQA | Ascend | zero-shot | 25.94 | 25~26 |
| GPT-2 medium | CoQA | Ascend | zero-shot | 43.69 | 42~43 |
| GPT-2 large | CoQA | Ascend | zero-shot | 49.39 | 49~51 |
### Summarization 任务
下表展示了GPT-2 small、medium、large三种规模的模型在Summarization任务中的ROUGE得分情况。
| 模型 | dataset | device | eval_type | ROUGE | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | CNN_Dailymail(TL;DR) | Ascend | finetune | 21.4 | 16.8~17 |
| GPT-2 medium | CNN_Dailymail(TL;DR) | Ascend | finetune | 25.94 | 20.6~20.9 |
| GPT-2 large | CNN_Dailymail(TL;DR) | Ascend | finetune | 26.73 | 21.5~21.6 |
| 模型 | dataset | device | eval_type | ROUGE | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.08 | 15.03(xlarge) |
| GPT-2 medium | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.16 | 15.03(xlarge) |
| GPT-2 large | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.29 | 15.03(xlarge) |
### Translation 任务
下表展示了GPT-2 small、medium、large三种规模的模型在Translation任务中的BLEU得分情况。
| 模型 | dataset | device | eval_type | BLEU | OpenAI |
| :--- | :------ | :------ | :------ | :------ | :------ |
| GPT-2 small | WMT-14 Fr-En | Ascend | zero-shot | 4.49 | 0.7~0.8 |
| GPT-2 medium | WMT-14 Fr-En | Ascend | zero-shot | 7.09 | 2.0~3.0 |
| GPT-2 large | WMT-14 Fr-En | Ascend | zero-shot | 7.97 | 6.5~7.0 |
| GPT-2 small | WMT-14 En-Fr | Ascend | zero-shot | 2.81 | 5(xlarge) |
| GPT-2 medium | WMT-14 En-Fr | Ascend | zero-shot | 3.2 | 5(xlarge) |
| GPT-2 large | WMT-14 En-Fr | Ascend | zero-shot | 3.06 | 5(xlarge) |
# 其他
该模型已在Ascend环境下环境下得到验证。
# ModelZoo主页
[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

View File

@ -0,0 +1,141 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CNN & DailyMail train dataset sampler
"""
import os
import sys
import shutil
import argparse
from random import random
from src.utils.tokenization import Tokenizer
def replace_split_word(read_path, output_path, tldr_str="TL;DR:", original_split='\t'):
"""
append tldr str
"""
with open(read_path, "r") as r, open(output_path, "a") as w:
line = r.readline()
while line:
article = line[:line.find(original_split)] + ' ' + tldr_str + ' '
ref = line[line.rfind(original_split) + 1:]
w.write(article + ref)
line = r.readline()
def sample(read_path, out_path, threshold=1.0, max_items=0xFFFFFFF):
"""
sample function
"""
cnt = 0
total_cnt = 0
with open(read_path, "r") as r, open(out_path, "a") as w:
line = r.readline()
while line:
total_cnt += 1
if cnt >= max_items:
break
if random() > threshold:
line = r.readline()
continue
w.write(line)
if (cnt + 1) % 3000 == 0:
print("Now Processed Samples: {}, total: {}".format(cnt, total_cnt))
cnt += 1
line = r.readline()
def clip_article(input_path, out_path, hint, max_length):
"""
clip article that the sample (article + summary) exceed max_length
"""
tokenizer = Tokenizer()
cnt = 0
with open(input_path, "r") as r, open(out_path, "a+") as a:
line = r.readline()
while line:
pos = line.rfind(hint)
article = line[:pos]
summary = line[pos:]
if len(tokenizer.encode(line)) > max_length:
l_article = tokenizer.encode(article)[:max_length - len(tokenizer.encode(summary))]
article = tokenizer.decode(l_article) + " "
if cnt % 1000 == 0:
print(article + summary)
print("==============================")
cnt += 1
a.write(article + summary)
line = r.readline()
def sampler_dataset():
"""
run CNN & DailyMail train dataset sampler
"""
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, default="",
help="input file path")
parser.add_argument("--output_path", type=str, default="",
help="out file path")
parser.add_argument("--replace_hint", type=str, default="true")
parser.add_argument("--sample", type=str, default="true",
help="do sample? true or false")
parser.add_argument("--max_length", type=int, default=1022,
help="max seq_length of input_raw_dataset")
parser.add_argument("--prob", type=float, default=0.25,
help="sample rate")
parser.add_argument("--max_items", type=int, default=10000,
help="max number of document")
parser.add_argument("--hint", type=str, default="TL:DR;",
help="hint text")
args = parser.parse_args()
# temp_files, one for storing inputs in every stage, the other for storing middle results.
temp_file_input = sys.path[0] + '/temp_file1_by_sampler_py.txt'
temp_file_proc = sys.path[0] + '/temp_file2_by_sampler_py.txt'
read_path = args.input_path
output_path = args.output_path
prob = args.prob
max_items = args.max_items
hint = args.hint
max_length = args.max_length
split_str = '\t'
shutil.copyfile(read_path, temp_file_input)
clip_article(temp_file_input, temp_file_proc, hint=split_str, max_length=max_length)
shutil.copyfile(temp_file_proc, temp_file_input)
os.remove(temp_file_proc)
if args.replace_hint.lower() == "true":
replace_split_word(temp_file_input, temp_file_proc, hint, split_str)
shutil.copyfile(temp_file_proc, temp_file_input)
os.remove(temp_file_proc)
if args.sample.lower() == "true":
sample(temp_file_input, temp_file_proc, prob, max_items)
shutil.copyfile(temp_file_proc, temp_file_input)
os.remove(temp_file_proc)
shutil.copyfile(temp_file_input, output_path)
os.remove(temp_file_input)
if __name__ == "__main__":
sampler_dataset()

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Read weight using tensorflow
to read the parameters of the gpt-2 pretrained model from tensorflow checkpoint
and save them into npy files for mindspore to load.
*this script is based on the gpt-2 model downloaded from openai.*
"""
import argparse
import tensorflow as tf
import numpy as np
from .trans_dict import trans_dict_tf
def read_weight(ckpt_path):
"""
read weight
Args:
ckpt_path: the path of tensorflow checkpoint
"""
# model path and model name
init_vars = tf.train.list_variables(ckpt_path)
# load the model parameters into vars
save_param_num = 0
for name, _ in init_vars:
array = tf.train.load_variable(ckpt_path, name)
# By this you can understand the next step easily
name = name[6:].replace(r"/", ".")
# skip 'model/' and change var names to avoid path mistake
if name not in trans_dict_tf.keys():
print(name + " is not in this model")
else:
np.save(trans_dict_tf[name] + ".npy", array)
save_param_num = save_param_num + 1
# save the parameters by 'npy'
print("finished!")
print("save {num} parameters.".format(num=save_param_num))
def main():
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
parser.add_argument("--ckpt_file_path", type=str, default="",
help="The tensorflow GPT-2 model checkpoint file path")
args_opt = parser.parse_args()
ckpt_path = args_opt.ckpt_file_path
read_weight(ckpt_path=ckpt_path)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Save weight using mindspore, to load the parameters of gpt-2 model from npy file.
npy files should be in the same path with this script. Otherwise you should change the path name of the script.
"""
import os
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint
from .trans_dict import trans_dict_tf
def trans_model_parameter(ckpt_name):
"""
transform model parameters
Args:
ckpt_name (str): the name of the transformed checkpoint.
"""
file_names = [name for name in os.listdir() if name.endswith(".npy")]
# to find all file names with suffix '.npy' in the current path.
new_params_list = []
for file_name in file_names:
var_name = file_name[:-4]
param_dict = {"name": var_name, "data": Tensor(np.load(file_name))}
if var_name in trans_dict_tf.values():
new_params_list.append(param_dict)
print(var_name+" has been saved")
save_checkpoint(new_params_list, ckpt_name)
# to load the parameters from npy files and save them as mindspore checkpoint
print("Finished:the parameters have been saved into mindspore checkpoint.")
def main():
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
parser.add_argument("--output_file_name", type=str, default="",
help="The name of output checkpoint name")
args_opt = parser.parse_args()
ckpt_name = args_opt.output_file_name
trans_model_parameter(ckpt_name=ckpt_name)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,892 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""transform diction"""
trans_dict_tf = {
'h0.attn.c_attn.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h0.attn.c_attn.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h0.attn.c_proj.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h0.attn.c_proj.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h0.ln_1.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h0.ln_1.g': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h0.ln_2.b': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta',
'h0.ln_2.g': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma',
'h0.mlp.c_fc.b': 'gpt2_decoder.layers.0.feedforward.c_fc.bias',
'h0.mlp.c_fc.w': 'gpt2_decoder.layers.0.feedforward.c_fc.weight',
'h0.mlp.c_proj.b': 'gpt2_decoder.layers.0.feedforward.c_proj.bias',
'h0.mlp.c_proj.w': 'gpt2_decoder.layers.0.feedforward.c_proj.weight',
'h1.attn.c_attn.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h1.attn.c_attn.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h1.attn.c_proj.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h1.attn.c_proj.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h1.ln_1.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h1.ln_1.g': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h1.ln_2.b': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta',
'h1.ln_2.g': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma',
'h1.mlp.c_fc.b': 'gpt2_decoder.layers.1.feedforward.c_fc.bias',
'h1.mlp.c_fc.w': 'gpt2_decoder.layers.1.feedforward.c_fc.weight',
'h1.mlp.c_proj.b': 'gpt2_decoder.layers.1.feedforward.c_proj.bias',
'h1.mlp.c_proj.w': 'gpt2_decoder.layers.1.feedforward.c_proj.weight',
'h2.attn.c_attn.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h2.attn.c_attn.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h2.attn.c_proj.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h2.attn.c_proj.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h2.ln_1.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h2.ln_1.g': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h2.ln_2.b': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta',
'h2.ln_2.g': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma',
'h2.mlp.c_fc.b': 'gpt2_decoder.layers.2.feedforward.c_fc.bias',
'h2.mlp.c_fc.w': 'gpt2_decoder.layers.2.feedforward.c_fc.weight',
'h2.mlp.c_proj.b': 'gpt2_decoder.layers.2.feedforward.c_proj.bias',
'h2.mlp.c_proj.w': 'gpt2_decoder.layers.2.feedforward.c_proj.weight',
'h3.attn.c_attn.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h3.attn.c_attn.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h3.attn.c_proj.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h3.attn.c_proj.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h3.ln_1.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h3.ln_1.g': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h3.ln_2.b': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta',
'h3.ln_2.g': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma',
'h3.mlp.c_fc.b': 'gpt2_decoder.layers.3.feedforward.c_fc.bias',
'h3.mlp.c_fc.w': 'gpt2_decoder.layers.3.feedforward.c_fc.weight',
'h3.mlp.c_proj.b': 'gpt2_decoder.layers.3.feedforward.c_proj.bias',
'h3.mlp.c_proj.w': 'gpt2_decoder.layers.3.feedforward.c_proj.weight',
'h4.attn.c_attn.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h4.attn.c_attn.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h4.attn.c_proj.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h4.attn.c_proj.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h4.ln_1.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h4.ln_1.g': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h4.ln_2.b': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta',
'h4.ln_2.g': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma',
'h4.mlp.c_fc.b': 'gpt2_decoder.layers.4.feedforward.c_fc.bias',
'h4.mlp.c_fc.w': 'gpt2_decoder.layers.4.feedforward.c_fc.weight',
'h4.mlp.c_proj.b': 'gpt2_decoder.layers.4.feedforward.c_proj.bias',
'h4.mlp.c_proj.w': 'gpt2_decoder.layers.4.feedforward.c_proj.weight',
'h5.attn.c_attn.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h5.attn.c_attn.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h5.attn.c_proj.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h5.attn.c_proj.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h5.ln_1.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h5.ln_1.g': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h5.ln_2.b': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta',
'h5.ln_2.g': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma',
'h5.mlp.c_fc.b': 'gpt2_decoder.layers.5.feedforward.c_fc.bias',
'h5.mlp.c_fc.w': 'gpt2_decoder.layers.5.feedforward.c_fc.weight',
'h5.mlp.c_proj.b': 'gpt2_decoder.layers.5.feedforward.c_proj.bias',
'h5.mlp.c_proj.w': 'gpt2_decoder.layers.5.feedforward.c_proj.weight',
'h6.attn.c_attn.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h6.attn.c_attn.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h6.attn.c_proj.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h6.attn.c_proj.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h6.ln_1.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h6.ln_1.g': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h6.ln_2.b': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta',
'h6.ln_2.g': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma',
'h6.mlp.c_fc.b': 'gpt2_decoder.layers.6.feedforward.c_fc.bias',
'h6.mlp.c_fc.w': 'gpt2_decoder.layers.6.feedforward.c_fc.weight',
'h6.mlp.c_proj.b': 'gpt2_decoder.layers.6.feedforward.c_proj.bias',
'h6.mlp.c_proj.w': 'gpt2_decoder.layers.6.feedforward.c_proj.weight',
'h7.attn.c_attn.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h7.attn.c_attn.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h7.attn.c_proj.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h7.attn.c_proj.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h7.ln_1.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h7.ln_1.g': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h7.ln_2.b': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta',
'h7.ln_2.g': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma',
'h7.mlp.c_fc.b': 'gpt2_decoder.layers.7.feedforward.c_fc.bias',
'h7.mlp.c_fc.w': 'gpt2_decoder.layers.7.feedforward.c_fc.weight',
'h7.mlp.c_proj.b': 'gpt2_decoder.layers.7.feedforward.c_proj.bias',
'h7.mlp.c_proj.w': 'gpt2_decoder.layers.7.feedforward.c_proj.weight',
'h8.attn.c_attn.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h8.attn.c_attn.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h8.attn.c_proj.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h8.attn.c_proj.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h8.ln_1.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h8.ln_1.g': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h8.ln_2.b': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta',
'h8.ln_2.g': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma',
'h8.mlp.c_fc.b': 'gpt2_decoder.layers.8.feedforward.c_fc.bias',
'h8.mlp.c_fc.w': 'gpt2_decoder.layers.8.feedforward.c_fc.weight',
'h8.mlp.c_proj.b': 'gpt2_decoder.layers.8.feedforward.c_proj.bias',
'h8.mlp.c_proj.w': 'gpt2_decoder.layers.8.feedforward.c_proj.weight',
'h9.attn.c_attn.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h9.attn.c_attn.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h9.attn.c_proj.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h9.attn.c_proj.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h9.ln_1.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h9.ln_1.g': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h9.ln_2.b': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta',
'h9.ln_2.g': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma',
'h9.mlp.c_fc.b': 'gpt2_decoder.layers.9.feedforward.c_fc.bias',
'h9.mlp.c_fc.w': 'gpt2_decoder.layers.9.feedforward.c_fc.weight',
'h9.mlp.c_proj.b': 'gpt2_decoder.layers.9.feedforward.c_proj.bias',
'h9.mlp.c_proj.w': 'gpt2_decoder.layers.9.feedforward.c_proj.weight',
'h10.attn.c_attn.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h10.attn.c_attn.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h10.attn.c_proj.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h10.attn.c_proj.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h10.ln_1.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h10.ln_1.g': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h10.ln_2.b': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta',
'h10.ln_2.g': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma',
'h10.mlp.c_fc.b': 'gpt2_decoder.layers.10.feedforward.c_fc.bias',
'h10.mlp.c_fc.w': 'gpt2_decoder.layers.10.feedforward.c_fc.weight',
'h10.mlp.c_proj.b': 'gpt2_decoder.layers.10.feedforward.c_proj.bias',
'h10.mlp.c_proj.w': 'gpt2_decoder.layers.10.feedforward.c_proj.weight',
'h11.attn.c_attn.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h11.attn.c_attn.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h11.attn.c_proj.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h11.attn.c_proj.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h11.ln_1.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h11.ln_1.g': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h11.ln_2.b': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta',
'h11.ln_2.g': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma',
'h11.mlp.c_fc.b': 'gpt2_decoder.layers.11.feedforward.c_fc.bias',
'h11.mlp.c_fc.w': 'gpt2_decoder.layers.11.feedforward.c_fc.weight',
'h11.mlp.c_proj.b': 'gpt2_decoder.layers.11.feedforward.c_proj.bias',
'h11.mlp.c_proj.w': 'gpt2_decoder.layers.11.feedforward.c_proj.weight',
'h12.attn.c_attn.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h12.attn.c_attn.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h12.attn.c_proj.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h12.attn.c_proj.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h12.ln_1.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h12.ln_1.g': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h12.ln_2.b': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta',
'h12.ln_2.g': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma',
'h12.mlp.c_fc.b': 'gpt2_decoder.layers.12.feedforward.c_fc.bias',
'h12.mlp.c_fc.w': 'gpt2_decoder.layers.12.feedforward.c_fc.weight',
'h12.mlp.c_proj.b': 'gpt2_decoder.layers.12.feedforward.c_proj.bias',
'h12.mlp.c_proj.w': 'gpt2_decoder.layers.12.feedforward.c_proj.weight',
'h13.attn.c_attn.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h13.attn.c_attn.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h13.attn.c_proj.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h13.attn.c_proj.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h13.ln_1.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h13.ln_1.g': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h13.ln_2.b': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta',
'h13.ln_2.g': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma',
'h13.mlp.c_fc.b': 'gpt2_decoder.layers.13.feedforward.c_fc.bias',
'h13.mlp.c_fc.w': 'gpt2_decoder.layers.13.feedforward.c_fc.weight',
'h13.mlp.c_proj.b': 'gpt2_decoder.layers.13.feedforward.c_proj.bias',
'h13.mlp.c_proj.w': 'gpt2_decoder.layers.13.feedforward.c_proj.weight',
'h14.attn.c_attn.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h14.attn.c_attn.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h14.attn.c_proj.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h14.attn.c_proj.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h14.ln_1.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h14.ln_1.g': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h14.ln_2.b': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta',
'h14.ln_2.g': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma',
'h14.mlp.c_fc.b': 'gpt2_decoder.layers.14.feedforward.c_fc.bias',
'h14.mlp.c_fc.w': 'gpt2_decoder.layers.14.feedforward.c_fc.weight',
'h14.mlp.c_proj.b': 'gpt2_decoder.layers.14.feedforward.c_proj.bias',
'h14.mlp.c_proj.w': 'gpt2_decoder.layers.14.feedforward.c_proj.weight',
'h15.attn.c_attn.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h15.attn.c_attn.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h15.attn.c_proj.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h15.attn.c_proj.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h15.ln_1.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h15.ln_1.g': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h15.ln_2.b': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta',
'h15.ln_2.g': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma',
'h15.mlp.c_fc.b': 'gpt2_decoder.layers.15.feedforward.c_fc.bias',
'h15.mlp.c_fc.w': 'gpt2_decoder.layers.15.feedforward.c_fc.weight',
'h15.mlp.c_proj.b': 'gpt2_decoder.layers.15.feedforward.c_proj.bias',
'h15.mlp.c_proj.w': 'gpt2_decoder.layers.15.feedforward.c_proj.weight',
'h16.attn.c_attn.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h16.attn.c_attn.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h16.attn.c_proj.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h16.attn.c_proj.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h16.ln_1.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h16.ln_1.g': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h16.ln_2.b': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta',
'h16.ln_2.g': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma',
'h16.mlp.c_fc.b': 'gpt2_decoder.layers.16.feedforward.c_fc.bias',
'h16.mlp.c_fc.w': 'gpt2_decoder.layers.16.feedforward.c_fc.weight',
'h16.mlp.c_proj.b': 'gpt2_decoder.layers.16.feedforward.c_proj.bias',
'h16.mlp.c_proj.w': 'gpt2_decoder.layers.16.feedforward.c_proj.weight',
'h17.attn.c_attn.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h17.attn.c_attn.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h17.attn.c_proj.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h17.attn.c_proj.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h17.ln_1.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h17.ln_1.g': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h17.ln_2.b': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta',
'h17.ln_2.g': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma',
'h17.mlp.c_fc.b': 'gpt2_decoder.layers.17.feedforward.c_fc.bias',
'h17.mlp.c_fc.w': 'gpt2_decoder.layers.17.feedforward.c_fc.weight',
'h17.mlp.c_proj.b': 'gpt2_decoder.layers.17.feedforward.c_proj.bias',
'h17.mlp.c_proj.w': 'gpt2_decoder.layers.17.feedforward.c_proj.weight',
'h18.attn.c_attn.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h18.attn.c_attn.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h18.attn.c_proj.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h18.attn.c_proj.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h18.ln_1.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h18.ln_1.g': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h18.ln_2.b': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta',
'h18.ln_2.g': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma',
'h18.mlp.c_fc.b': 'gpt2_decoder.layers.18.feedforward.c_fc.bias',
'h18.mlp.c_fc.w': 'gpt2_decoder.layers.18.feedforward.c_fc.weight',
'h18.mlp.c_proj.b': 'gpt2_decoder.layers.18.feedforward.c_proj.bias',
'h18.mlp.c_proj.w': 'gpt2_decoder.layers.18.feedforward.c_proj.weight',
'h19.attn.c_attn.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h19.attn.c_attn.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h19.attn.c_proj.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h19.attn.c_proj.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h19.ln_1.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h19.ln_1.g': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h19.ln_2.b': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta',
'h19.ln_2.g': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma',
'h19.mlp.c_fc.b': 'gpt2_decoder.layers.19.feedforward.c_fc.bias',
'h19.mlp.c_fc.w': 'gpt2_decoder.layers.19.feedforward.c_fc.weight',
'h19.mlp.c_proj.b': 'gpt2_decoder.layers.19.feedforward.c_proj.bias',
'h19.mlp.c_proj.w': 'gpt2_decoder.layers.19.feedforward.c_proj.weight',
'h20.attn.c_attn.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h20.attn.c_attn.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h20.attn.c_proj.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h20.attn.c_proj.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h20.ln_1.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h20.ln_1.g': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h20.ln_2.b': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta',
'h20.ln_2.g': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma',
'h20.mlp.c_fc.b': 'gpt2_decoder.layers.20.feedforward.c_fc.bias',
'h20.mlp.c_fc.w': 'gpt2_decoder.layers.20.feedforward.c_fc.weight',
'h20.mlp.c_proj.b': 'gpt2_decoder.layers.20.feedforward.c_proj.bias',
'h20.mlp.c_proj.w': 'gpt2_decoder.layers.20.feedforward.c_proj.weight',
'h21.attn.c_attn.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h21.attn.c_attn.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h21.attn.c_proj.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h21.attn.c_proj.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h21.ln_1.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h21.ln_1.g': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h21.ln_2.b': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta',
'h21.ln_2.g': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma',
'h21.mlp.c_fc.b': 'gpt2_decoder.layers.21.feedforward.c_fc.bias',
'h21.mlp.c_fc.w': 'gpt2_decoder.layers.21.feedforward.c_fc.weight',
'h21.mlp.c_proj.b': 'gpt2_decoder.layers.21.feedforward.c_proj.bias',
'h21.mlp.c_proj.w': 'gpt2_decoder.layers.21.feedforward.c_proj.weight',
'h22.attn.c_attn.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h22.attn.c_attn.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h22.attn.c_proj.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h22.attn.c_proj.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h22.ln_1.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h22.ln_1.g': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h22.ln_2.b': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta',
'h22.ln_2.g': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma',
'h22.mlp.c_fc.b': 'gpt2_decoder.layers.22.feedforward.c_fc.bias',
'h22.mlp.c_fc.w': 'gpt2_decoder.layers.22.feedforward.c_fc.weight',
'h22.mlp.c_proj.b': 'gpt2_decoder.layers.22.feedforward.c_proj.bias',
'h22.mlp.c_proj.w': 'gpt2_decoder.layers.22.feedforward.c_proj.weight',
'h23.attn.c_attn.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h23.attn.c_attn.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h23.attn.c_proj.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h23.attn.c_proj.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h23.ln_1.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h23.ln_1.g': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h23.ln_2.b': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta',
'h23.ln_2.g': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma',
'h23.mlp.c_fc.b': 'gpt2_decoder.layers.23.feedforward.c_fc.bias',
'h23.mlp.c_fc.w': 'gpt2_decoder.layers.23.feedforward.c_fc.weight',
'h23.mlp.c_proj.b': 'gpt2_decoder.layers.23.feedforward.c_proj.bias',
'h23.mlp.c_proj.w': 'gpt2_decoder.layers.23.feedforward.c_proj.weight',
'h24.attn.c_attn.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h24.attn.c_attn.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h24.attn.c_proj.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h24.attn.c_proj.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h24.ln_1.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h24.ln_1.g': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h24.ln_2.b': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta',
'h24.ln_2.g': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma',
'h24.mlp.c_fc.b': 'gpt2_decoder.layers.24.feedforward.c_fc.bias',
'h24.mlp.c_fc.w': 'gpt2_decoder.layers.24.feedforward.c_fc.weight',
'h24.mlp.c_proj.b': 'gpt2_decoder.layers.24.feedforward.c_proj.bias',
'h24.mlp.c_proj.w': 'gpt2_decoder.layers.24.feedforward.c_proj.weight',
'h25.attn.c_attn.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h25.attn.c_attn.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h25.attn.c_proj.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h25.attn.c_proj.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h25.ln_1.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h25.ln_1.g': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h25.ln_2.b': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta',
'h25.ln_2.g': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma',
'h25.mlp.c_fc.b': 'gpt2_decoder.layers.25.feedforward.c_fc.bias',
'h25.mlp.c_fc.w': 'gpt2_decoder.layers.25.feedforward.c_fc.weight',
'h25.mlp.c_proj.b': 'gpt2_decoder.layers.25.feedforward.c_proj.bias',
'h25.mlp.c_proj.w': 'gpt2_decoder.layers.25.feedforward.c_proj.weight',
'h26.attn.c_attn.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h26.attn.c_attn.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h26.attn.c_proj.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h26.attn.c_proj.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h26.ln_1.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h26.ln_1.g': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h26.ln_2.b': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta',
'h26.ln_2.g': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma',
'h26.mlp.c_fc.b': 'gpt2_decoder.layers.26.feedforward.c_fc.bias',
'h26.mlp.c_fc.w': 'gpt2_decoder.layers.26.feedforward.c_fc.weight',
'h26.mlp.c_proj.b': 'gpt2_decoder.layers.26.feedforward.c_proj.bias',
'h26.mlp.c_proj.w': 'gpt2_decoder.layers.26.feedforward.c_proj.weight',
'h27.attn.c_attn.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h27.attn.c_attn.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h27.attn.c_proj.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h27.attn.c_proj.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h27.ln_1.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h27.ln_1.g': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h27.ln_2.b': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta',
'h27.ln_2.g': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma',
'h27.mlp.c_fc.b': 'gpt2_decoder.layers.27.feedforward.c_fc.bias',
'h27.mlp.c_fc.w': 'gpt2_decoder.layers.27.feedforward.c_fc.weight',
'h27.mlp.c_proj.b': 'gpt2_decoder.layers.27.feedforward.c_proj.bias',
'h27.mlp.c_proj.w': 'gpt2_decoder.layers.27.feedforward.c_proj.weight',
'h28.attn.c_attn.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h28.attn.c_attn.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h28.attn.c_proj.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h28.attn.c_proj.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h28.ln_1.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h28.ln_1.g': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h28.ln_2.b': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta',
'h28.ln_2.g': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma',
'h28.mlp.c_fc.b': 'gpt2_decoder.layers.28.feedforward.c_fc.bias',
'h28.mlp.c_fc.w': 'gpt2_decoder.layers.28.feedforward.c_fc.weight',
'h28.mlp.c_proj.b': 'gpt2_decoder.layers.28.feedforward.c_proj.bias',
'h28.mlp.c_proj.w': 'gpt2_decoder.layers.28.feedforward.c_proj.weight',
'h29.attn.c_attn.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h29.attn.c_attn.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h29.attn.c_proj.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h29.attn.c_proj.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h29.ln_1.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h29.ln_1.g': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h29.ln_2.b': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta',
'h29.ln_2.g': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma',
'h29.mlp.c_fc.b': 'gpt2_decoder.layers.29.feedforward.c_fc.bias',
'h29.mlp.c_fc.w': 'gpt2_decoder.layers.29.feedforward.c_fc.weight',
'h29.mlp.c_proj.b': 'gpt2_decoder.layers.29.feedforward.c_proj.bias',
'h29.mlp.c_proj.w': 'gpt2_decoder.layers.29.feedforward.c_proj.weight',
'h30.attn.c_attn.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h30.attn.c_attn.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h30.attn.c_proj.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h30.attn.c_proj.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h30.ln_1.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h30.ln_1.g': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h30.ln_2.b': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta',
'h30.ln_2.g': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma',
'h30.mlp.c_fc.b': 'gpt2_decoder.layers.30.feedforward.c_fc.bias',
'h30.mlp.c_fc.w': 'gpt2_decoder.layers.30.feedforward.c_fc.weight',
'h30.mlp.c_proj.b': 'gpt2_decoder.layers.30.feedforward.c_proj.bias',
'h30.mlp.c_proj.w': 'gpt2_decoder.layers.30.feedforward.c_proj.weight',
'h31.attn.c_attn.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h31.attn.c_attn.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h31.attn.c_proj.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h31.attn.c_proj.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h31.ln_1.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h31.ln_1.g': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h31.ln_2.b': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta',
'h31.ln_2.g': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma',
'h31.mlp.c_fc.b': 'gpt2_decoder.layers.31.feedforward.c_fc.bias',
'h31.mlp.c_fc.w': 'gpt2_decoder.layers.31.feedforward.c_fc.weight',
'h31.mlp.c_proj.b': 'gpt2_decoder.layers.31.feedforward.c_proj.bias',
'h31.mlp.c_proj.w': 'gpt2_decoder.layers.31.feedforward.c_proj.weight',
'h32.attn.c_attn.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h32.attn.c_attn.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h32.attn.c_proj.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h32.attn.c_proj.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h32.ln_1.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h32.ln_1.g': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h32.ln_2.b': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta',
'h32.ln_2.g': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma',
'h32.mlp.c_fc.b': 'gpt2_decoder.layers.32.feedforward.c_fc.bias',
'h32.mlp.c_fc.w': 'gpt2_decoder.layers.32.feedforward.c_fc.weight',
'h32.mlp.c_proj.b': 'gpt2_decoder.layers.32.feedforward.c_proj.bias',
'h32.mlp.c_proj.w': 'gpt2_decoder.layers.32.feedforward.c_proj.weight',
'h33.attn.c_attn.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h33.attn.c_attn.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h33.attn.c_proj.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h33.attn.c_proj.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h33.ln_1.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h33.ln_1.g': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h33.ln_2.b': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta',
'h33.ln_2.g': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma',
'h33.mlp.c_fc.b': 'gpt2_decoder.layers.33.feedforward.c_fc.bias',
'h33.mlp.c_fc.w': 'gpt2_decoder.layers.33.feedforward.c_fc.weight',
'h33.mlp.c_proj.b': 'gpt2_decoder.layers.33.feedforward.c_proj.bias',
'h33.mlp.c_proj.w': 'gpt2_decoder.layers.33.feedforward.c_proj.weight',
'h34.attn.c_attn.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h34.attn.c_attn.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h34.attn.c_proj.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h34.attn.c_proj.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h34.ln_1.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h34.ln_1.g': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h34.ln_2.b': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta',
'h34.ln_2.g': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma',
'h34.mlp.c_fc.b': 'gpt2_decoder.layers.34.feedforward.c_fc.bias',
'h34.mlp.c_fc.w': 'gpt2_decoder.layers.34.feedforward.c_fc.weight',
'h34.mlp.c_proj.b': 'gpt2_decoder.layers.34.feedforward.c_proj.bias',
'h34.mlp.c_proj.w': 'gpt2_decoder.layers.34.feedforward.c_proj.weight',
'h35.attn.c_attn.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h35.attn.c_attn.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h35.attn.c_proj.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h35.attn.c_proj.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h35.ln_1.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h35.ln_1.g': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h35.ln_2.b': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta',
'h35.ln_2.g': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma',
'h35.mlp.c_fc.b': 'gpt2_decoder.layers.35.feedforward.c_fc.bias',
'h35.mlp.c_fc.w': 'gpt2_decoder.layers.35.feedforward.c_fc.weight',
'h35.mlp.c_proj.b': 'gpt2_decoder.layers.35.feedforward.c_proj.bias',
'h35.mlp.c_proj.w': 'gpt2_decoder.layers.35.feedforward.c_proj.weight',
'ln_f.b': 'layer_norm.layer_norm.gamma',
'ln_f.g': 'layer_norm.layer_norm.beta',
'wpe': 'gpt2_embedding_postprocess.position_embedding_table',
'wte': 'gpt2_embedding_lookup.embedding_table'
} # transfer dictionary
trans_dict_py = {
'h.0.attn.c_attn.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.0.attn.c_attn.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.0.attn.c_proj.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.0.attn.c_proj.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.0.ln_1.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.0.ln_1.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.0.ln_2.bias': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta',
'h.0.ln_2.weight': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma',
'h.0.mlp.c_fc.bias': 'gpt2_decoder.layers.0.feedforward.c_fc.bias',
'h.0.mlp.c_fc.weight': 'gpt2_decoder.layers.0.feedforward.c_fc.weight',
'h.0.mlp.c_proj.bias': 'gpt2_decoder.layers.0.feedforward.c_proj.bias',
'h.0.mlp.c_proj.weight': 'gpt2_decoder.layers.0.feedforward.c_proj.weight',
'h.1.attn.c_attn.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.1.attn.c_attn.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.1.attn.c_proj.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.1.attn.c_proj.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.1.ln_1.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.1.ln_1.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.1.ln_2.bias': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta',
'h.1.ln_2.weight': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma',
'h.1.mlp.c_fc.bias': 'gpt2_decoder.layers.1.feedforward.c_fc.bias',
'h.1.mlp.c_fc.weight': 'gpt2_decoder.layers.1.feedforward.c_fc.weight',
'h.1.mlp.c_proj.bias': 'gpt2_decoder.layers.1.feedforward.c_proj.bias',
'h.1.mlp.c_proj.weight': 'gpt2_decoder.layers.1.feedforward.c_proj.weight',
'h.2.attn.c_attn.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.2.attn.c_attn.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.2.attn.c_proj.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.2.attn.c_proj.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.2.ln_1.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.2.ln_1.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.2.ln_2.bias': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta',
'h.2.ln_2.weight': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma',
'h.2.mlp.c_fc.bias': 'gpt2_decoder.layers.2.feedforward.c_fc.bias',
'h.2.mlp.c_fc.weight': 'gpt2_decoder.layers.2.feedforward.c_fc.weight',
'h.2.mlp.c_proj.bias': 'gpt2_decoder.layers.2.feedforward.c_proj.bias',
'h.2.mlp.c_proj.weight': 'gpt2_decoder.layers.2.feedforward.c_proj.weight',
'h.3.attn.c_attn.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.3.attn.c_attn.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.3.attn.c_proj.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.3.attn.c_proj.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.3.ln_1.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.3.ln_1.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.3.ln_2.bias': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta',
'h.3.ln_2.weight': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma',
'h.3.mlp.c_fc.bias': 'gpt2_decoder.layers.3.feedforward.c_fc.bias',
'h.3.mlp.c_fc.weight': 'gpt2_decoder.layers.3.feedforward.c_fc.weight',
'h.3.mlp.c_proj.bias': 'gpt2_decoder.layers.3.feedforward.c_proj.bias',
'h.3.mlp.c_proj.weight': 'gpt2_decoder.layers.3.feedforward.c_proj.weight',
'h.4.attn.c_attn.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.4.attn.c_attn.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.4.attn.c_proj.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.4.attn.c_proj.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.4.ln_1.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.4.ln_1.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.4.ln_2.bias': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta',
'h.4.ln_2.weight': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma',
'h.4.mlp.c_fc.bias': 'gpt2_decoder.layers.4.feedforward.c_fc.bias',
'h.4.mlp.c_fc.weight': 'gpt2_decoder.layers.4.feedforward.c_fc.weight',
'h.4.mlp.c_proj.bias': 'gpt2_decoder.layers.4.feedforward.c_proj.bias',
'h.4.mlp.c_proj.weight': 'gpt2_decoder.layers.4.feedforward.c_proj.weight',
'h.5.attn.c_attn.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.5.attn.c_attn.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.5.attn.c_proj.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.5.attn.c_proj.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.5.ln_1.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.5.ln_1.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.5.ln_2.bias': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta',
'h.5.ln_2.weight': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma',
'h.5.mlp.c_fc.bias': 'gpt2_decoder.layers.5.feedforward.c_fc.bias',
'h.5.mlp.c_fc.weight': 'gpt2_decoder.layers.5.feedforward.c_fc.weight',
'h.5.mlp.c_proj.bias': 'gpt2_decoder.layers.5.feedforward.c_proj.bias',
'h.5.mlp.c_proj.weight': 'gpt2_decoder.layers.5.feedforward.c_proj.weight',
'h.6.attn.c_attn.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.6.attn.c_attn.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.6.attn.c_proj.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.6.attn.c_proj.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.6.ln_1.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.6.ln_1.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.6.ln_2.bias': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta',
'h.6.ln_2.weight': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma',
'h.6.mlp.c_fc.bias': 'gpt2_decoder.layers.6.feedforward.c_fc.bias',
'h.6.mlp.c_fc.weight': 'gpt2_decoder.layers.6.feedforward.c_fc.weight',
'h.6.mlp.c_proj.bias': 'gpt2_decoder.layers.6.feedforward.c_proj.bias',
'h.6.mlp.c_proj.weight': 'gpt2_decoder.layers.6.feedforward.c_proj.weight',
'h.7.attn.c_attn.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.7.attn.c_attn.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.7.attn.c_proj.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.7.attn.c_proj.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.7.ln_1.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.7.ln_1.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.7.ln_2.bias': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta',
'h.7.ln_2.weight': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma',
'h.7.mlp.c_fc.bias': 'gpt2_decoder.layers.7.feedforward.c_fc.bias',
'h.7.mlp.c_fc.weight': 'gpt2_decoder.layers.7.feedforward.c_fc.weight',
'h.7.mlp.c_proj.bias': 'gpt2_decoder.layers.7.feedforward.c_proj.bias',
'h.7.mlp.c_proj.weight': 'gpt2_decoder.layers.7.feedforward.c_proj.weight',
'h.8.attn.c_attn.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.8.attn.c_attn.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.8.attn.c_proj.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.8.attn.c_proj.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.8.ln_1.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.8.ln_1.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.8.ln_2.bias': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta',
'h.8.ln_2.weight': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma',
'h.8.mlp.c_fc.bias': 'gpt2_decoder.layers.8.feedforward.c_fc.bias',
'h.8.mlp.c_fc.weight': 'gpt2_decoder.layers.8.feedforward.c_fc.weight',
'h.8.mlp.c_proj.bias': 'gpt2_decoder.layers.8.feedforward.c_proj.bias',
'h.8.mlp.c_proj.weight': 'gpt2_decoder.layers.8.feedforward.c_proj.weight',
'h.9.attn.c_attn.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.9.attn.c_attn.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.9.attn.c_proj.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.9.attn.c_proj.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.9.ln_1.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.9.ln_1.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.9.ln_2.bias': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta',
'h.9.ln_2.weight': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma',
'h.9.mlp.c_fc.bias': 'gpt2_decoder.layers.9.feedforward.c_fc.bias',
'h.9.mlp.c_fc.weight': 'gpt2_decoder.layers.9.feedforward.c_fc.weight',
'h.9.mlp.c_proj.bias': 'gpt2_decoder.layers.9.feedforward.c_proj.bias',
'h.9.mlp.c_proj.weight': 'gpt2_decoder.layers.9.feedforward.c_proj.weight',
'h.10.attn.c_attn.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.10.attn.c_attn.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.10.attn.c_proj.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.10.attn.c_proj.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.10.ln_1.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.10.ln_1.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.10.ln_2.bias': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta',
'h.10.ln_2.weight': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma',
'h.10.mlp.c_fc.bias': 'gpt2_decoder.layers.10.feedforward.c_fc.bias',
'h.10.mlp.c_fc.weight': 'gpt2_decoder.layers.10.feedforward.c_fc.weight',
'h.10.mlp.c_proj.bias': 'gpt2_decoder.layers.10.feedforward.c_proj.bias',
'h.10.mlp.c_proj.weight': 'gpt2_decoder.layers.10.feedforward.c_proj.weight',
'h.11.attn.c_attn.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.11.attn.c_attn.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.11.attn.c_proj.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.11.attn.c_proj.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.11.ln_1.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.11.ln_1.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.11.ln_2.bias': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta',
'h.11.ln_2.weight': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma',
'h.11.mlp.c_fc.bias': 'gpt2_decoder.layers.11.feedforward.c_fc.bias',
'h.11.mlp.c_fc.weight': 'gpt2_decoder.layers.11.feedforward.c_fc.weight',
'h.11.mlp.c_proj.bias': 'gpt2_decoder.layers.11.feedforward.c_proj.bias',
'h.11.mlp.c_proj.weight': 'gpt2_decoder.layers.11.feedforward.c_proj.weight',
'h.12.attn.c_attn.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.12.attn.c_attn.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.12.attn.c_proj.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.12.attn.c_proj.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.12.ln_1.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.12.ln_1.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.12.ln_2.bias': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta',
'h.12.ln_2.weight': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma',
'h.12.mlp.c_fc.bias': 'gpt2_decoder.layers.12.feedforward.c_fc.bias',
'h.12.mlp.c_fc.weight': 'gpt2_decoder.layers.12.feedforward.c_fc.weight',
'h.12.mlp.c_proj.bias': 'gpt2_decoder.layers.12.feedforward.c_proj.bias',
'h.12.mlp.c_proj.weight': 'gpt2_decoder.layers.12.feedforward.c_proj.weight',
'h.13.attn.c_attn.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.13.attn.c_attn.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.13.attn.c_proj.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.13.attn.c_proj.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.13.ln_1.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.13.ln_1.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.13.ln_2.bias': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta',
'h.13.ln_2.weight': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma',
'h.13.mlp.c_fc.bias': 'gpt2_decoder.layers.13.feedforward.c_fc.bias',
'h.13.mlp.c_fc.weight': 'gpt2_decoder.layers.13.feedforward.c_fc.weight',
'h.13.mlp.c_proj.bias': 'gpt2_decoder.layers.13.feedforward.c_proj.bias',
'h.13.mlp.c_proj.weight': 'gpt2_decoder.layers.13.feedforward.c_proj.weight',
'h.14.attn.c_attn.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.14.attn.c_attn.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.14.attn.c_proj.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.14.attn.c_proj.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.14.ln_1.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.14.ln_1.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.14.ln_2.bias': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta',
'h.14.ln_2.weight': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma',
'h.14.mlp.c_fc.bias': 'gpt2_decoder.layers.14.feedforward.c_fc.bias',
'h.14.mlp.c_fc.weight': 'gpt2_decoder.layers.14.feedforward.c_fc.weight',
'h.14.mlp.c_proj.bias': 'gpt2_decoder.layers.14.feedforward.c_proj.bias',
'h.14.mlp.c_proj.weight': 'gpt2_decoder.layers.14.feedforward.c_proj.weight',
'h.15.attn.c_attn.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.15.attn.c_attn.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.15.attn.c_proj.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.15.attn.c_proj.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.15.ln_1.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.15.ln_1.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.15.ln_2.bias': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta',
'h.15.ln_2.weight': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma',
'h.15.mlp.c_fc.bias': 'gpt2_decoder.layers.15.feedforward.c_fc.bias',
'h.15.mlp.c_fc.weight': 'gpt2_decoder.layers.15.feedforward.c_fc.weight',
'h.15.mlp.c_proj.bias': 'gpt2_decoder.layers.15.feedforward.c_proj.bias',
'h.15.mlp.c_proj.weight': 'gpt2_decoder.layers.15.feedforward.c_proj.weight',
'h.16.attn.c_attn.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.16.attn.c_attn.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.16.attn.c_proj.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.16.attn.c_proj.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.16.ln_1.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.16.ln_1.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.16.ln_2.bias': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta',
'h.16.ln_2.weight': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma',
'h.16.mlp.c_fc.bias': 'gpt2_decoder.layers.16.feedforward.c_fc.bias',
'h.16.mlp.c_fc.weight': 'gpt2_decoder.layers.16.feedforward.c_fc.weight',
'h.16.mlp.c_proj.bias': 'gpt2_decoder.layers.16.feedforward.c_proj.bias',
'h.16.mlp.c_proj.weight': 'gpt2_decoder.layers.16.feedforward.c_proj.weight',
'h.17.attn.c_attn.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.17.attn.c_attn.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.17.attn.c_proj.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.17.attn.c_proj.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.17.ln_1.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.17.ln_1.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.17.ln_2.bias': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta',
'h.17.ln_2.weight': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma',
'h.17.mlp.c_fc.bias': 'gpt2_decoder.layers.17.feedforward.c_fc.bias',
'h.17.mlp.c_fc.weight': 'gpt2_decoder.layers.17.feedforward.c_fc.weight',
'h.17.mlp.c_proj.bias': 'gpt2_decoder.layers.17.feedforward.c_proj.bias',
'h.17.mlp.c_proj.weight': 'gpt2_decoder.layers.17.feedforward.c_proj.weight',
'h.18.attn.c_attn.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.18.attn.c_attn.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.18.attn.c_proj.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.18.attn.c_proj.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.18.ln_1.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.18.ln_1.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.18.ln_2.bias': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta',
'h.18.ln_2.weight': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma',
'h.18.mlp.c_fc.bias': 'gpt2_decoder.layers.18.feedforward.c_fc.bias',
'h.18.mlp.c_fc.weight': 'gpt2_decoder.layers.18.feedforward.c_fc.weight',
'h.18.mlp.c_proj.bias': 'gpt2_decoder.layers.18.feedforward.c_proj.bias',
'h.18.mlp.c_proj.weight': 'gpt2_decoder.layers.18.feedforward.c_proj.weight',
'h.19.attn.c_attn.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.19.attn.c_attn.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.19.attn.c_proj.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.19.attn.c_proj.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.19.ln_1.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.19.ln_1.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.19.ln_2.bias': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta',
'h.19.ln_2.weight': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma',
'h.19.mlp.c_fc.bias': 'gpt2_decoder.layers.19.feedforward.c_fc.bias',
'h.19.mlp.c_fc.weight': 'gpt2_decoder.layers.19.feedforward.c_fc.weight',
'h.19.mlp.c_proj.bias': 'gpt2_decoder.layers.19.feedforward.c_proj.bias',
'h.19.mlp.c_proj.weight': 'gpt2_decoder.layers.19.feedforward.c_proj.weight',
'h.20.attn.c_attn.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.20.attn.c_attn.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.20.attn.c_proj.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.20.attn.c_proj.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.20.ln_1.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.20.ln_1.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.20.ln_2.bias': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta',
'h.20.ln_2.weight': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma',
'h.20.mlp.c_fc.bias': 'gpt2_decoder.layers.20.feedforward.c_fc.bias',
'h.20.mlp.c_fc.weight': 'gpt2_decoder.layers.20.feedforward.c_fc.weight',
'h.20.mlp.c_proj.bias': 'gpt2_decoder.layers.20.feedforward.c_proj.bias',
'h.20.mlp.c_proj.weight': 'gpt2_decoder.layers.20.feedforward.c_proj.weight',
'h.21.attn.c_attn.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.21.attn.c_attn.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.21.attn.c_proj.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.21.attn.c_proj.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.21.ln_1.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.21.ln_1.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.21.ln_2.bias': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta',
'h.21.ln_2.weight': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma',
'h.21.mlp.c_fc.bias': 'gpt2_decoder.layers.21.feedforward.c_fc.bias',
'h.21.mlp.c_fc.weight': 'gpt2_decoder.layers.21.feedforward.c_fc.weight',
'h.21.mlp.c_proj.bias': 'gpt2_decoder.layers.21.feedforward.c_proj.bias',
'h.21.mlp.c_proj.weight': 'gpt2_decoder.layers.21.feedforward.c_proj.weight',
'h.22.attn.c_attn.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.22.attn.c_attn.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.22.attn.c_proj.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.22.attn.c_proj.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.22.ln_1.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.22.ln_1.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.22.ln_2.bias': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta',
'h.22.ln_2.weight': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma',
'h.22.mlp.c_fc.bias': 'gpt2_decoder.layers.22.feedforward.c_fc.bias',
'h.22.mlp.c_fc.weight': 'gpt2_decoder.layers.22.feedforward.c_fc.weight',
'h.22.mlp.c_proj.bias': 'gpt2_decoder.layers.22.feedforward.c_proj.bias',
'h.22.mlp.c_proj.weight': 'gpt2_decoder.layers.22.feedforward.c_proj.weight',
'h.23.attn.c_attn.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.23.attn.c_attn.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.23.attn.c_proj.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.23.attn.c_proj.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.23.ln_1.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.23.ln_1.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.23.ln_2.bias': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta',
'h.23.ln_2.weight': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma',
'h.23.mlp.c_fc.bias': 'gpt2_decoder.layers.23.feedforward.c_fc.bias',
'h.23.mlp.c_fc.weight': 'gpt2_decoder.layers.23.feedforward.c_fc.weight',
'h.23.mlp.c_proj.bias': 'gpt2_decoder.layers.23.feedforward.c_proj.bias',
'h.23.mlp.c_proj.weight': 'gpt2_decoder.layers.23.feedforward.c_proj.weight',
'h.24.attn.c_attn.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.24.attn.c_attn.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.24.attn.c_proj.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.24.attn.c_proj.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.24.ln_1.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.24.ln_1.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.24.ln_2.bias': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta',
'h.24.ln_2.weight': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma',
'h.24.mlp.c_fc.bias': 'gpt2_decoder.layers.24.feedforward.c_fc.bias',
'h.24.mlp.c_fc.weight': 'gpt2_decoder.layers.24.feedforward.c_fc.weight',
'h.24.mlp.c_proj.bias': 'gpt2_decoder.layers.24.feedforward.c_proj.bias',
'h.24.mlp.c_proj.weight': 'gpt2_decoder.layers.24.feedforward.c_proj.weight',
'h.25.attn.c_attn.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.25.attn.c_attn.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.25.attn.c_proj.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.25.attn.c_proj.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.25.ln_1.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.25.ln_1.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.25.ln_2.bias': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta',
'h.25.ln_2.weight': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma',
'h.25.mlp.c_fc.bias': 'gpt2_decoder.layers.25.feedforward.c_fc.bias',
'h.25.mlp.c_fc.weight': 'gpt2_decoder.layers.25.feedforward.c_fc.weight',
'h.25.mlp.c_proj.bias': 'gpt2_decoder.layers.25.feedforward.c_proj.bias',
'h.25.mlp.c_proj.weight': 'gpt2_decoder.layers.25.feedforward.c_proj.weight',
'h.26.attn.c_attn.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.26.attn.c_attn.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.26.attn.c_proj.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.26.attn.c_proj.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.26.ln_1.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.26.ln_1.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.26.ln_2.bias': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta',
'h.26.ln_2.weight': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma',
'h.26.mlp.c_fc.bias': 'gpt2_decoder.layers.26.feedforward.c_fc.bias',
'h.26.mlp.c_fc.weight': 'gpt2_decoder.layers.26.feedforward.c_fc.weight',
'h.26.mlp.c_proj.bias': 'gpt2_decoder.layers.26.feedforward.c_proj.bias',
'h.26.mlp.c_proj.weight': 'gpt2_decoder.layers.26.feedforward.c_proj.weight',
'h.27.attn.c_attn.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.27.attn.c_attn.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.27.attn.c_proj.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.27.attn.c_proj.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.27.ln_1.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.27.ln_1.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.27.ln_2.bias': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta',
'h.27.ln_2.weight': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma',
'h.27.mlp.c_fc.bias': 'gpt2_decoder.layers.27.feedforward.c_fc.bias',
'h.27.mlp.c_fc.weight': 'gpt2_decoder.layers.27.feedforward.c_fc.weight',
'h.27.mlp.c_proj.bias': 'gpt2_decoder.layers.27.feedforward.c_proj.bias',
'h.27.mlp.c_proj.weight': 'gpt2_decoder.layers.27.feedforward.c_proj.weight',
'h.28.attn.c_attn.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.28.attn.c_attn.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.28.attn.c_proj.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.28.attn.c_proj.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.28.ln_1.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.28.ln_1.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.28.ln_2.bias': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta',
'h.28.ln_2.weight': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma',
'h.28.mlp.c_fc.bias': 'gpt2_decoder.layers.28.feedforward.c_fc.bias',
'h.28.mlp.c_fc.weight': 'gpt2_decoder.layers.28.feedforward.c_fc.weight',
'h.28.mlp.c_proj.bias': 'gpt2_decoder.layers.28.feedforward.c_proj.bias',
'h.28.mlp.c_proj.weight': 'gpt2_decoder.layers.28.feedforward.c_proj.weight',
'h.29.attn.c_attn.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.29.attn.c_attn.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.29.attn.c_proj.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.29.attn.c_proj.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.29.ln_1.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.29.ln_1.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.29.ln_2.bias': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta',
'h.29.ln_2.weight': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma',
'h.29.mlp.c_fc.bias': 'gpt2_decoder.layers.29.feedforward.c_fc.bias',
'h.29.mlp.c_fc.weight': 'gpt2_decoder.layers.29.feedforward.c_fc.weight',
'h.29.mlp.c_proj.bias': 'gpt2_decoder.layers.29.feedforward.c_proj.bias',
'h.29.mlp.c_proj.weight': 'gpt2_decoder.layers.29.feedforward.c_proj.weight',
'h.30.attn.c_attn.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.30.attn.c_attn.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.30.attn.c_proj.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.30.attn.c_proj.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.30.ln_1.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.30.ln_1.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.30.ln_2.bias': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta',
'h.30.ln_2.weight': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma',
'h.30.mlp.c_fc.bias': 'gpt2_decoder.layers.30.feedforward.c_fc.bias',
'h.30.mlp.c_fc.weight': 'gpt2_decoder.layers.30.feedforward.c_fc.weight',
'h.30.mlp.c_proj.bias': 'gpt2_decoder.layers.30.feedforward.c_proj.bias',
'h.30.mlp.c_proj.weight': 'gpt2_decoder.layers.30.feedforward.c_proj.weight',
'h.31.attn.c_attn.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.31.attn.c_attn.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.31.attn.c_proj.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.31.attn.c_proj.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.31.ln_1.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.31.ln_1.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.31.ln_2.bias': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta',
'h.31.ln_2.weight': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma',
'h.31.mlp.c_fc.bias': 'gpt2_decoder.layers.31.feedforward.c_fc.bias',
'h.31.mlp.c_fc.weight': 'gpt2_decoder.layers.31.feedforward.c_fc.weight',
'h.31.mlp.c_proj.bias': 'gpt2_decoder.layers.31.feedforward.c_proj.bias',
'h.31.mlp.c_proj.weight': 'gpt2_decoder.layers.31.feedforward.c_proj.weight',
'h.32.attn.c_attn.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.32.attn.c_attn.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.32.attn.c_proj.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.32.attn.c_proj.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.32.ln_1.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.32.ln_1.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.32.ln_2.bias': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta',
'h.32.ln_2.weight': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma',
'h.32.mlp.c_fc.bias': 'gpt2_decoder.layers.32.feedforward.c_fc.bias',
'h.32.mlp.c_fc.weight': 'gpt2_decoder.layers.32.feedforward.c_fc.weight',
'h.32.mlp.c_proj.bias': 'gpt2_decoder.layers.32.feedforward.c_proj.bias',
'h.32.mlp.c_proj.weight': 'gpt2_decoder.layers.32.feedforward.c_proj.weight',
'h.33.attn.c_attn.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.33.attn.c_attn.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.33.attn.c_proj.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.33.attn.c_proj.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.33.ln_1.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.33.ln_1.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.33.ln_2.bias': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta',
'h.33.ln_2.weight': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma',
'h.33.mlp.c_fc.bias': 'gpt2_decoder.layers.33.feedforward.c_fc.bias',
'h.33.mlp.c_fc.weight': 'gpt2_decoder.layers.33.feedforward.c_fc.weight',
'h.33.mlp.c_proj.bias': 'gpt2_decoder.layers.33.feedforward.c_proj.bias',
'h.33.mlp.c_proj.weight': 'gpt2_decoder.layers.33.feedforward.c_proj.weight',
'h.34.attn.c_attn.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.34.attn.c_attn.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.34.attn.c_proj.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.34.attn.c_proj.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.34.ln_1.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.34.ln_1.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.34.ln_2.bias': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta',
'h.34.ln_2.weight': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma',
'h.34.mlp.c_fc.bias': 'gpt2_decoder.layers.34.feedforward.c_fc.bias',
'h.34.mlp.c_fc.weight': 'gpt2_decoder.layers.34.feedforward.c_fc.weight',
'h.34.mlp.c_proj.bias': 'gpt2_decoder.layers.34.feedforward.c_proj.bias',
'h.34.mlp.c_proj.weight': 'gpt2_decoder.layers.34.feedforward.c_proj.weight',
'h.35.attn.c_attn.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias',
'h.35.attn.c_attn.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight',
'h.35.attn.c_proj.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias',
'h.35.attn.c_proj.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight',
'h.35.ln_1.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta',
'h.35.ln_1.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma',
'h.35.ln_2.bias': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta',
'h.35.ln_2.weight': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma',
'h.35.mlp.c_fc.bias': 'gpt2_decoder.layers.35.feedforward.c_fc.bias',
'h.35.mlp.c_fc.weight': 'gpt2_decoder.layers.35.feedforward.c_fc.weight',
'h.35.mlp.c_proj.bias': 'gpt2_decoder.layers.35.feedforward.c_proj.bias',
'h.35.mlp.c_proj.weight': 'gpt2_decoder.layers.35.feedforward.c_proj.weight',
'ln_f.bias': 'layer_norm.layer_norm.gamma',
'ln_f.weight': 'layer_norm.layer_norm.beta',
'wpe.weight': 'gpt2_embedding_postprocess.position_embedding_table',
'wte.weight': 'gpt2_embedding_lookup.embedding_table'
}

View File

@ -0,0 +1,148 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create mindrecord data for Children's Book Test task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import numpy as np
from mindspore.mindrecord import FileWriter
from src.utils.tokenization import Tokenizer
def create_instance(tokenizer, text, max_length=None, num_choice=None):
"""A single sample instance for cbt task."""
text = text.replace(" \t ", "\t ")
sentence = text.strip().split("\t")
context_length = len(tokenizer.encode(sentence[0]))
whole_sentence = sentence[0] + sentence[1]
whole_sentence = whole_sentence.strip()
assert whole_sentence != ""
print(" | whole sentence: ", whole_sentence)
ids = tokenizer.encode(whole_sentence)
input_length = len(ids)
pair_ids = None
output = tokenizer.prepare_for_model(ids=ids,
pair_ids=pair_ids,
add_special_tokens=True,
max_length=max_length,
padding=True,
truncate_direction="RIGHT",
return_overflowing_tokens=False,
return_attention_mask=True)
output["length"] = [context_length + 1] + [input_length + 1]
gold_answer_id = int(sentence[2])
assert gold_answer_id < 10
output["mc_labels"] = gold_answer_id
for name, value in output.items():
print(name)
print(value)
print("==================================")
return output
def write_instance_to_file(writer, instance):
"""write the instance to file"""
input_ids = instance["input_ids"]
input_mask = instance["attention_mask"]
assert len(input_ids) == len(input_mask)
length = instance["length"] # list
mc_labels = instance["mc_labels"]
features = collections.OrderedDict()
features["input_ids"] = np.asarray(input_ids)
features["input_mask"] = np.asarray(input_mask)
features["input_length"] = np.asarray(length)
features["mc_labels"] = mc_labels
writer.write_raw_data([features])
return features
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, default="", help='Input raw text file. ')
parser.add_argument("--output_file", type=str, required=True, default="", help='Output MindRecord file. ')
parser.add_argument("--num_splits", type=int, default=1,
help='The MindRecord file will be split into the number of partition. ')
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
parser.add_argument("--num_choice", type=int, required=True, help='Number of choices. ')
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
args = parser.parse_args()
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
num_choice = args.num_choice
input_file = args.input_file
logging.info("***** Reading from input files *****")
logging.info("Input File: %s", input_file)
output_file = args.output_file
logging.info("***** Writing to output files *****")
logging.info("Output File: %s", output_file)
writer = FileWriter(output_file, args.num_splits)
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
"input_mask": {"type": "int64", "shape": [-1]},
"input_length": {"type": "int64", "shape": [-1]},
"mc_labels": {"type": "int64"}
}
writer.add_schema(data_schema, "cbt-schema")
total_written = 0
total_read = 0
logging.info("***** Reading from %s *****", input_file)
with open(input_file, "r") as f:
while True:
line = f.readline()
if not line:
break
total_read += 1
if total_read % 500 == 0:
logging.info("%d ...", total_read)
output = create_instance(tokenizer, line, args.max_seq_length, num_choice)
features = write_instance_to_file(writer, instance=output)
total_written += 1
if total_written <= 20:
logging.info("***** Example *****")
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,140 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create mindrecord data for LAMBADA task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import numpy as np
from mindspore.mindrecord import FileWriter
from src.utils.tokenization import Tokenizer
def create_instance(tokenizer, text, max_length=None):
"""A single sample instance for LAMBADA task."""
text = text.replace(" \t ", "\t ")
sentence = text.strip().split("\t")
context_length = len(tokenizer.encode(sentence[0]))
whole_sentence = sentence[0] + sentence[1]
whole_sentence = whole_sentence.strip()
assert whole_sentence != ""
print(" | whole sentence: ", whole_sentence)
ids = tokenizer.encode(whole_sentence)
input_length = len(ids)
pair_ids = None
output = tokenizer.prepare_for_model(ids=ids,
pair_ids=pair_ids,
add_special_tokens=True,
max_length=max_length,
padding=True,
truncate_direction="RIGHT",
return_overflowing_tokens=False,
return_attention_mask=True)
# input_length = <bos> + text_length, not include <eos>
output["length"] = [context_length + 1] + [input_length + 1]
for k, v in output.items():
print(k)
print(v)
print("==================================")
return output
def write_instance_to_file(writer, instance):
"""write the instance to file"""
input_ids = instance["input_ids"]
input_mask = instance["attention_mask"]
assert len(input_ids) == len(input_mask)
length = instance["length"] # list
features = collections.OrderedDict()
features["input_ids"] = np.asarray(input_ids)
features["input_mask"] = np.asarray(input_mask)
features["input_length"] = np.asarray(length)
writer.write_raw_data([features])
return features
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
parser.add_argument("--num_splits", type=int, default=1,
help='The MindRecord file will be split into the number of partition. ')
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
args = parser.parse_args()
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
input_file = args.input_file
logging.info("***** Reading from input files *****")
logging.info("Input File: %s", input_file)
output_file = args.output_file
logging.info("***** Writing to output files *****")
logging.info("Output File: %s", output_file)
writer = FileWriter(output_file, args.num_splits)
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
"input_mask": {"type": "int64", "shape": [-1]},
"input_length": {"type": "int64", "shape": [-1]},
}
writer.add_schema(data_schema, "lambada-schema")
total_written = 0
total_read = 0
logging.info("***** Reading from %s *****", input_file)
with open(input_file, "r") as f:
while True:
line = f.readline()
if not line:
break
total_read += 1
if total_read % 500 == 0:
logging.info("%d ...", total_read)
output = create_instance(tokenizer, line, args.max_seq_length)
features = write_instance_to_file(writer, instance=output)
total_written += 1
if total_written <= 20:
logging.info("***** Example *****")
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,126 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create mindrecord data for LM task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import numpy as np
from mindspore.mindrecord import FileWriter
from src.utils.tokenization import Tokenizer
def create_instance(tokenizer, text, max_length=None):
"""A single sample instance for LM task."""
sentence = text.strip().split("\t")
ids = tokenizer.encode(sentence[0])
pair_ids = None
if len(sentence) == 2:
pair_ids = tokenizer.encode(sentence[1])
output = tokenizer.prepare_for_model(ids=ids,
pair_ids=pair_ids,
add_special_tokens=True,
max_length=max_length,
padding=True,
truncate_direction="LEFT",
return_overflowing_tokens=False,
return_attention_mask=True)
return output
def write_instance_to_file(writer, instance):
"""write the instance to file"""
input_ids = instance["input_ids"]
input_mask = instance["attention_mask"]
label_ids = instance["input_ids"]
assert len(input_ids) == len(label_ids)
features = collections.OrderedDict()
features["input_ids"] = np.asarray(input_ids)
features["input_mask"] = np.asarray(input_mask)
features["label_ids"] = np.asarray(label_ids)
writer.write_raw_data([features])
return features
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
parser.add_argument("--num_splits", type=int, default=1,
help='The MindRecord file will be split into the number of partition. ')
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
args = parser.parse_args()
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
input_file = args.input_file
logging.info("***** Reading from input files *****")
logging.info("Input File: %s", input_file)
output_file = args.output_file
logging.info("***** Writing to output files *****")
logging.info("Output File: %s", output_file)
writer = FileWriter(output_file, args.num_splits)
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
"input_mask": {"type": "int64", "shape": [-1]},
"label_ids": {"type": "int64", "shape": [-1]}
}
writer.add_schema(data_schema, "lm-schema")
total_written = 0
total_read = 0
logging.info("***** Reading from %s *****", input_file)
with open(input_file, "r") as f:
while True:
line = f.readline()
if not line:
break
total_read += 1
if total_read % 500 == 0:
logging.info("%d ...", total_read)
output = create_instance(tokenizer, line, args.max_seq_length)
features = write_instance_to_file(writer, instance=output)
total_written += 1
if total_written <= 20:
logging.info("***** Example *****")
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,130 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create mindrecord data for Summarization task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import numpy as np
from mindspore.mindrecord import FileWriter
from src.utils import tokenization
def create_instance(tokenizer, text, max_length=None):
"""A single sample instance for Summarization task."""
sentence = text.strip().split("\t")
ids = tokenizer.encode(sentence[0])
pair_ids = None
if len(sentence) == 2:
pair_ids = tokenizer.encode(sentence[1])
if len(sentence) >= 3:
article = sentence[0]
for i in range(1, len(sentence) - 1):
article += sentence[i]
ids = tokenizer.encode(article)
pair_ids = tokenizer.encode(sentence[-1])
output = tokenizer.prepare_for_model(ids=ids,
pair_ids=pair_ids,
add_special_tokens=True,
max_length=max_length,
padding=True,
return_overflowing_tokens=False,
return_attention_mask=True)
return output
def write_instance_to_file(writer, instance):
"""write the instance to file"""
input_ids = instance["input_ids"]
input_mask = instance["attention_mask"]
label_ids = instance["input_ids"]
assert len(input_ids) == len(label_ids)
features = collections.OrderedDict()
features["input_ids"] = np.asarray(input_ids)
features["input_mask"] = np.asarray(input_mask)
features["label_ids"] = np.asarray(label_ids)
writer.write_raw_data([features])
return features
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file.')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
parser.add_argument("--num_splits", type=int, default=1,
help='The MindRecord file will be split into the number of partition. ')
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length.')
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
parser.add_argument("--mode", type=str, required=True, default='cnn_dailymail', help='mode of dataset creation')
args = parser.parse_args()
tokenizer = tokenization.Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file, mode=args.mode)
input_file = args.input_file
logging.info("***** Reading from input files *****")
logging.info("Input File: %s", input_file)
output_file = args.output_file
logging.info("***** Writing to output files *****")
logging.info("Output File: %s", output_file)
writer = FileWriter(output_file, args.num_splits)
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
"input_mask": {"type": "int64", "shape": [-1]},
"label_ids": {"type": "int64", "shape": [-1]}
}
writer.add_schema(data_schema, "wikitext2-schema")
total_written = 0
total_read = 0
logging.info("***** Reading from %s *****", input_file)
with open(input_file, "r") as f:
while True:
line = f.readline()
if not line:
break
total_read += 1
if total_read % 500 == 0:
logging.info("%d ...", total_read)
output = create_instance(tokenizer, line, args.max_seq_length)
features = write_instance_to_file(writer, instance=output)
total_written += 1
if total_written <= 20:
logging.info("***** Example *****")
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""download the CNN & DailyMail for Summarization task"""
import argparse
from datasets import load_dataset
def generate_txt(url, split_, number=None, version="3.0.0"):
"""
generate txt file of cnn_dailymail dataset
Args:
url (str): directory of dataset txt file.
split_ (str): test or train.
number (int): top-n number of samples from dataset
version (str): "3.0.0" by default
"""
cnn = load_dataset("cnn_dailymail", version, split=split_)
if number == -1:
number = len(cnn)
f = open(url + split_ + '.txt', 'w')
for idx in range(number):
article = cnn[idx]['article']
article = article.replace('\n', ' ')
highlights = cnn[idx]['highlights']
highlights = highlights.replace('\n', ' ')
f.write(article + "\t" + highlights + '\n')
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download CNN_Dailymail 3.0.0 using datasets by Huggingface')
parser.add_argument('--dir', type=str, default="", help="directory of dataset")
parser.add_argument('--split', type=str, default='test', help="[test,train]")
parser.add_argument('--num', type=int, default=-1,
help=" number of samples by default order. "
"If num is -1, it will download whole dataset. Default: -1")
args = parser.parse_args()
data_directory = args.dir
split = args.split
num = args.num
generate_txt(url=data_directory, split_=split, number=num)

View File

@ -0,0 +1,135 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation reading comprehension result with additional answer."""
import json
import re
import string
import argparse
from collections import Counter
def get_normalize_answer_token(string_):
"""normalize the answer token, Lower text and remove punctuation, article and extra whitespace"""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(char for char in text if char not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(string_)))).split()
def calculate_f1(pred_answer, gold_answer):
"""
calculate final F1 score with addition answer
"""
f1_score = 0
pred_answer = get_normalize_answer_token(pred_answer)
gold_answer = get_normalize_answer_token(gold_answer)
common = Counter(pred_answer) & Counter(gold_answer)
num_same = sum(common.values())
# the number of same tokens between pred_answer and gold_answer
precision = 1.0 * num_same / len(pred_answer) if pred_answer.strip() == "" else 0
recall = 1.0 * num_same / len(gold_answer) if gold_answer.strip() == "" else 0
if pred_answer.strip() == "" and gold_answer.strip() == "":
f1_score = 1
else:
f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0
return f1_score
def main():
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
parser.add_argument("--input_file", type=str, default="",
help="The log file path of evaluation in Reading Comprehension. ")
parser.add_argument("--addition_file", type=str, default="", help="Coqa-dev-v1.0.json path")
args_opt = parser.parse_args()
input_file = args_opt.input_file
addition_file = args_opt.addition_file
find_word = 'Pred_answer:'
find_word_length = len(find_word)
pred_answer_list = []
with open(input_file, 'r', encoding='utf-8') as f:
while True:
line = f.readline()
if not line:
break
index = line.find(find_word)
if index != -1:
pred_answer = line[index + find_word_length:].strip()
pred_answer_list.append(pred_answer)
dataset = json.load(open(addition_file))
pred_answer_num = 0
total_f1score = 0
average_f1score = 0
data_num = len(pred_answer_list)
for story in dataset['data']:
questions = story['questions']
multiple_answers = [story['answers']]
multiple_answers += story['additional_answers'].values()
for question in questions:
pred_a = pred_answer_list[pred_answer_num]
turn_id = question['turn_id']
max_score = 0
max_group = 0
flag = 0
for i, answer in enumerate(multiple_answers):
gold_a = answer[turn_id - 1]['input_text']
score = calculate_f1(pred_a, gold_a)
if score > max_score:
max_score = score
max_group = i
# calculate the max score in multiple answers and record it's number.
gold_a = multiple_answers[max_group][turn_id - 1]['input_text']
pred_answer_num += 1
total_f1score += max_score
average_f1score = total_f1score / pred_answer_num
print('==================== data {} ===================='.format(pred_answer_num))
print('| Gold_answer:{}'.format(gold_a))
print('| Pred_answer:{}'.format(pred_a))
print('| F1_Score:{:.8f}'.format(average_f1score))
print('=====================================================\n')
if pred_answer_num >= data_num:
flag = 1
break
# Stop flag
if flag:
print('Finished evaluation with addition answer! \n')
print("********************** Testing Finished **********************")
print('| Test file name: {}'.format(input_file))
print('| Final F1 score: {:.8f}'.format(average_f1score))
print('| Total data num: {}'.format(pred_answer_num))
print("**************************************************************")
break
if __name__ == "__main__":
main()

View File

@ -0,0 +1,270 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for Children's Book Test task.
"""
import argparse
import time
import numpy as np
from mindspore import context
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CBT
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.utils.metric_method import Accuracy
from src.dataset import create_cbt_dataset, create_language_model_dataset
from src.utils.lr_schedule import GPT2LearningRate
from src.utils.task_utils import calculate_choice_prob_for_cbt
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrained model checkpoint.
save_checkpoint_path: the file path which will save finetuned model checkpoint.
epoch_num: the number of epoch.
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_" + "cbt_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" + str(epoch_num) +\
"_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load pretrained parameter successfully!\n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("==================== Starting Finetuning ====================")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("==================== Finetuning Success ====================")
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, num_choice=None):
"""
Do evaluation for CBT task.
Args:
dataset: the eval dataset.
network: the network with loss.
metric: the evaluation method.
load_checkpoint_path: the file path which saved finetuned model checkpoint.
eval_type:
num_choice:
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
if metric.lower() == "accuracy":
print("Prepare to calculate the accuracy score ...")
gpt2_cbt = network(config=gpt2_net_cfg,
is_training=False,
use_one_hot_embeddings=False
)
gpt2_cbt.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
if eval_type == "zero-shot":
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(gpt2_cbt, final_param_dict)
print("load pretrained parameter successfully!\n")
elif eval_type == "finetuned":
load_param_into_net(gpt2_cbt, param_dict)
print("load finetuned parameter successfully!\n")
else:
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
model = Model(gpt2_cbt)
callback = Accuracy()
columns_list = ["input_ids", "input_mask", "input_length", "mc_labels"]
print("==================== [ACC] Testing ====================")
num_data = 1
all_choice_prob = []
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, input_length, mc_labels = input_data
print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
# print("mc_labels: {}".format(mc_labels)) # [batch_size]
logits = model.predict(input_ids, input_mask)
# choice_prob_list [batch_size]
choice_prob_list = calculate_choice_prob_for_cbt(logits=logits,
batch_size=gpt2_net_cfg.batch_size,
input_length=input_length,
input_ids=input_ids)
all_choice_prob.append(choice_prob_list)
if (num_data * gpt2_net_cfg.batch_size) % num_choice == 0:
all_choice_prob_np = np.array(all_choice_prob)
all_choice_prob_np = all_choice_prob_np.reshape((-1, num_choice))
print("| all_choice_prob_np: ", all_choice_prob_np)
print("| all_choice_prob_np shape: ", all_choice_prob_np.shape)
mc_labels = np.array([mc_labels.asnumpy()[0]])
callback.update(all_choice_prob_np, mc_labels)
all_choice_prob = []
num_data += 1
print("\n\n")
print("**************************************************************")
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num))
print("********************** Testing Finished **********************")
else:
raise ValueError("metric method not supported, support: [Accuracy]")
def run_cbt_task():
"""
run Children's Book Test (CBT) task
"""
parser = argparse.ArgumentParser(description="Finetune and Evaluate CBT task")
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=1,
help="ID of target device. ")
parser.add_argument("--num_choice", type=int, default=10,
help="The number of choice in CBT task. ")
parser.add_argument("--metric_method", type=str, default="Accuracy",
help="The eval method including [Accuracy]. Default: Accuracy.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: true.")
parser.add_argument("--eval_type", type=str, default="zero-shot",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--epoch_num", type=int, default=1,
help="Epoch number. Default: 1.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the finetuned checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path for train.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path for evaluation.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device_target = args_opt.device_target
if device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target,
device_id=args_opt.device_id,
max_call_depth=3000)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
gpt2_loss = GPT2CBT(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
if args_opt.do_train.lower() == "true":
print("============== Start Loading Train Dataset ============")
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
print("============== Start Loading Evaluation Dataset ============")
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
eval_dataset = create_cbt_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
dataset_path=args_opt.eval_data_file_path)
do_eval(eval_dataset, GPT2CBT, metric, load_finetune_ckpt_path, args_opt.eval_type, args_opt.num_choice)
if __name__ == "__main__":
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_cbt_task()
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,293 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for Reading Comprehension task.
"""
import argparse
import time
from mindspore import context
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CoQA
from src.GPT2ForReadComprehension import GPT2CoQAModel
from src.utils.metric_method import F1
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.dataset import create_language_model_dataset
from src.utils.lr_schedule import GPT2LearningRate
from src.utils.tokenization import Tokenizer
from src.GPT2_generation import GenerateForReadComprehension
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrained model checkpoint.
save_checkpoint_path: the file path which will save finetuned model checkpoint.
epoch_num: the number of epoch.
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_rc_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load the pretrained parameter successfully! \n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("=================== Starting Training For Translation Task ====================")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("=================== Translation Training Success ====================")
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="",
generate_length=1, top_k=1, top_p=1.0, temperature=1.0):
"""
Do evaluation on Translation
Args:
dataset: the eval dataset.
network: the network with loss.
metric: the evaluation method.
load_checkpoint_path: the file path which saved finetune model checkpoint.
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
if metric.lower() == "f1":
print("Prepare to calculate the BLEU score ...")
gpt2_rc = network(config=gpt2_net_cfg,
is_training=False,
use_one_hot_embeddings=False)
gpt2_rc.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
if eval_type == "zero-shot":
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.' + name] = param_dict[name]
final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(gpt2_rc, final_param_dict)
print("load pretrained parameter successfully!\n")
elif eval_type == "finetuned":
load_param_into_net(gpt2_rc, param_dict)
print("load finetuned parameter successfully!\n")
else:
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
model = Model(gpt2_rc)
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
callback = F1()
rc_generator = GenerateForReadComprehension(decoder=model,
config=gpt2_net_cfg,
tokenizer=tokenizer,
generate_length=generate_length,
topk_num=top_k,
topp_prob=float(top_p),
temperature=float(temperature)
)
columns_list = ["input_ids", "input_mask", "label_ids"]
print("==================== [F1] Testing ====================")
num_data = 0
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, _, label_ids = input_data
print("input_ids shape: {}".format(input_ids.shape))
print("label_ids shape: {}".format(label_ids.shape))
passage, pred_answer, gold_answer = rc_generator.generate_for_read_comprehension(input_ids)
for batch_id in range(gpt2_net_cfg.batch_size):
print("============== [F1] {} ================".format(num_data + 1))
print(" | Passage:{}".format(passage[batch_id]))
print(" | Gold_answer:{}".format(gold_answer[batch_id]))
print(" | Pred_answer:{}".format(pred_answer[batch_id]))
pred = callback.get_normalize_answer_token(pred_answer[batch_id])
gold = callback.get_normalize_answer_token(gold_answer[batch_id])
callback.update(pred, gold)
num_data += 1
average_f1_score = callback.f1_score / num_data
print("============== Evaluation =================")
print("| Avg F1 Score:{:.8f}".format(average_f1_score))
print("=============================================\n\n")
print("********************** Testing Finished **********************")
else:
raise ValueError("metric method not supported in Reading Comprehension task, support: [F1]")
def run_Readcomprehension():
'''
run Readcomprehension task
'''
parser = argparse.ArgumentParser(description="Finetune and Evaluate translation")
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=0,
help="ID of target device. ")
parser.add_argument("--metric_method", type=str, default="F1",
help="The eval method including [F1]. Default: F1.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: false.")
parser.add_argument("--eval_type", type=str, default="zero-shot",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--epoch_num", type=int, default=1,
help="Epoch number. Default: 1.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--tokenizer_file_path", type=str, default="",
help="pretrained vocab and merge file path.")
parser.add_argument("--generate_length", type=int, default=55,
help="The generation length of translation sentence.")
parser.add_argument("--top_k", type=int, default=1,
help="Parameter for Top-K sampling.")
parser.add_argument("--top_p", type=str, default="1.0",
help="parameter for Top-P sampling.")
parser.add_argument("--temperature", type=str, default="1.0",
help="Parameter for generation, greater if generation more diverse. ")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device_target = args_opt.device_target
if device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target,
device_id=args_opt.device_id,
max_call_depth=3000)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
gpt2_loss = GPT2CoQA(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
if args_opt.do_train.lower() == "true":
print("============== Start Loading Translation Train Dataset ==============")
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
print("============ Start Loading Translation Evaluation Dataset ============")
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
dataset_path=args_opt.eval_data_file_path)
do_eval(eval_dataset, GPT2CoQAModel, metric, load_finetune_ckpt_path, args_opt.eval_type,
args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p,
args_opt.temperature)
if __name__ == "__main__":
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_Readcomprehension()
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,328 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for LAMBADA task.
"""
import argparse
import math
import time
from mindspore import context
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Lambada
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.utils.metric_method import LastWordAccuracy
from src.dataset import create_language_model_dataset, create_lambada_control_dataset
from src.utils.lr_schedule import GPT2LearningRate
from src.utils.task_utils import get_final_word_label
from src.utils.tokenization import Tokenizer
from src.GPT2_generation import GenerateForLambada
from src.utils.CrossEntropy import CrossEntropyCalculationWithMask
from src.utils.get_config_setting import get_train_setting, get_model_setting
from src.utils.task_utils import calculate_final_word_loss
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrain model checkpoint.
save_checkpoint_path: the file path which will save finetune model checkpoint.
epoch_num: the number of epoch
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_" + "lambada_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load pretrained parameter successfully!\n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("==================== Starting Finetuning ====================")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("==================== Finetuning Success ====================")
def eval_result_print(metric="accuracy", callback=None):
"""
Print eval result.
"""
if metric.lower() == "accuracy":
print("acc_num {}, total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num))
else:
raise ValueError("metric method not supported, support: [accuracy]")
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, stop_word_file="",
generate_length_dynamic=True, tokenizer_file_path=""):
"""
Do eval
Args:
dataset: the eval dataset.
network: the network with loss.
metric: the evaluation method.
load_checkpoint_path: the file path which saved finetune model checkpoint.
eval_type: the eval type, i.e. zero-shot, finetuned.
generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True.
tokenizer_file_path: the tokenizer file path for vocab file and merge file.
stop_word_file: stop word file for calculating Accuracy.
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
gpt2_lambada = network(config=gpt2_net_cfg,
is_training=False,
use_one_hot_embeddings=False)
gpt2_lambada.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
if eval_type == "zero-shot":
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(gpt2_lambada, final_param_dict)
print("load pretrained parameter successfully!\n")
elif eval_type == "finetuned":
load_param_into_net(gpt2_lambada, param_dict)
print("load finetuned parameter successfully!\n")
model = Model(gpt2_lambada)
if metric.lower() == "accuracy":
print("Prepare to calculate the accuracy score ...")
callback = LastWordAccuracy()
columns_list = ["input_ids", "input_mask", "input_length"]
print("==================== [ACC] Testing ====================")
lambada_generator = GenerateForLambada(decoder=model,
config=gpt2_net_cfg,
tokenizer=tokenizer,
generate_length_dynamic=generate_length_dynamic,
max_iterations=200,
stop_word_file=stop_word_file)
num_data = 1
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, input_length = input_data
print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
logits = model.predict(input_ids, input_mask)
predict_str = lambada_generator.generate_for_lambada(input_ids=input_ids,
logits=logits,
input_length=input_length)
label_str = get_final_word_label(input_ids=input_ids, input_length=input_length, tokenizer=tokenizer)
callback.update(predict_str, label_str)
eval_result_print(metric, callback)
num_data += 1
print("\n\n")
print("**********************************************************")
eval_result_print(metric, callback)
print("******************** Testing Finished ********************")
elif metric.lower() == "ppl":
print("Prepare to calculate the ppl score ...")
cross_entropy = CrossEntropyCalculationWithMask(is_training=True,
num_labels=gpt2_net_cfg.vocab_size,
config=gpt2_net_cfg)
columns_list = ["input_ids", "input_mask", "input_length"]
num_data = 1
total_loss = 0.0
print("==================== [PPL] Testing ====================")
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, input_length = input_data
print("| [PPL] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
logits = model.predict(input_ids, input_mask) # (batch_size, seq_len, vocab_size)
avg_batch_loss = calculate_final_word_loss(logits,
gpt2_net_cfg.batch_size,
input_ids,
input_length,
cross_entropy)
total_loss += avg_batch_loss
avg_total_loss = total_loss / num_data
print(" | Current AVG loss:", avg_total_loss)
print(" | Current AVG ppl:", math.exp(avg_total_loss))
num_data += 1
print("\n\n")
print("**********************************************************")
print("Average PPL: {:.6f}".format(math.exp(avg_total_loss)))
print("******************** Testing Finished ********************")
else:
raise ValueError("metric method not supported, support: [accuracy, ppl]")
def run_lambada():
"""
Run Lambada task.
"""
parser = argparse.ArgumentParser(description="Finetune and Evaluate languagemodel")
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=2,
help="ID of target device.")
parser.add_argument("--metric_method", type=str, default="PPL",
help="The eval method including [Accuracy, PPL]. Default: Accuracy.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: false.")
parser.add_argument("--eval_type", type=str, default="finetuned",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--epoch_num", type=int, default=3,
help="Epoch number. Default: 1.")
parser.add_argument("--train_data_shuffle", type=str, default="false",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--generate_length_dynamically", type=str, default="true",
help="Enable generate_length_Dynamically. Default: true.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path.")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path.")
parser.add_argument("--tokenizer_file_path", type=str, default="",
help="pretrained vocab and merge file path.")
parser.add_argument("--stop_word_file_path", type=str, default="",
help="The stop word file path.")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device = args_opt.device_target
if device == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
gpt2_loss = GPT2Lambada(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
if args_opt.do_train.lower() == "true":
get_train_setting(cfg)
get_model_setting(cfg, gpt2_net_cfg)
print("============== Start Loading Train Dataset ============")
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
get_model_setting(cfg, gpt2_net_cfg)
print("============== Start Loading Evaluation Dataset ============")
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
eval_dataset = create_lambada_control_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
dataset_path=args_opt.eval_data_file_path)
do_eval(eval_dataset, GPT2Lambada, metric, load_finetune_ckpt_path, args_opt.eval_type,
args_opt.stop_word_file_path, args_opt.generate_length_dynamically, args_opt.tokenizer_file_path)
if __name__ == "__main__":
print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_lambada()
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,255 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for Language Modeling task.
"""
import argparse
import math
import time
from mindspore import context
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2LM
from src.utils.lr_schedule import GPT2LearningRate
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.dataset import create_language_model_dataset
from src.utils.get_config_setting import get_train_setting, get_model_setting
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrained model checkpoint.
save_checkpoint_path: the file path which will save finetuned model checkpoint.
epoch_num: the number of epoch.
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_language_model_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load pretrained parameter successfully!\n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("==================== Starting Finetuning ====================")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("==================== Finetuning Success ====================")
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None):
"""
Do eval
Args:
dataset: the eval dataset.
network: the network with loss.
metric: the evaluation method.
load_checkpoint_path: the file path which saved finetuned model checkpoint.
eval_type:
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
if metric.lower() == "ppl":
print("Prepare to calculate the ppl score ...")
gpt2_loss = network(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
gpt2_loss.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
if eval_type == "zero-shot":
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(gpt2_loss, final_param_dict)
print("load pretrained parameter successfully!\n")
elif eval_type == "finetuned":
load_param_into_net(gpt2_loss, param_dict)
print("load finetuned parameter successfully!\n")
else:
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
model = Model(gpt2_loss)
columns_list = ["input_ids", "input_mask", "label_ids"]
print("==================== [PPL] Testing ====================")
num_data = 1
total_loss = 0.0
avg_loss = 0.0
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, label_ids = input_data
loss = model.predict(input_ids, input_mask, label_ids)
loss = float(loss.asnumpy())
total_loss += loss
avg_loss = float(total_loss / num_data)
print(" | Current Loss: {:.6f}".format(avg_loss))
print(" | Current PPL: {}\n\n".format(math.exp(avg_loss)))
num_data += 1
print("\n\n")
print("**************************************************************")
print("Average Loss: {:.6f}".format(avg_loss))
print("Average PPL: {:.6f}".format(math.exp(avg_loss)))
print("********************** Testing Finished **********************")
else:
raise ValueError("metric method not supported, support: [ppl]")
def run_languagemodel():
"""
run Language Modeling task
"""
parser = argparse.ArgumentParser(description="Finetune and Evaluate language modelings task")
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=1,
help="ID of target device. ")
parser.add_argument("--metric_method", type=str, default="PPL",
help="The eval method including [PPL]. Default: PPL.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: true.")
parser.add_argument("--eval_type", type=str, default="zero-shot",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--epoch_num", type=int, default=1,
help="Epoch number. Default: 1.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the finetuned checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path for train.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path for evaluation.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device_target = args_opt.device_target
if device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target,
device_id=args_opt.device_id,
max_call_depth=3000)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
gpt2_loss = GPT2LM(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
if args_opt.do_train.lower() == "true":
get_train_setting(cfg)
get_model_setting(cfg, gpt2_net_cfg)
print("==================== Start Loading Train Dataset ==================")
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
get_model_setting(cfg, gpt2_net_cfg)
print("==================== Start Loading Evaluation Dataset ==================")
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.eval_data_file_path)
do_eval(eval_dataset, GPT2LM, metric, load_finetune_ckpt_path, args_opt.eval_type)
if __name__ == "__main__":
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_languagemodel()
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,296 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for Summarization task.
"""
import time
import argparse
from mindspore import context
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.GPT2ForSummarization import GPT2SummarizationModel
from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.utils.metric_method import Rouge
from src.dataset import create_language_model_dataset
from src.utils.lr_schedule import GPT2LearningRate
from src.utils.tokenization import Tokenizer
from src.utils.task_utils import clean_hypo, modify_paramdict
from src.GPT2_generation import GenerateForSummarization
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrain model checkpoint.
save_checkpoint_path: the file path which will save finetune model checkpoint.
epoch_num: the number of epoch
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(
filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load pretrained parameter successfully!\n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("============== Starting Finetuning ==============")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("============== Finetuning Success ==============")
def eval_result_print(metric="Rouge", callback=None):
"""
print eval result
"""
if metric == "Rouge":
print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}".
format(callback.Rouge1 / callback.total_num,
callback.Rouge2 / callback.total_num,
callback.RougeL / callback.total_num,
(callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num)))
else:
raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric)))
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="",
top_k=None, top_p=None, temperature=None, generate_length=None):
"""
Do evaluation on summarization
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
if metric.lower() == "rouge":
print("Prepare to calculate the Rouge score ...")
callback = Rouge()
gpt2_loss = network(config=gpt2_net_cfg,
is_training=False,
use_one_hot_embeddings=False)
gpt2_loss.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.")
load_param_into_net(gpt2_loss, reorganized_param_dict)
# load nn.Cell into Model and initiate tokenizer and Sample
model = Model(gpt2_loss)
tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json',
merge_file=tokenizer_file + 'gpt2-merges.txt')
# load data and process text generation
columns_list = ["input_ids", "input_mask", "label_ids"]
summarization_generator = GenerateForSummarization(model,
config=gpt2_net_cfg,
tokenizer=tokenizer,
select_sentence=3,
eval_type=eval_type,
topk=top_k,
topp=float(top_p),
temperature=float(temperature),
generate_length=generate_length)
num_data = 1
print("==================== [Summrization] Testing ====================")
for data in dataset.create_dict_iterator():
input_data = []
for value in columns_list:
input_data.append(data[value])
input_ids, _, label_ids = input_data
print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
print("input_ids shape: {}".format(input_ids.shape))
print("label_ids shape: {}".format(label_ids.shape))
hypothesis, ref = summarization_generator.generate_for_summarization(input_ids)
if ref[0] == '' or ref[0] is None:
print("Sorry ref_list is None, skip it!")
continue
print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n")
for batch_idx in range(gpt2_net_cfg.batch_size):
hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx])
for batch_idx in range(gpt2_net_cfg.batch_size):
hypothesis[batch_idx] = hypothesis[batch_idx].lower()
ref[batch_idx] = ref[batch_idx].lower()
callback.update(hypothesis, ref)
num_data += 1
print("\n\n")
print("**********************************************************")
eval_result_print(metric, callback)
print("******************** Testing Finished ********************")
else:
raise ValueError("metric method not supported in summarization, support: [Rouge]")
def run_summarization():
"""
Run Summarization task.
"""
# set argument parser
parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization")
# context and task settings
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=4,
help="ID of target device.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: false.")
parser.add_argument("--eval_type", type=str, default="finetuned",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--metric_method", type=str, default="Rouge",
help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.")
parser.add_argument("--epoch_num", type=int, default=2,
help="Epoch number. Default: 2.")
# dataset and params_dict file settings
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
# sampling settings
parser.add_argument("--top_k", type=int, default=2,
help="top k tokens chosen for sampling")
parser.add_argument("--top_p", type=str, default="1.0",
help="top p accumulated probability threshold for logit to be counted")
parser.add_argument("--generate_length", type=int, default=100,
help="the number of generated tokens.")
parser.add_argument("--temperature", type=str, default="1.0",
help="temperature on logits for sampling")
parser.add_argument("--tokenizer_file_path", type=str, default="",
help="vocab & merge file path")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
eval_type = args_opt.eval_type
tokenizer_file = args_opt.tokenizer_file_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device = args_opt.device_target
if device == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
if args_opt.do_train.lower() == "true":
train_data_file_path = args_opt.train_data_file_path
gpt2_loss = GPT2Summarization(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
print("============== Start Loading Train Dataset ============")
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
eval_dataset_file_path = args_opt.eval_data_file_path
print("============== Start Loading Evaluation Dataset ============")
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=eval_dataset_file_path)
do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file,
args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length)
if __name__ == "__main__":
print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_summarization()
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,298 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 finetune and evaluation script for Translation task.
"""
import argparse
import time
from mindspore import context
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.GPT2ForTranslation import GPT2TranslationModel
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Translation
from src.finetune_eval_config import cfg, gpt2_net_cfg
from src.dataset import create_language_model_dataset
from src.utils.lr_schedule import GPT2LearningRate
from src.utils.tokenization import Tokenizer
from src.utils.metric_method import BLEU
from src.GPT2_generation import GenerateForTranslation
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
"""
Do train
Args:
dataset: the train dataset.
network: the network with loss
load_checkpoint_path: the file path which saved pretrained model checkpoint.
save_checkpoint_path: the file path which will save finetuned model checkpoint.
epoch_num: the number of epoch.
"""
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
# optimizer
if cfg.optimizer == 'AdamWeightDecay':
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == 'Lamb':
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
end_learning_rate=cfg.Lamb.end_learning_rate,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
decay_steps=steps_per_epoch * epoch_num,
power=cfg.Lamb.power)
optimizer = Lamb(network.trainable_params(), lr_schedule)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
else:
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
prefix_name = "gpt2_translation_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
config=ckpt_config)
param_dict = load_checkpoint(load_checkpoint_path)
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(network, final_param_dict)
print("Load the pretrained parameter successfully! \n")
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads.set_train(True)
loss_cb = LossMonitor(per_print_times=1)
model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
print("=================== Starting Training For Translation Task ====================")
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
print("=================== Translation Training Success ====================")
def eval_result_print(metric="BLEU", callback=None):
""" print eval result"""
if metric == "BLEU":
print(" | BLEU: {:.6f}".format(callback.bleu / float(callback.total_num)))
else:
raise ValueError("metric method '{}' not supported, support: [BLEU]. ".format(str(metric)))
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="",
generate_length=1, top_k=1, top_p=1.0, temperature=1.0):
"""
Do evaluation on Translation
Args:
dataset: the eval dataset.
network: the network with loss.
metric: the evaluation method.
load_checkpoint_path: the file path which saved finetune model checkpoint.
"""
if load_checkpoint_path == "":
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
if metric.lower() == "bleu":
print("Prepare to calculate the BLEU score ...")
gpt2_translation = network(config=gpt2_net_cfg,
is_training=False,
use_one_hot_embeddings=False)
gpt2_translation.set_train(False)
param_dict = load_checkpoint(load_checkpoint_path)
if eval_type == "zero-shot":
final_param_dict = {}
for name, _ in param_dict.items():
final_param_dict['gpt2.' + name] = param_dict[name]
final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
load_param_into_net(gpt2_translation, final_param_dict)
print("load pretrained parameter successfully!\n")
elif eval_type == "finetuned":
load_param_into_net(gpt2_translation, param_dict)
print("load finetuned parameter successfully!\n")
else:
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
model = Model(gpt2_translation)
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
callback = BLEU(tokenizer)
translation_generator = GenerateForTranslation(decoder=model,
config=gpt2_net_cfg,
tokenizer=tokenizer,
generate_length=1,
use_hint=True,
select_first_sentence=True,
topk_num=top_k,
topp_prob=float(top_p),
temperature=float(temperature)
)
columns_list = ["input_ids", "input_mask", "label_ids"]
print("==================== [BLEU] Testing ====================")
num_data = 1
for data in dataset.create_dict_iterator():
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, label_ids = input_data
print("| Data count: {}".format(num_data * gpt2_net_cfg.batch_size))
print("input_ids shape: {}".format(input_ids.shape))
print("input_mask shape: {}".format(input_mask.shape))
print("label_ids shape: {}".format(label_ids.shape))
ts_predict_list, ref_list = translation_generator.generate_for_translation(input_ids)
print("| Batch Reference translation:\n{}\n".format(ref_list))
if ref_list == '' or ref_list is None:
print("Sorry ref_list is None, skip it!")
continue
else:
print(" | Batch Predict translation:\n{}\n".format(ts_predict_list))
callback.update(ref_list, ts_predict_list)
num_data += 1
print("\n\n")
print("**************************************************************")
eval_result_print(metric, callback)
print("********************** Testing Finished **********************")
else:
raise ValueError("metric method not supported in translation, support: [BLEU]")
def run_translation():
"""
run translation task
"""
parser = argparse.ArgumentParser(description="Finetune and Evaluate translation")
parser.add_argument("--device_target", type=str, default="Ascend",
help="Device type. Default: Ascend.")
parser.add_argument("--device_id", type=int, default=0,
help="ID of target device. ")
parser.add_argument("--metric_method", type=str, default="BLEU",
help="The eval method including [BLEU]. Default: BLEU.")
parser.add_argument("--do_train", type=str, default="false",
help="Enable train. Default: false.")
parser.add_argument("--do_eval", type=str, default="true",
help="Enable evaluation. Default: false.")
parser.add_argument("--eval_type", type=str, default="zero-shot",
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
parser.add_argument("--epoch_num", type=int, default=1,
help="Epoch number. Default: 1.")
parser.add_argument("--train_data_shuffle", type=str, default="true",
help="Enable train data shuffle. Default: true.")
parser.add_argument("--eval_data_shuffle", type=str, default="false",
help="Enable eval data shuffle. Default: false.")
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
help="Save the checkpoint path.")
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
help="Load the checkpoint file path.")
parser.add_argument("--train_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_file_path", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--tokenizer_file_path", type=str, default="",
help="pretrained vocab and merge file path.")
parser.add_argument("--generate_length", type=int, default=150,
help="The generation length of translation sentence.")
parser.add_argument("--top_k", type=int, default=1,
help="Parameter for Top-K sampling.")
parser.add_argument("--top_p", type=str, default="1.0",
help="parameter for Top-P sampling.")
parser.add_argument("--temperature", type=str, default="1.0",
help="Parameter for generation, greater if generation more diverse. ")
args_opt = parser.parse_args()
epoch_num = args_opt.epoch_num
metric = args_opt.metric_method
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
raise ValueError("'train_data_file_path' must be set when do finetune task")
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
device_target = args_opt.device_target
if device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target,
device_id=args_opt.device_id,
max_call_depth=3000)
context.set_auto_parallel_context(parallel_mode="stand_alone")
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
else:
raise Exception("Device target error, Ascend is supported.")
gpt2_loss = GPT2Translation(config=gpt2_net_cfg,
is_training=True,
use_one_hot_embeddings=False)
if args_opt.do_train.lower() == "true":
print("============== Start Loading Translation Train Dataset ==============")
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
dataset_path=args_opt.train_data_file_path)
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
if args_opt.do_eval.lower() == "true":
print("============ Start Loading Translation Evaluation Dataset ============")
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
dataset_path=args_opt.eval_data_file_path)
do_eval(eval_dataset, GPT2TranslationModel, metric, load_finetune_ckpt_path, args_opt.eval_type,
args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p,
args_opt.temperature)
if __name__ == "__main__":
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
run_translation()
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

View File

@ -0,0 +1,60 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_cbt.sh"
echo "for example: bash scripts/run_cbt.sh"
echo "metric method: Accuracy"
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_cbt.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_CBT_task.py \
--device_target="Ascend" \
--device_id=4 \
--num_choice=10 \
--metric_method="Accuracy" \
--do_train="false" \
--do_eval="true" \
--eval_type="zero-shot" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &

View File

@ -0,0 +1,68 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_lambada.sh"
echo "for example: bash scripts/run_lambada.sh"
echo "method metric include: [Accuracy, PPL]"
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_lambada.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
# tokenizer path
tokenizer_file_path=""
# stopword path
stop_word_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_lambada.py \
--device_target="Ascend" \
--device_id=1 \
--metric_method="PPL" \
--do_train="false" \
--do_eval="true" \
--eval_type="zero-shot" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--generate_length_dynamically="true" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path \
--tokenizer_file_path=$tokenizer_file_path \
--stop_word_file_path=$stop_word_file_path >> $output_log 2>&1 &

View File

@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_language_model.sh"
echo "for example: bash scripts/run_language_model.sh"
echo "metric method: PPL"
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_language_model.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_language_model.py \
--device_target="Ascend" \
--device_id=4 \
--metric_method="PPL" \
--do_train="false" \
--do_eval="true" \
--eval_type="zero-shot" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &

View File

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_read_comprehension.sh"
echo "for example: bash scripts/run_read_comprehension.sh"
echo "metric method: F1"
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_read_comprehension.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
# tokenizer path
tokenizer_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_ReadComprehension.py \
--device_target="Ascend" \
--device_id=7 \
--metric_method="F1" \
--do_train="false" \
--do_eval="true" \
--eval_type="zero-shot" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path \
--tokenizer_file_path=$tokenizer_file_path \
--generate_length=55 \
--top_k=1 \
--top_p="1.0" \
--temperature="1.0" >> $output_log 2>&1 &

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_summarization.sh"
echo "for example: bash scripts/run_summarization.sh"
echo "eval_load_param_mode include: [zero-shot, finetuned]. Default: finetuned"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_summarization.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
# tokenizer path
tokenizer_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_summarization.py \
--device_target="Ascend" \
--device_id=0 \
--do_train="false" \
--do_eval="true" \
--metric_method="Rouge" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--top_k=2 \
--top_p="1.0" \
--generate_length=100 \
--temperature="1.0" \
--eval_type="finetuned" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path \
--tokenizer_file_path=$tokenizer_file_path >> $output_log 2>&1 &

View File

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_translation.sh"
echo "for example: bash scripts/run_translation.sh"
echo "metric method: BLEU"
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
echo "=============================================================================================================="
CUR_DIR=`pwd`
mkdir -p ms_log
output_log="${CUR_DIR}/ms_log/gpt2_translation.log"
# create file and head line
echo " | Eval log file: " > $output_log
echo $output_log >> $output_log
# checkpoint path
save_finetune_ckpt_path=""
load_pretrain_ckpt_path=""
load_eval_ckpt_path=""
# dataset path
train_data_file_path=""
eval_data_file_path=""
# tokenizer path
tokenizer_file_path=""
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_translation.py \
--device_target="Ascend" \
--device_id=4 \
--metric_method="BLEU" \
--do_train="false" \
--do_eval="true" \
--eval_type="zero-shot" \
--epoch_num=1 \
--train_data_shuffle="true" \
--eval_data_shuffle="false" \
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
--load_finetune_ckpt_path=$load_eval_ckpt_path \
--train_data_file_path=$train_data_file_path \
--eval_data_file_path=$eval_data_file_path \
--tokenizer_file_path=$tokenizer_file_path \
--generate_length=100 \
--top_k=1 \
--top_p="1.0" \
--temperature="1.0" >> $output_log 2>&1 &

View File

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (CBT) model script.
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from .GPT2_model import GPT2Model
class GPT2CBTModel(nn.Cell):
"""
GPT2CBTModel is responsible for Children's Book Test (CBT) task, i.e. CBT-CN, CBT-NE datasets.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
"""
Args:
config: the configuration of GPT-2 model
is_training (bool): `True` for train (finetune), `False` for evaluation.
use_one_hot_embeddings (bool): default False.
"""
super(GPT2CBTModel, self).__init__()
if not is_training:
config.summary_first_dropout = 0.0
self.is_training = is_training
self.d_model = config.d_model
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.vocab_size = config.vocab_size
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.lm_head = nn.Dense(config.d_model,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range),
has_bias=False).to_float(config.compute_type)
self.first_dropout = nn.Dropout(1 - config.summary_first_dropout)
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): shape with [batch_size, seq_len]
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
Returns:
lm_logits (Tensor): language model distribution with log_softmax,
shape with [batch_size, seq_len, vocab_size]
"""
output, _ = self.gpt2(input_ids, input_mask) # output shape is [batch_size, seq_len, d_model]
output = self.cast(output, self.dtype)
output = self.reshape(output, (-1, self.d_model))
output = self.first_dropout(output)
lm_logits = self.lm_head(output) # [batch_size * seq_len, vocab_size]
lm_logits = self.reshape(lm_logits, (self.batch_size, self.seq_length, self.vocab_size))
lm_logits = self.cast(lm_logits, self.dtype)
lm_logits = self.log_softmax(lm_logits)
return lm_logits
def get_lm_head(self):
return self.lm_head.weight

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (LAMBADA) model script.
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from .GPT2_model import GPT2Model
class GPT2LambadaModel(nn.Cell):
"""
GPT2LambadaModel is responsible for Lambada task, i.e. Lambada-train, Lambada-test datasets.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
"""
Args:
config: the configuration of GPT-2 model
is_training (bool): `True` for train (finetune), `False` for evaluation.
use_one_hot_embeddings (bool): default False.
"""
super(GPT2LambadaModel, self).__init__()
if not is_training:
config.hidden_dropout = 0.0
self.vocab_size = config.vocab_size
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.shape = P.Shape()
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.dense1 = nn.Dense(config.d_model,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range)).to_float(mstype.float16)
self.dropout = nn.Dropout(1 - config.hidden_dropout)
def construct(self, input_ids, input_mask):
"""
Args:
input_ids (Tensor): shape with [batch_size, seq_len]
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
Returns:
lm_logits (Tensor): language model distribution with log_softmax,
shape with [batch_size, seq_len, vocab_size]
"""
output, _ = self.gpt2(input_ids, input_mask)
output = self.cast(output, self.dtype)
output = self.dropout(output)
batch_size, seq_length, d_model = self.shape(output)
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
logits = self.dense1(output_reshape)
logits = self.cast(logits, self.dtype)
logits = self.log_softmax(logits)
lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
return lm_logits

View File

@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (Language Modeling) model script.
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from .GPT2_model import GPT2Model
class GPT2LanguageModel(nn.Cell):
"""
GPT2LanguageModel is responsible for Language Modeling task, i.e. WikiText2, WikiText103, PTB, 1BW datasets.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
"""
Args:
config: the configuration of GPT-2 model
is_training (bool): `True` for train (finetune), `False` for evaluation.
use_one_hot_embeddings (bool): default False.
"""
super(GPT2LanguageModel, self).__init__()
if not is_training:
config.hidden_dropout = 0.0
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.vocab_size = config.vocab_size
self.cast = P.Cast()
self.shape = P.Shape()
self.dtype = config.dtype
self.dense1 = nn.Dense(config.d_model,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range),
has_bias=False).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - config.hidden_dropout)
self.log_softmax = P.LogSoftmax(axis=-1)
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
where 0 indicates padding position.
Returns:
lm_logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model].
"""
output, _ = self.gpt2(input_ids, input_mask)
output = self.cast(output, self.dtype)
batch_size, seq_length, d_model = self.shape(output)
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
logits = self.dense1(output_reshape)
logits = self.cast(logits, self.dtype)
logits = self.log_softmax(logits)
lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) # [batch_size, seq_len, vocab]
return lm_logits

View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (Reading Comprehension) model script.
"""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from .GPT2_model import GPT2Model
class GPT2CoQAModel(nn.Cell):
"""
This class is responsible for CoQA
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(GPT2CoQAModel, self).__init__()
if not is_training:
config.hidden_dropout = 0.0
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.weight_init = TruncatedNormal(config.initializer_range)
self.dense1 = nn.Dense(config.d_model,
config.vocab_size,
weight_init=self.weight_init,
has_bias=False).to_float(config.compute_type)
self.log_softmax = P.LogSoftmax(axis=-1)
self.vocab_size = config.vocab_size
self.dtype = config.dtype
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
where 0 indicates padding position.
Returns:
logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model].
"""
decoder_output, _ = self.gpt2(input_ids, input_mask)
decoder_output = P.Cast()(decoder_output, self.dtype)
batch_size, seq_length, d_model = P.Shape()(decoder_output)
reshaped_ouput = P.Reshape()(decoder_output, (-1, d_model)) # [batch_size * seq_length, d_model]
logits = self.dense1(reshaped_ouput)
logits = P.Cast()(logits, self.dtype)
logits = self.log_softmax(logits)
logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
return logits

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (Summarization) model script.
"""
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from .GPT2_model import GPT2Model
class GPT2SummarizationModel(nn.Cell):
"""
GPT2SummarizationModel is responsible for summary task, i.e. cnn_dailymail datasets.
Args:
config: the configuration of GPT-2 model
is_training (bool): `True` for train (finetune), `False` for evaluation.
use_one_hot_embeddings (bool): default False.
"""
def __init__(self, config, is_training=True, use_one_hot_embeddings=False):
super(GPT2SummarizationModel, self).__init__()
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.lm_head = nn.Dense(config.d_model,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range),
has_bias=False).to_float(mstype.float16)
self.reshape = P.Reshape()
self.dtype = config.dtype
self.cast = P.Cast()
self.shape = P.Shape()
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
where 0 indicates padding position.
Returns:
lm_logits (Tensor): language model distribution without log_softmax,
shape with [batch_size, seq_len, d_model].
"""
output, _ = self.gpt2(input_ids, input_mask)
output = self.cast(output, self.dtype)
batch_size, seq_length, d_model = self.shape(output)
hidden_state = self.reshape(output, (-1, d_model))
hidden_state = self.cast(hidden_state, self.dtype)
lm_logits = self.lm_head(hidden_state)
lm_logits = self.cast(lm_logits, self.dtype)
lm_logits = self.reshape(lm_logits, (batch_size, seq_length, -1))
return lm_logits

View File

@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 downstream task (Translation) model script.
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from .GPT2_model import GPT2Model
class GPT2TranslationModel(nn.Cell):
"""
GPT2TranslationModel is responsible for translation task, i.e. WMT-14 En-Fr, WMT-14 Fr-En datasets.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
"""
Args:
config: the configuration of GPT-2 model
is_training (bool): `True` for train (finetune), `False` for evaluation.
use_one_hot_embeddings (bool): default False.
"""
super(GPT2TranslationModel, self).__init__()
if not is_training:
config.hidden_dropout = 0.0
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
self.vocab_size = config.vocab_size
self.cast = P.Cast()
self.shape = P.Shape()
self.dtype = config.dtype
self.dense1 = nn.Dense(config.d_model,
config.vocab_size,
weight_init=TruncatedNormal(config.initializer_range),
has_bias=True).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - config.hidden_dropout)
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): input sentences shape with [batch_size, seq_len]
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
Returns:
translation_logits (Tensor): language model distribution without log_softmax,
shape with [batch_size, seq_len, vocab_size]
"""
output, _ = self.gpt2(input_ids, input_mask)
output = self.cast(output, self.dtype)
output = self.dropout(output)
batch_size, seq_length, d_model = self.shape(output)
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
logits = self.dense1(output_reshape)
logits = self.cast(logits, self.dtype)
translation_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
return translation_logits

View File

@ -0,0 +1,366 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
generation class for downstream task (Summarization, Reading Comprehension, Translation)
"""
import numpy as np
from .utils.task_utils import extract_logits
from .utils.generation_utils import Sample
from .utils.tensor_manipulations import extract_string_from_tensor
INF = 1. * 1e9
class GenerateForSummarization():
"""
generate for summarization task
"""
def __init__(self,
decoder,
config=None,
tokenizer=None,
select_sentence=3,
eval_type="finetuned",
temperature=1.0,
generate_length=100,
topk=2,
topp=1.0):
self.decoder = decoder
self.config = config
self.tokenizer = tokenizer
self.select_sentence = select_sentence
self.eval_type = eval_type
self.generator = Sample(decoder,
tokenizer=tokenizer,
config=config,
topk_num=topk,
topp_prob=topp,
min_tokens_to_keep=1,
demo_mode=False,
temperature=temperature)
self.generate_length = generate_length
def generate_for_summarization(self, input_ids):
"""generation function for summarization task"""
# prepare input_str
article_str, summary_str = extract_string_from_tensor(input_ids=input_ids,
mode="pair",
config=self.config,
tokenizer=self.tokenizer)
generated_summary_list = [""] * self.config.batch_size
# clip overflow
for batch_idx in range(self.config.batch_size):
last_dot_pos = max(article_str[batch_idx].rfind(' .'), article_str[batch_idx].rfind('. ')) + 2
article_str[batch_idx] = article_str[batch_idx][:last_dot_pos]
# pad a <TL,DR;> token(<EOS>) after the string of Article.
tldr_str = "TL;DR:"
if self.eval_type == "finetuned":
for batch_idx in range(self.config.batch_size):
article_str[batch_idx] += (" " + tldr_str)
# add prefix
for batch_idx in range(self.config.batch_size):
article_str[batch_idx] = article_str[batch_idx]
generate_str_list, _ = self.generator.generate(input_str=article_str, generate_length=self.generate_length)
for batch_idx in range(self.config.batch_size):
generate_str = generate_str_list[batch_idx]
generated_summary = ""
if self.select_sentence > 0:
# check if there are number of select_sentence of sentences in generated text,
# if not enough, it will return full generated string
len_generate_str = len(generate_str)
search_index = -1
for _ in range(self.select_sentence):
search_index = generate_str.find('.', search_index + 1)
if search_index == -1 or search_index >= len_generate_str:
search_index = len_generate_str
break
# increase search_index to add period token('.') if search_index does not overflow.
search_index = search_index + 1 if search_index < len_generate_str else len_generate_str
generated_summary = generate_str[:search_index]
if generated_summary.find(self.tokenizer.eos_token) != -1:
cut_pos = generated_summary.find(self.tokenizer.eos_token, 0)
generated_summary = generated_summary[:cut_pos]
else:
generated_summary = generate_str
# if all of str hs been clipped, restore it to beginning state.
if generated_summary == '':
generated_summary = generate_str
# empty str check
if generated_summary == '':
generated_summary = '<empty>'
generated_summary_list[batch_idx] = generated_summary
return generated_summary_list, summary_str # Hypo and Ref
class GenerateForLambada():
"""
generate class for lambada task, which is to predict the final word of sentence.
"""
def __init__(self,
decoder,
config=None,
tokenizer=None,
generate_length_dynamic=True,
generate_length=1,
max_iterations=200,
stop_word_file=""):
"""
Args:
decoder: decoder (Model): GPT2 model to do generation.
config (object): configuration of given GPT2 model.
tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True.
max_iterations (int): choose the top k token according to selected probability, there k = `max_iterations`.
generate_length (int): the final word max generated token length.
stop_word_file (str): stop word file is used to be a stop-word filter.
"""
self.decoder = decoder
self.config = config
self.batch_size = config.batch_size
self.tokenizer = tokenizer
self.generate_length_dynamic = generate_length_dynamic
self.generate_length = generate_length
self.max_iterations = max_iterations
self.stop_word_set = self.build_stop_word(stop_word_file)
self.generator = Sample(decoder=decoder,
config=config,
batch_size=1,
tokenizer=tokenizer,
topk_num=1,
topp_prob=1,
return_ids=True
)
self.stop_eos = ['.', ',', '!', '?', '"', " '", " and", " says", " said"]
def build_stop_word(self, stop_word_file):
stop_words_set = set()
with open(stop_word_file, 'r', encoding="utf8") as file:
for line in file.readlines():
line = line.strip('\n')
stop_words_set.add(line)
return stop_words_set
def is_stop_word(self, word):
flag = False
if word in self.stop_word_set:
flag = True
return flag
return flag
def generate_for_lambada(self, input_ids, logits, input_length):
"""
generation function for lambada task
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
logits (Tensor): the language model distribution.
input_length (Tensor): store the context length which not including final word , and whole sentence length
return:
batch_predict_words (list): the list of predict_words
"""
batch_predict_words = ["" for _ in range(self.batch_size)]
input_len_np = input_length.asnumpy()
input_ids_list = input_ids.asnumpy().tolist()
extracted_logits = extract_logits(logits=logits, position=input_len_np) # [batch_size, vocab_size]
extracted_logits = extracted_logits.asnumpy()
sorted_ids = np.argsort(-extracted_logits, axis=-1)[::, :self.max_iterations] # [batch_size, max_iterations]
for batch_idx in range(self.batch_size):
final_word_spos = input_len_np[batch_idx, 0]
context_ids = input_ids_list[batch_idx][1:final_word_spos] # 1 for dropping <bos> token
last_word_token_num = input_len_np[batch_idx, 1] - input_len_np[batch_idx, 0]
if self.generate_length_dynamic:
generate_length = last_word_token_num
else:
generate_length = self.generate_length
for num in range(self.max_iterations):
id_ = sorted_ids[batch_idx][num]
source_ids = context_ids + [id_]
source_string = self.tokenizer.decode(source_ids)
generated_ids_list = self.generator.generate(input_str=source_string,
generate_length=generate_length,
do_sample=False)
predict_tokens_ids = [id_] + generated_ids_list[0]
predict_word = self.tokenizer.decode(predict_tokens_ids)
eos_pos = min(predict_word.find(word) if predict_word.find(word) >= 0
else INF for word in self.stop_eos)
if eos_pos == INF:
continue
else:
predict_word = predict_word[:eos_pos]
predict_word = predict_word.strip()
if predict_word.find(" ") == -1:
if self.is_stop_word(word=predict_word.lower()):
continue
batch_predict_words[batch_idx] = predict_word
print("predict word: {}".format(predict_word))
break
return batch_predict_words
class GenerateForTranslation():
"""
generate class for translation task
"""
def __init__(self,
decoder,
config=None,
tokenizer=None,
generate_length=1,
use_hint=True,
select_first_sentence=True,
topk_num=None,
topp_prob=None,
temperature=None
):
self.decoder = decoder
self.config = config
self.batch_size = config.batch_size
self.tokenizer = tokenizer
self.generate_length = generate_length
self.use_hint = use_hint
self.select_first_sentence = select_first_sentence
self.generator = Sample(decoder=decoder,
config=config,
tokenizer=tokenizer,
topk_num=topk_num,
topp_prob=topp_prob,
temperature=temperature,
min_tokens_to_keep=1,
early_stop=False)
def generate_for_translation(self, input_ids):
"""generation function for translation task"""
source_str_list, ref_str_list = extract_string_from_tensor(input_ids=input_ids,
mode="pair",
config=self.config,
tokenizer=self.tokenizer)
final_predict_translation_list = [""] * self.batch_size
if self.use_hint:
for index in range(self.batch_size):
source_str_list[index] += " =" # now source_str is "english sentence ="
translation_str_list, _ = self.generator.generate(input_str=source_str_list,
generate_length=self.generate_length,
do_sample=False)
for index in range(self.batch_size):
generate_str = translation_str_list[index].replace('<|endoftext|>', '')
predict_translation = ""
# According to the GPT2 paper, the select_first_sentence will be set "True"
if self.select_first_sentence:
# check if there are number of select_sentence of sentences in generated text,
# if not enough, it will return full generated string
search_index = generate_str.find('.', 0, len(generate_str))
if search_index == -1:
search_index = len(generate_str)
else:
search_index = search_index + 1
predict_translation = generate_str[:search_index]
else:
predict_translation = generate_str
if predict_translation == '':
predict_translation = '<empty>'
final_predict_translation_list[index] = predict_translation
return final_predict_translation_list, ref_str_list
class GenerateForReadComprehension():
"""
generate class for Reading Comprehension task.
Args:
decoder: decoder (Model): GPT2 model to do generation.
config (object): configuration of given GPT2 model.
tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
generate_length (int):
"""
def __init__(self,
decoder,
config=None,
tokenizer=None,
generate_length=1,
topk_num=None,
topp_prob=None,
temperature=None
):
self.decoder = decoder
self.config = config
self.batch_size = config.batch_size
self.tokenizer = tokenizer
self.generate_length = generate_length
self.generator = Sample(decoder=decoder,
config=config,
tokenizer=tokenizer,
topk_num=topk_num,
topp_prob=topp_prob,
temperature=temperature,
min_tokens_to_keep=1,
)
def generate_for_read_comprehension(self, input_ids):
"""generation function for reading comprehension task"""
passage_str_list, answer_str_list = extract_string_from_tensor(input_ids=input_ids,
mode="pair",
config=self.config,
tokenizer=self.tokenizer)
passage = passage_str_list[:]
generate_str_list, _ = self.generator.generate(input_str=passage_str_list,
generate_length=self.generate_length,
do_sample=False)
pred_answer = []
for batch_id in range(self.batch_size):
new_str = generate_str_list[batch_id].replace('<|endoftext|>', '')
index_a = new_str.find('.')
index_b = new_str.find('Q:')
if index_a != -1 or index_b != -1:
index = max(index_a, index_b)
pred_answer += [new_str[1:index]] # 1 represents skip the space in the beginning of the sentence
else:
pred_answer += [new_str]
return passage, pred_answer, answer_str_list

View File

@ -0,0 +1,896 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT-2 base model
"""
import math
import copy
import numpy as np
import mindspore
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .weight_init import normal_weight, zero_weight
class GPT2Config:
"""
Configuration for `GPT2Model`.
Args:
batch_size (int): Batch size of input dataset. Default: 512.
seq_length (int): Length of input sequence. Default: 1024.
vocab_size (int): The shape of each embedding vector. Default: 50257.
d_model (int): Size of the bert encoder layers. Default: 768.
num_hidden_layers (int): Number of hidden layers in the GPT2Transformer decoder block. Default: 12.
num_attention_heads (int): Number of attention heads in the GPT2Transformer decoder block. Default: 12.
intermediate_size (int): Size of intermediate layer in the GPT2Transformer decoder block. Default: 3072.
hidden_act (str): Activation function used in the GPT2Transformer decoder block. Default: "gelu".
hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1.
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1.
max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from dataset.
Default: True.
summary_first_dropout (float): The dropout probability for GPT2CBTModel. Default: 0.1.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in GPT2Transformer. Default: mstype.float16.
"""
def __init__(self,
batch_size=512,
seq_length=1024,
vocab_size=50257,
d_model=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1024,
initializer_range=0.02,
input_mask_from_dataset=True,
summary_first_dropout=0.1,
dtype=mstype.float32,
compute_type=mstype.float16,
):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.d_model = d_model
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.input_mask_from_dataset = input_mask_from_dataset
self.summary_first_dropout = summary_first_dropout
self.dtype = dtype
self.compute_type = compute_type
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_dim (int): The size of each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
vocab_size,
embedding_dim,
use_one_hot_embeddings=False,
compute_type=mstype.float16):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.use_one_hot_embeddings = use_one_hot_embeddings
self.compute_type = compute_type
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_dim], embedding_dim),
name='embedding_table')
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.GatherV2()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
def construct(self, input_ids):
"""
get embedding according to input_ids.
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
Returns:
output (Tensor): the embedding matrix according to the input_ids.
self.embedding_table (Parameter): the whole embedding table of GPT-2 model.
"""
input_shape = self.shape(input_ids) # [batch_size, seq_length]
flat_ids = self.reshape(input_ids, self.shape_flat) # [batch_size * seq_length]
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
# precision transition fp32 -> fp16
one_hot_ids = self.cast(one_hot_ids, self.compute_type)
self.embedding_table = self.cast(self.embedding_table, self.compute_type)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
output_for_reshape = self.cast(output_for_reshape, mstype.float32)
else:
# [batch_size * seq_length * embedding_dim]
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
out_shape = input_shape + (self.embedding_dim,)
output = self.reshape(output_for_reshape, out_shape) # [batch_size, seq_length, embedidng_dim]
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional embeddings to word embeddings.
Args:
embedding_dim (int): The size of each embedding vector.
seq_length (int): the length of input sequence.
max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024.
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
embedding_dim=None,
seq_length=None,
max_position_embeddings=1024,
dropout_prob=0.1):
super(EmbeddingPostprocessor, self).__init__()
self.position_embedding_table = Parameter(
normal_weight([max_position_embeddings, embedding_dim], embedding_dim), name='position_embeddings')
self.expand_dims = P.ExpandDims()
self.add = P.TensorAdd()
self.gather = P.GatherV2()
self.input_indices = Tensor(np.array([x for x in range(seq_length)]), mindspore.int32)
self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32)
self.use_dropout = dropout_prob > 0
def construct(self, word_embeddings):
"""
Add the position embedding table to token embedding table
Args:
word_embeddings (Tensor): the token embedding matrix
Returns:
output (Tensor): the final embedding matrix by adding the position embedding table
to token embedding table.
"""
position_embeddings = self.gather(self.position_embedding_table, self.input_indices, 0)
position_embeddings = self.expand_dims(position_embeddings, 0)
output = self.add(word_embeddings, position_embeddings)
if self.use_dropout:
output = self.dropout(output)
return output
class CastWrapper(nn.Cell):
"""
Cast wrapper
"""
def __init__(self,
dst_type=mstype.float32):
super(CastWrapper, self).__init__()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
"""
type cast
Args:
x (Tensor): the input which need to be cast.
Returns:
Tensor, the cast output.
"""
return self.cast(x, self.dst_type)
class LayerNorm(nn.Cell):
"""
Do layer norm
Args:
in_channels (int): In channels number of layer norm
"""
def __init__(self,
in_channels=None):
super(LayerNorm, self).__init__()
self.layer_norm = nn.LayerNorm((in_channels,))
self.cast = P.Cast()
self.get_dtype = P.DType()
def construct(self, input_tensor):
"""
layer norm
Args:
input_tensor (Tensor): the input of layernorm.
Returns:
Tensor, the output after layernorm.
"""
output = self.cast(input_tensor, mstype.float32)
output = self.layer_norm(output)
output = self.cast(output, self.get_dtype(input_tensor))
return output
class ResidualConnection(nn.Cell):
"""
Add residual to output.
Args:
dropout_prob (float): Dropout rate.
"""
def __init__(self, dropout_prob=0.0):
super(ResidualConnection, self).__init__()
self.add = P.TensorAdd()
self.dropout = nn.Dropout(1 - dropout_prob)
self.use_dropout = dropout_prob > 0
def construct(self, hidden_tensor, input_tensor):
"""
Args:
hidden_tensor (Tensor): the output of sublayer.
input_tensor (Tensor): the input tensor.
Returns:
output (Tensor): with the same shape of hidden_tensor.
"""
output = hidden_tensor
if self.use_dropout:
output = self.dropout(output)
output = self.add(output, input_tensor)
return output
class Conv1D(nn.Cell):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nx (int): The number of input features.
nf (int): The number of output features.
"""
def __init__(self,
nx,
nf):
super(Conv1D, self).__init__()
self.nx = nx
self.nf = nf
self.weight = Parameter(normal_weight([nx, nf], nf), name='projection_weight')
self.bias = Parameter(zero_weight(nf), name='projection_bias')
self.matmul = P.MatMul()
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
def construct(self, input_tensor):
"""
Args:
input_tensor (Tensor): the input tensor of Conv1D with shape [batch_size * seq_length, nx]
Returns:
output_tensor (Tensor): the output tensor with shape [batch_size * seq_length, self.nf]
"""
# precision transition fp32 -> fp16
input_tensor = self.cast(input_tensor, mstype.float16)
fp16_weight = self.cast(self.weight, mstype.float16)
output_tensor = self.matmul(input_tensor, fp16_weight) # [batch_size * seq_length, self.nf]
output_tensor = self.cast(output_tensor, mstype.float32)
output_tensor = self.bias_add(output_tensor, self.bias)
return output_tensor
class MaskedSelfAttention(nn.Cell):
"""
Apply masked multi-head attention.
Args:
batch_size (int): Batch size of input datasets. Default: 512.
d_model (int): Size of last dim of input tensor. Default: 768.
seq_length (int): Length of input tensor sequence. Default: 1024.
num_attention_heads (int): Number of attention heads. Default: 12.
dim_per_head (int): Size of each attention head. Default: 64.
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
attention_dropout (float): The dropout probability for MultiheadAttention. Default: 0.0.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
Returns:
Tensor, with the shape [batch_size, seq_length, d_model]
"""
def __init__(self,
batch_size=512,
d_model=768,
seq_length=1024,
num_attention_heads=12,
dim_per_head=64,
has_attention_mask=True,
do_return_2d_tensor=True,
attention_dropout=0.0,
compute_type=mstype.float16):
super(MaskedSelfAttention, self).__init__()
self.batch_size = batch_size
self.d_model = d_model
self.seq_length = seq_length
self.num_heads = num_attention_heads
self.dim_per_head = dim_per_head
self.has_attention_mask = has_attention_mask
self.compute_type = compute_type
assert has_attention_mask
self.scale = Tensor([1.0 / math.sqrt(float(self.dim_per_head))], dtype=compute_type) # attention scale
self.mask_data = Tensor([-10000.0,], dtype=compute_type)
self.split_head_shape = (-1, self.seq_length, self.num_heads, self.dim_per_head)
self.c_attn = Conv1D(d_model, d_model * 3)
self.c_proj = Conv1D(d_model, d_model)
self.split_for_qkv = P.Split(1, 3)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.matmul = P.BatchMatMul()
self.multiply = P.Mul()
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.TensorAdd()
self.cast = P.Cast()
self.get_dtype = P.DType()
if do_return_2d_tensor:
self.shape_return = (-1, d_model)
else:
self.shape_return = (-1, seq_length, d_model)
self.softmax = nn.Softmax()
self.softmax_cast = P.Cast()
self.dropout = nn.Dropout(1 - attention_dropout)
self.use_attention_dropout = attention_dropout > 0
def construct(self, input_tensor, attention_mask):
"""
do masked self-attention
Args:
input_tensor (Tensor): the embedding of input sequence tokens,
shape with [batch_size * seq_length, d_mdoel]
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
shape with [batch_size, seq_len, seq_len].
Returns:
outputs (Tensor): the output of masked self-attention, shape with [batch_size * seq_len, d_model].
"""
input_tensor = self.c_attn(input_tensor) # [batch_size * seq_length, d_model*3]---> eg.[1 * 3, 2304]
input_tensor = self.split_for_qkv(input_tensor)
query = input_tensor[0] # [batch_size * seq_length, d_model] ---> eg. [1 * 3, 768]
key = input_tensor[1]
value = input_tensor[2]
# split head
query = self.reshape(query, self.split_head_shape)
# query shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
query = self.transpose(query, self.trans_shape)
key = self.reshape(key, self.split_head_shape)
# key shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
key = self.transpose(key, self.trans_shape)
value = self.reshape(value, self.split_head_shape)
# value shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
value = self.transpose(value, self.trans_shape)
# attention and mask
# precision transition fp32 -> fp16
query = self.cast(query, self.compute_type)
key = self.cast(key, self.compute_type)
attention_scores = self.matmul_trans_b(query, key) # [batch_size, num_heads, seq_len, seq_len]
attention_scores = self.cast(attention_scores, self.compute_type)
attention_scores = self.multiply(attention_scores, self.scale)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1) # [batch_size, 1, seq_length, seq_length]
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores))) # fp16
adder = self.multiply(multiply_out, self.mask_data)
adder = self.cast(adder, mstype.float32)
attention_scores = self.cast(attention_scores, mstype.float32)
attention_scores = self.add(adder, attention_scores)
attention_scores = self.softmax_cast(attention_scores, mstype.float32)
attention_probs = self.softmax(attention_scores) # [batch_size, num_heads, seq_len, seq_len]
attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key))
if self.use_attention_dropout:
attention_probs = self.dropout(attention_probs)
value = self.cast(value, mstype.float16)
attention_probs = self.cast(attention_probs, self.compute_type)
outputs = self.matmul(attention_probs, value) # [batch_size, num_heads, seq_len, dim_per_head]
outputs = self.cast(outputs, mstype.float32)
# merge heads
outputs = self.transpose(outputs, self.trans_shape) # [batch_size, seq_len, num_heads, dim_per_head]
outputs = self.reshape(outputs,
self.shape_return) # default True, the outputs shape [batch_size * seq_len, d_model]
# project
outputs = self.c_proj(outputs)
return outputs
class FeedForward(nn.Cell):
"""
Apply two-layer feed forward
Args:
in_channels (int): Size of the input layer. Default: 768.
out_channels (int): Size of the output layers. Default: 768.
hidden_size (int): Size of the hidden layer. Default: 3072.
hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1.
"""
def __init__(self,
in_channels=786,
out_channels=768,
hidden_size=3072,
hidden_dropout=0.1):
super(FeedForward, self).__init__()
self.c_fc = Conv1D(in_channels, hidden_size)
self.c_proj = Conv1D(hidden_size, out_channels)
# self.gelu = Gelu()
self.layernorm = LayerNorm(in_channels=in_channels)
self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout)
self.gelu_act = P.Gelu()
self.dropout = nn.Dropout(1 - hidden_dropout)
self.use_dropout = hidden_dropout > 0
self.reshape = P.Reshape()
def construct(self, input_tensor):
"""
FeedForward construct function with layernorm and residual connection.
Args:
input_tensor (Tensor): the input of FeedForward layer, shape with [batch_szie * seq_len, d_model].
Returns:
output (Tensor): the output of FeedForward layer, shape with [batch_szie * seq_len, d_model]
"""
# LayerNorm
output = self.layernorm(input_tensor)
# Feed Forward
output = self.c_fc(output) # [batch_szie * seq_len, d_model * 4]
output = self.gelu_act(output)
# output = self.gelu(output)
output = self.c_proj(output) # [batch_szie * seq_len, d_model]
if self.use_dropout:
output = self.dropout(output)
# Add
output = self.residual_connect(output, input_tensor)
return output
class MaskedMultiHeadAttention(nn.Cell):
"""
Masked multi-head attention block.
"""
def __init__(self,
batch_size=512,
seq_length=2014,
d_model=768,
num_attention_heads=12,
attention_dropout=0.02,
hidden_dropout=0.1,
has_attention_mask=True,
compute_type=mstype.float16
):
super(MaskedMultiHeadAttention, self).__init__()
if d_model % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (d_model, num_attention_heads))
self.dim_per_head = int(d_model / num_attention_heads) # 64
self.masked_self_attention = MaskedSelfAttention(
batch_size=batch_size,
d_model=d_model,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
dim_per_head=self.dim_per_head,
has_attention_mask=has_attention_mask,
do_return_2d_tensor=True,
attention_dropout=attention_dropout,
compute_type=compute_type
)
self.layer_norm = LayerNorm(in_channels=d_model)
self.residual_connection = ResidualConnection()
self.reshape = P.Reshape()
self.new_shape = (-1, d_model)
def construct(self, input_tensor, attention_mask):
"""
do masked multi head self-attention with layernorm and residual_connection.
Args:
input_tensor (Tensor): the embedding matrix of input sequence tokens,
shape with [batch_size * seq_length, d_mdoel]
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
shape with [batch_size, seq_len, seq_len].
Returns:
outputs (Tensor): the output of MaskedMultiHeadAttention, shape with [batch_size * seq_len, d_model].
"""
# LayerNorm
output_tensor = self.layer_norm(input_tensor)
# masked multi-head attention
# attention_output shape [batch_size * seq_length, d_model]
attention_output = self.masked_self_attention(output_tensor, attention_mask)
# residual connection
output = self.residual_connection(attention_output, input_tensor)
return output
class DecoderBlock(nn.Cell):
"""
decoder block used in GPT2.
Args:
batch_size (int): Batch size of input dataset. Default: 512.
seq_length (int): Length of input sequence. Default: 1024.
d_model (int): Size of the GPT2 decoder layers. Default: 768.
num_attention_heads (int): Number of attention heads. Default: 12.
intermediate_size (int): Size of intermediate layer. Default: 3072.
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.02.
hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1.
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size=512,
seq_length=1024,
d_model=768,
num_attention_heads=12,
intermediate_size=3072,
attention_dropout=0.02,
hidden_dropout=0.1,
has_attention_mask=True,
compute_type=mstype.float16
):
super(DecoderBlock, self).__init__()
if d_model % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (d_model, num_attention_heads))
self.dim_per_head = int(d_model / num_attention_heads) # 64
self.masked_multi_head_attention = MaskedMultiHeadAttention(
batch_size=batch_size,
seq_length=seq_length,
d_model=d_model,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
hidden_dropout=hidden_dropout,
has_attention_mask=has_attention_mask,
compute_type=compute_type
)
self.feedforward = FeedForward(
in_channels=d_model,
out_channels=d_model,
hidden_size=intermediate_size,
hidden_dropout=hidden_dropout
)
self.reshape = P.Reshape()
self.new_shape = (-1, d_model)
def construct(self, input_tensor, attention_mask): # input tensor shape[batch_size, seq_length, d_model]
"""
DecoderBlock with masked_multi_head_attention and feedforward.
Args:
input_tensor (Tensor): the embedding matrix of input sequence tokens,
shape with [batch_size * seq_length, d_mdoel]
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
shape with [batch_size, seq_len, seq_len].
Returns:
outputs (Tensor): the output of DecoderBlock, shape with [batch_size * seq_len, d_model].
"""
input_tensor = self.reshape(input_tensor, self.new_shape)
# masked multi head attention with ln, res
attention_output = self.masked_multi_head_attention(input_tensor, attention_mask)
# feed forward with ln, res
output = self.feedforward(attention_output)
return output
class GPT2Transformer(nn.Cell):
"""
Multi-layer GPT2 transformer.
Args:
batch_size (int): Batch size of input dataset. Default: 512.
d_model (int): Size of the decoder layers. Default: 768.
seq_length (int): Length of input sequence. Default: 1024.
num_hidden_layers (int): Number of hidden layers in decoder cells. Default: 12.
num_attention_heads (int): Number of attention heads in decoder cells. Default: 12.
intermediate_size (int): Size of intermediate layer in decoder cells. Default: 3072.
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1.
hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
batch_size=512,
d_model=768,
seq_length=1024,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
has_attention_mask=True,
attention_dropout=0.1,
hidden_dropout=0.1,
compute_type=mstype.float16):
super(GPT2Transformer, self).__init__()
layers = []
for _ in range(num_hidden_layers):
layer = DecoderBlock(batch_size=batch_size,
seq_length=seq_length,
d_model=d_model,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_dropout=attention_dropout,
hidden_dropout=hidden_dropout,
has_attention_mask=has_attention_mask,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
self.new_shape = (-1, d_model)
# self.out_shape = (batch_size, seq_length, d_model)
self.out_shape = (-1, seq_length, d_model)
def construct(self, input_tensor, attention_mask):
"""
Do Multi DecoderBlock.
Args:
input_tensor (Tensor): the embedding matrix of input sequence tokens,
shape with [batch_size * seq_length, d_mdoel]
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
shape with [batch_size, seq_len, seq_len].
Returns:
outputs (Tensor): the output of GPT2Transformer, shape with [batch_size * seq_len, d_model].
"""
prev_output = self.reshape(input_tensor, self.new_shape)
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask)
prev_output = layer_output
output = self.reshape(prev_output, self.out_shape)
return output
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (Class): Configuration for GPT2Model.
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.input_mask = None
self.compute_type = config.compute_type
assert self.input_mask_from_dataset
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.BatchMatMul()
self.multiply = P.Mul()
# mask future positions
ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))
self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)
def construct(self, input_mask, mask_future=True):
"""
Construct network.
Args:
input_mask (Tensor): Tensor mask vectors with shape [batch_size, seq_len].
mask_future (bool): Whether mask future (for decoder training). Default: True.
Returns:
attention_mask (Tensor): shape [batch_size, seq_len, seq_len].
"""
input_shape = self.shape(input_mask)
shape_right = (input_shape[0], 1, input_shape[1]) # [batch_size, 1, seq_len]
shape_left = input_shape + (1,) # [batch_size, seq_len, 1]
input_mask = self.cast(input_mask, mstype.float32)
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
# precision transition fp32 -> fp16
mask_left = self.cast(mask_left, self.compute_type)
mask_right = self.cast(mask_right, self.compute_type)
attention_mask = self.matmul(mask_left, mask_right) # [batch_szie, seq_len, seq_len]
attention_mask = self.cast(attention_mask, mstype.float32)
if mask_future:
attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)
return attention_mask
class GPT2Model(nn.Cell):
"""
Decoder Representations from Transformers.
Args:
config (Class): Configuration for GPT2Model.
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
config,
is_training,
use_one_hot_embeddings=False
):
super(GPT2Model, self).__init__()
self.config = copy.deepcopy(config)
self.is_training = is_training
if not is_training:
self.config.hidden_dropout = 0.0
self.config.attention_dropout = 0.0
self.input_mask_from_dataset = self.config.input_mask_from_dataset
self.batch_size = self.config.batch_size
self.seq_length = self.config.seq_length
self.d_model = self.config.d_model
self.num_hidden_layers = self.config.num_hidden_layers
self.embedding_dim = self.config.d_model
self.last_idx = self.num_hidden_layers - 1
self.gpt2_embedding_lookup = EmbeddingLookup(
vocab_size=self.config.vocab_size,
embedding_dim=self.embedding_dim,
use_one_hot_embeddings=use_one_hot_embeddings,
compute_type=self.config.compute_type
)
self.gpt2_embedding_postprocess = EmbeddingPostprocessor(
embedding_dim=self.embedding_dim,
seq_length=self.seq_length,
max_position_embeddings=self.config.max_position_embeddings,
dropout_prob=self.config.hidden_dropout
)
self.gpt2_decoder = GPT2Transformer(
batch_size=self.batch_size,
d_model=self.d_model,
seq_length=self.seq_length,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.config.num_attention_heads,
intermediate_size=self.config.intermediate_size,
has_attention_mask=True,
attention_dropout=self.config.attention_dropout,
hidden_dropout=self.config.hidden_dropout,
compute_type=self.config.compute_type
)
self.cast_compute_type = CastWrapper(dst_type=self.config.compute_type)
self.layer_norm = LayerNorm(in_channels=self.d_model)
self.dropout = nn.Dropout(1 - self.config.hidden_dropout)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.config)
self.reshape = P.Reshape()
self.new_shape = (-1, self.d_model)
def construct(self, input_ids, input_mask):
"""
Construct network.
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
where 0 indicates padding position.
Returns:
decoder_output (Tensor): shape[batch_size, seq_len, d_model].
embedding_tables (Tensor): word embeddings with shape [vocab_size, d_model]
"""
# Embedding
word_embeddings, embedding_tables = self.gpt2_embedding_lookup(input_ids)
embedding_output = self.gpt2_embedding_postprocess(word_embeddings)
embedding_output = self.dropout(embedding_output)
# Attention mask with shape [batch_size, seq_len, seq_len]
attention_mask = self._create_attention_mask_from_input_mask(input_mask, True)
# GPT2 decoder
decoder_output = self.gpt2_decoder(
self.cast_compute_type(embedding_output),
self.cast_compute_type(attention_mask)
)
# LayerNorm
decoder_output = self.reshape(decoder_output, self.new_shape)
decoder_output = self.layer_norm(decoder_output)
decoder_output = self.reshape(decoder_output, (-1, self.seq_length, self.d_model))
return decoder_output, embedding_tables
def get_token_embeddings(self):
return self.gpt2_embedding_lookup.embedding_table.asnumpy()

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""clip gradient"""
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore.ops import composite as C
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
# pylint: disable=consider-using-in
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type != 0 and clip_type != 1:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad

View File

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations"""
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
from .finetune_eval_config import gpt2_net_cfg
def create_language_model_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""):
"""create dataset like language model task"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.MindDataset(dataset_path,
columns_list=["input_ids", "input_mask", "label_ids"],
shuffle=do_shuffle,
num_shards=device_num,
shard_id=rank_id)
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
ds = ds.map(operations=type_cast_op, input_columns="label_ids")
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
print("dataset size: {}".format(ds.get_dataset_size()))
print("repeat count: {}".format(ds.get_repeat_count()))
print("output shape: {}".format(ds.output_shapes()))
print("output type: {}".format(ds.output_types()))
print("============== create dataset successful ===============")
return ds
def create_cbt_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=False, dataset_path=""):
"""create dataset for cbt task"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.MindDataset(dataset_path,
columns_list=["input_ids", "input_mask", "input_length", "mc_labels"],
shuffle=do_shuffle,
num_shards=device_num,
shard_id=rank_id)
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
ds = ds.map(operations=type_cast_op, input_columns="input_length")
ds = ds.map(operations=type_cast_op, input_columns="mc_labels")
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
print("dataset size: {}".format(ds.get_dataset_size()))
print("repeat count: {}".format(ds.get_repeat_count()))
print("output shape: {}".format(ds.output_shapes()))
print("output type: {}".format(ds.output_types()))
print("============== create CBT LM dataset successful ===============")
return ds
def create_lambada_control_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""):
"""create dataset for lambada task"""
type_cast_op = C.TypeCast(mstype.int32)
ds = de.MindDataset(dataset_path,
columns_list=["input_ids", "input_mask", "input_length"],
shuffle=do_shuffle,
num_shards=device_num,
shard_id=rank_id)
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
ds = ds.map(operations=type_cast_op, input_columns="input_length")
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
print("dataset size: {}".format(ds.get_dataset_size()))
print("repeat count: {}".format(ds.get_repeat_count()))
print("output shape: {}".format(ds.output_shapes()))
print("output type: {}".format(ds.output_types()))
print("============== create dataset successful ===============")
return ds

View File

@ -0,0 +1,104 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GPT-2 finetune config and GPT-2 model config"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from .GPT2_model import GPT2Config
cfg = edict({
'gpt2_network': 'large',
'optimizer': 'Lamb',
'AdamWeightDecay': edict({
'learning_rate': 1e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
}),
'Lamb': edict({
'learning_rate': 1e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
"""
three kinds of GPT2 model version
"""
if cfg.gpt2_network == 'small':
gpt2_net_cfg = GPT2Config(
batch_size=8,
seq_length=1024,
vocab_size=50257,
d_model=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1024,
initializer_range=0.02,
input_mask_from_dataset=True,
summary_first_dropout=0.1,
dtype=mstype.float32,
compute_type=mstype.float16,
)
if cfg.gpt2_network == 'medium':
gpt2_net_cfg = GPT2Config(
batch_size=8,
seq_length=1024,
vocab_size=50257,
d_model=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1024,
initializer_range=0.02,
input_mask_from_dataset=True,
summary_first_dropout=0.1,
dtype=mstype.float32,
compute_type=mstype.float16,
)
if cfg.gpt2_network == 'large':
gpt2_net_cfg = GPT2Config(
batch_size=6,
seq_length=1024,
vocab_size=50257,
d_model=1280,
num_hidden_layers=36,
num_attention_heads=20,
intermediate_size=5120,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1024,
initializer_range=0.02,
input_mask_from_dataset=True,
summary_first_dropout=0.1,
dtype=mstype.float32,
compute_type=mstype.float16,
)

View File

@ -0,0 +1,464 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GPT-2 finetune for downstream task"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from .utils.CrossEntropy import CrossEntropyCalculationWithMask
from .clip_grad_utils import clip_grad
from .GPT2ForLanguageModel import GPT2LanguageModel
from .GPT2ForLambada import GPT2LambadaModel
from .GPT2ForCBT import GPT2CBTModel
from .GPT2ForTranslation import GPT2TranslationModel
from .GPT2ForReadComprehension import GPT2CoQAModel
from .GPT2ForSummarization import GPT2SummarizationModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class GPT2FinetuneCell(nn.Cell):
"""
Specifically defined for finetuning where only three inputs tensor are needed.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(GPT2FinetuneCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.gpu_target = False
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
def construct(self,
input_ids,
input_mask,
label_ids,
sens=None):
"""
GPT2 Finetune.
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
"""
weights = self.weights
init = False
loss = self.network(input_ids,
input_mask,
label_ids)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
if not self.gpu_target:
init = self.alloc_status()
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
label_ids,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
grads = self.grad_reducer(grads)
if not self.gpu_target:
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
return F.depend(ret, succ)
class GPT2LM(nn.Cell):
"""
Train interface for Language Modeling finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
super(GPT2LM, self).__init__()
self.gpt2 = GPT2LanguageModel(config, is_training, use_one_hot_embeddings)
self.num_labels = config.vocab_size
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
num_labels=self.num_labels,
config=config)
self.is_training = is_training
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
def construct(self, input_ids, input_mask, label_ids):
"""
construct function for Language Modeling
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
Returns:
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
if self.is_training:
shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
label_ids = label_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss(shift_logits, label_ids, input_mask)
return loss
return lm_logits
class GPT2Lambada(nn.Cell):
"""
Train interface for Lambada finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(GPT2Lambada, self).__init__()
self.gpt2 = GPT2LambadaModel(config, is_training, use_one_hot_embeddings)
self.num_labels = config.vocab_size
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
num_labels=self.num_labels,
config=config)
self.is_training = is_training
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
def construct(self, input_ids, input_mask, label_ids=None):
"""
construct function for Lambada task
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
Returns:
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
if self.is_training:
shift_logits = lm_logits[:, :-1, :] # [batch_size, seq_length - 1, vocab_size]
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
label_ids = label_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss(shift_logits, label_ids, input_mask)
return loss
return lm_logits
class GPT2CBT(nn.Cell):
"""
Train interface for Children's Book Test finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
super(GPT2CBT, self).__init__()
self.gpt2 = GPT2CBTModel(config, is_training, use_one_hot_embeddings)
self.num_labels = config.vocab_size
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
num_labels=self.num_labels,
config=config)
self.is_training = is_training
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
def construct(self, input_ids, input_mask):
"""
construct function for CBT task
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
Returns:
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
if self.is_training:
shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
label_ids = input_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss(shift_logits, label_ids, input_mask)
return loss
return lm_logits
class GPT2Translation(nn.Cell):
"""
Train interface for Translation finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(GPT2Translation, self).__init__()
self.gpt2 = GPT2TranslationModel(config, is_training, use_one_hot_embeddings)
self.num_labels = config.vocab_size
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
num_labels=self.num_labels,
config=config)
self.is_training = is_training
self.log_softmax = P.LogSoftmax(axis=-1)
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, input_ids, input_mask, label_ids):
"""
construct function for Translation task
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
Returns:
translation_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
translation_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
translation_logits = self.log_softmax(translation_logits)
if self.is_training:
shift_logits = translation_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
label_ids = label_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss(shift_logits, label_ids, input_mask)
return loss
return translation_logits
class GPT2Summarization(nn.Cell):
"""
Train interface for Summary finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
super(GPT2Summarization, self).__init__()
self.gpt2 = GPT2SummarizationModel(config, is_training, use_one_hot_embeddings)
self.is_training = is_training
self.last_idx = (-1,)
self.log_softmax = P.LogSoftmax(axis=-1)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.vocab_size = config.vocab_size
self.cast = P.Cast()
self.loss_function = CrossEntropyCalculationWithMask(num_labels=self.vocab_size,
is_training=self.is_training,
config=config)
def construct(self, input_ids, input_mask, label_ids):
"""
construct function for Language Modeling
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
Returns:
loss (mstype.float32): if is_training is True, return the computed loss.
"""
output = self.gpt2(input_ids, input_mask)
shift_logits = output[::, :-1, ::]
shift_logits = self.reshape(shift_logits, (-1, self.vocab_size))
shift_logits = self.log_softmax(shift_logits)
label_ids = label_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss_function(shift_logits, label_ids, input_mask)
return loss
class GPT2CoQA(nn.Cell):
"""
Train interface for Reading Comprehension finetuning task.
Args:
config (class): the configuration of GPT-2 model.
is_training (bool): whether to train.
use_one_hot_embeddings (bool): whether to use onehot embeddings.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(GPT2CoQA, self).__init__()
self.gpt2 = GPT2CoQAModel(config, is_training, use_one_hot_embeddings)
self.num_labels = config.vocab_size
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
num_labels=self.num_labels,
config=config)
self.is_training = is_training
self.reshape = P.Reshape()
self.log_softmax = P.LogSoftmax(axis=-1)
def construct(self, input_ids, input_mask, label_ids=None):
"""
construct function for reading comprehension task
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
Returns:
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
lm_logits = self.gpt2(input_ids, input_mask)
lm_logits = self.log_softmax(lm_logits)
if self.is_training:
shift_logits = lm_logits[::, :-1, ::]
shift_logits = self.reshape(shift_logits, (-1, self.num_labels))
label_ids = label_ids[::, 1:]
input_mask = input_mask[::, 1:]
loss = self.loss(shift_logits, label_ids, input_mask)
return loss
return lm_logits

View File

@ -0,0 +1,82 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Calculate Cross Entropy With Mask"""
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
import mindspore.nn as nn
class CrossEntropyCalculationWithMask(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, is_training=None, num_labels=None, config=None):
super(CrossEntropyCalculationWithMask, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.is_training = is_training
self.num_labels = num_labels
if config is not None:
# for PPL calculation in evaluation
self.input_mask_length = Tensor(config.batch_size * (config.seq_length - 1), mstype.float32)
def construct(self, logits, label_ids, input_mask=None):
"""
Calculate loss
Args:
logits (Tensor): the probability distribution over vocabulary.
label_ids (Tensor): the indices of input sequence tokens in the vocabulary.
input_mask (Tensor): input sentences padding mask, where 0 indicates padding position.
Returns:
return_value (Tensor, mstype.float32): if is_training is False, directly return the logits, otherwise,
return the computed loss.
"""
# logits [batch * (seq_length-1), vocab_size] label_ids [batch, seq_length-1]
if self.is_training:
label_ids = self.reshape(label_ids, self.last_idx) # label_ids [batch * (seq_length-1)]
one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value,
self.off_value) # [batch * (seq_length-1), vocab_size]
per_example_loss = self.neg(
self.reduce_sum(one_hot_labels * logits, self.last_idx)) # [batch * (seq_length-1)]
# for PPL calculation in evaluation
if input_mask is not None:
input_mask = self.cast(self.reshape(input_mask, self.last_idx),
mstype.float32) # [batch * (seq_length-1)]
valid_loss_sum = self.reduce_sum(input_mask * per_example_loss, ())
valid_element_sum = self.reduce_sum(input_mask, ()) + self.cast(F.tuple_to_array((1e-5,)),
mstype.float32)
loss = valid_loss_sum / valid_element_sum
else:
loss = self.reduce_mean(per_example_loss, self.last_idx) # a number
return_value = self.cast(loss, mstype.float32)
else:
return_value = logits * 1.0 # [batch * (seq_length-1), vocab_size]
return return_value

View File

@ -0,0 +1,488 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data preprocess for downstream task"""
import re
import json
import random
def lambada_detokenizer(string):
string = re.sub(r"``", "-DQ-", string)
string = re.sub(r"`", "-SQ-", string)
string = re.sub(r"''", "-DQ-", string)
string = re.sub(r" '", "-SQ-", string)
string = re.sub("-DQ-", '"', string)
string = re.sub("-SQ-", "'", string)
string = re.sub(r"([,?!.]['\"])(\w)", "\g<1> \g<2>", string)
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
string = string.replace(" 'd", "'d")
string = string.replace(" '", "'")
string = string.replace(" n't", "n't")
string = string.replace(" .", ".")
string = string.replace(" ,", ",")
string = string.replace(" !", "!")
string = string.replace(" ?", "?")
string = string.replace(" :", ":")
string = string.replace(" ;", ";")
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" ,'", ",'")
string = string.replace(" .'", ".'")
string = string.replace(" !'", "!'")
string = string.replace(" ?'", "?'")
string = string.replace("~", "")
string = string.replace("---", "")
string = string.replace("<", "")
string = string.replace(">", "")
string = string.replace("#", "")
string = string.replace(', "', ',"')
string = string.replace('. "', '."')
string = string.replace('! "', '!"')
string = string.replace('? "', '?"')
string = string.replace('"" ', '" "')
string = string.replace('• • •', '')
# sensitive word process
string = string.replace("f ** k", "fuck")
string = string.replace("f ** king", "fucking")
string = string.replace("f ** ked", "fucked")
string = string.replace("c ** k", "cock")
string = string.replace("br ** sts", "breasts")
string = string.replace("n ** ples", "nipples")
string = string.replace("ni ** les", "nipples")
string = string.replace("a ** hole", "asshole")
string = string.replace("ass ** le", "asshole")
string = string.replace("p ** sy", "pussy")
string = string.replace("pu ** y", "pussy")
string = string.replace("na ** d", "naked")
string = string.replace("nak * d", "naked")
string = string.replace("cli ** x", "climax")
string = string.replace("h * ps", "hips")
string = string.replace("c * ck", "cock")
string = string.replace("coc ** ne", "cocaine")
string = string.replace("*", "")
string = re.sub(" "," ",string)
string = re.sub(" "," ",string)
string = re.sub(" "," ",string)
return string
def lambada_dataset_preprocess(input_file, output_file):
sentences = []
count = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
line = lambada_detokenizer(line)
split_sentence_list = line.split()
final_word = split_sentence_list[-1]
context = split_sentence_list[:-1]
new_sentence = ' '.join(context) + '\t' + ' ' + final_word
sentences.append(new_sentence)
count += 1
print('read {} file finished!\n total count = {}'.format(input_file, count))
with open(output_file, 'w', encoding='utf-8') as f:
for sentence in sentences:
sentence = sentence.strip()
if sentence:
f.write(sentence)
f.write('\n')
count -= 1
print('write {} file finished!\n total count = {}'.format(output_file, count))
def get_gold_answer_id(gold_answer, candidate_answer_list):
id_ = 0
for candidate in candidate_answer_list:
if gold_answer == candidate:
return id_
id_ += 1
def get_passage_string(passage_string, candidate_answer, final_sentence, gold_answer_id):
"""
concat each candidate answer to the rest_sentence
Args:
candidate_answer (list): store each candidate answers
final_sentence (str): the 21st sentence string with "XXXXX"
gold_answer_id (int): the id of correct answer.
return:
candidate_passage (list): the length of candidate_sentence equals to length of candidate_answer.
"""
candidate_passage = []
for answer in candidate_answer:
passage = passage_string + " " + final_sentence
passage = passage.replace(" XXXXX", "\t XXXXX")
final_passage = passage.replace("XXXXX", answer)
whole_passage = final_passage + "\t" + str(gold_answer_id)
candidate_passage.append(whole_passage)
return candidate_passage
def cbt_dataset_preprocess(input_file, output_file):
passages = []
candidate_passage_list = []
passage_string = ""
count = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
single_sentence = line.split(' ', 1)
line_id = int(single_sentence[0])
string = single_sentence[1]
if line_id == 21:
string = string.replace("\t\t", "\t")
mini_string = string.split("\t")
candidate_answer = mini_string[-1]
candidate_answer_list = candidate_answer.split("|")
gold_answer_id = get_gold_answer_id(mini_string[-2], candidate_answer_list)
candidate_passage = get_passage_string(passage_string,
candidate_answer_list,
mini_string[0],
gold_answer_id)
assert len(candidate_passage) == 10
count += 10
else:
passage_string = passage_string + " " + string
else:
passages.append(candidate_passage)
candidate_passage_list = []
passage_string = ""
print('read {} file finished!\n total count = {}'.format(input_file, count))
with open(output_file, 'w', encoding='utf-8') as f:
for passage in passages:
for candidate_passage in passage:
candidate_passage = candidate_passage.replace(" \t ", "\t ")
candidate_passage = candidate_passage.strip()
f.write(candidate_passage)
f.write("\n")
count -= 1
print('write {} file finished!\n total count = {}'.format(output_file, count))
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" .", ".")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def wikitext_dataset_preprocess(input_file, output_file):
dataset_test = []
passage = []
count = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
if line.startswith('=') and line.endswith('=') and len(passage) != 0:
dataset_test.append(passage)
count += 1
passage = []
elif line.startswith('=') and line.endswith('='):
continue
else:
passage.append(line)
print('read {} file finished!\n total count = {}'.format(input_file, count))
with open(output_file, 'w', encoding='utf-8') as f:
for line in dataset_test:
text = ""
for sentence in line:
sentence = wikitext_detokenizer(sentence)
text = text + " " + sentence
text = text.strip()
f.write(text)
f.write("\n")
print('write {} file finished!\n total count = {}'.format(output_file, count))
def ptb_detokenizer(string):
string = string.replace(" '", "'")
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" n't", "n't")
string = string.replace(" N ", "1 ")
string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1")
string = string.replace("\/abc", "")
string = string.replace("\/ua", "")
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
string = string.replace(" 's", "'s")
return string
def ptb_dataset_preprocess(input_file, output_file):
sentences = []
count = 0
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
line = ptb_detokenizer(line)
sentences.append(line)
count += 1
print('read {} file finished!\n total count = {}'.format(input_file, count))
with open(output_file, 'w', encoding='utf-8') as f:
for sentence in sentences:
sentence = sentence.strip()
if sentence:
f.write(sentence)
f.write("\n")
count -= 1
print('write {} file finished!\n total count = {}'.format(output_file, count))
def onebw_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" --", "")
string = string.replace("--", "")
string = string.replace("? ? ?", " ?")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" 't", "'t")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
string = string.replace(" '", "'")
string = string.replace(" n't", "n't")
string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1")
return string
def test_length(string):
string_list = string.split()
return len(string_list)
def onebw_dataset_preprocess(condition, input_file, output_file):
sentences = []
count = 0
if condition.lower() == "test":
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
sentences.append(line)
count += 1
print('read {} file finished!\n total count = {}'.format(input_file, count))
with open(output_file, 'w', encoding='utf-8') as f:
for sentence in sentences:
sentence = sentence.strip()
if sentence:
sentence = onebw_detokenizer(sentence)
f.write(sentence)
f.write("\n")
count -= 1
print('write {} file finished!\n total count = {}'.format(output_file, count))
elif condition.lower() == "train":
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
line = onebw_detokenizer(line)
length = test_length(line)
if length > 10 and length < 60:
sentences.append(line)
count += 1
print('read finished! count = ', count)
sample_result_list = random.sample(range(0, count), 30000)
sample_result_list.sort()
count_sample = 0
choiced_sentence = ""
with open(output_file, 'w', encoding='utf-8') as f:
for i in range(len(sample_result_list)):
choiced_sentence = sentences[sample_result_list[i]]
f.write(choiced_sentence)
f.write("\n")
count_sample += 1
print('write finished! ', count_sample)
else:
raise ValueError("condition error support: [train, test]")
def coqa_dataset_preprocess(input_file, output_file):
with open(input_file, 'r', encoding='utf-8') as f:
source_data = json.load(f)
stories = []
instances = []
end_sep = [',', '.', ';']
question_before_sep = " "
question_after_sep = " A: "
answer_sep = " A:\t"
for i, dialog in enumerate(source_data["data"]):
story = dialog["story"].replace("\n", "")
stories.append(story)
concat_ = ""
concat_ += story
for question, answer in zip(dialog["questions"], dialog["answers"]):
question = question["input_text"]
answer = answer["input_text"]
concat_ += question_before_sep
concat_ += question
tmp = concat_ + question_after_sep
concat_ += answer_sep
concat_ += answer
instances.append(concat_)
concat_ = tmp + answer
if concat_[-1] not in end_sep:
concat_ += "."
instances.append("")
with open(output_file, 'w', encoding='utf-8') as f:
for i in range(len(instances)):
if instances[i]:
f.write(instances[i])
f.write("\n")
print('write {} file finished!\n total count = {}'.format(output_file, len(instances)))
def wmt14_en_fr_preprocess(input_file, output_file):
input_file = input_file + "/newstest2014-fren-ref"
output_file = output_file + "/wmt14"
language = ['.en.sgm', '.fr.sgm']
count = 0
# en-fr
with open(input_file + language[0], "r", encoding='utf-8') as english, \
open(input_file + language[1], "r", encoding='utf-8') as french, \
open(output_file + '.en_fr.txt', "a", encoding='utf-8') as enfr_f, \
open(output_file + '.fr_en.txt', "a", encoding='utf-8') as fren_f:
line_id = 0
for en, fr in zip(english, french):
line_id += 1
if (en[:7] == '<seg id'):
print("=" * 20, "\n", line_id, "\n", "=" * 20)
en_start = en.find('>', 0)
en_end = en.find('</seg>', 0)
print(en[en_start + 1:en_end])
en_ = en[en_start + 1:en_end]
fr_start = fr.find('>', 0)
fr_end = fr.find('</seg>', 0)
print(fr[fr_start + 1:fr_end])
fr_ = fr[fr_start + 1:fr_end]
en_fr_str = en_ + "\t" + fr_ + "\n"
enfr_f.write(en_fr_str)
fr_en_str = fr_ + "\t" + en_ + "\n"
fren_f.write(fr_en_str)
count += 1
print('write {} file finished!\n total count = {}'.format(output_file + '.en_fr.txt', count))
print('write {} file finished!\n total count = {}'.format(output_file + '.fr_en.txt', count))

View File

@ -0,0 +1,542 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
generation utils
"""
import numpy as np
from scipy.special import softmax
from mindspore.ops import operations as P
from mindspore import dtype as mstype
from mindspore.common.tensor import Tensor
from .tensor_manipulations import extract_single_token_logits, add_last_token
INF = 1. * 1e9
class TopKTopP_Filter():
"""
Top K sampling along with Top P sampling(Nucleus Sampling)
Choose top-K probability of ids and those with top-P probability ids into candidate sample sets.
Use np.random.multinomial to sample
Args:
batch_size (int): batch size of input dataset.
vocab_size (int): the shape of each embedding vector.
k (int): parameter for Top-K sampling, k should be in range of [0, vocab_size].
0 for no filter for TopK sampling(do nothing). Default: 0.
p (float) [Optional]: parameter for Top-P sampling a.k.a. Necleus Sampling, p is in between 0.0 and 1.0.
Default: 1.0.
temperature (float) [Optional]: parameter for generation, greater if generation more diverse. Default: 1.0.
"""
def __init__(self,
batch_size=None,
vocab_size=None,
k=0,
p=1.0,
temperature=1.0,
min_tokens_to_keep=1,
):
self.k = k
self.p = p
self.temp = temperature
self.batch_size = batch_size
self.vocab_size = vocab_size
self.min_tokens_to_keep = min_tokens_to_keep
assert self.temp > 0.0, 'temperature must be positive'
assert self.k >= 0, 'the top_k number must be no negative.'
if self.k > 0:
assert self.min_tokens_to_keep <= self.k, 'k must be larger than or equal to min_token_to_keep ' \
'for Top-p sampling'
if self.k == 0:
self.k = self.vocab_size
self.safety_mask = np.concatenate((np.ones((self.batch_size, self.min_tokens_to_keep)),
np.zeros((self.batch_size, self.k - self.min_tokens_to_keep))),
axis=1).astype(np.bool)
def calculate(self, distribution):
"""
calculate sampling procedure with setting initialized before, return a list of sampled ids.
Args:
distribution (numpy.ndarray): with shape (batch_size,vocab_size)
Returns:
sampled ids: a list, with length of batch_size
"""
if self.temp != 1.0:
distribution = distribution / float(self.temp)
distribution_sorted = -np.sort(-distribution, axis=1)
index_sorted = np.argsort(-distribution, axis=1)
topk_distribution = distribution_sorted[::, :self.k if self.k > 0 else self.vocab_size]
topk_indices = index_sorted[::, :self.k if self.k > 0 else self.vocab_size]
# safety check of probability
self.p = max(0.0, min(1.0, self.p))
cum_sum = np.cumsum(softmax(topk_distribution, axis=1), axis=1)
bool_map = np.logical_or((cum_sum <= self.p), self.safety_mask).astype(np.float32)
topk_distribution = topk_distribution * bool_map + np.float32(-1e5) * (1.0 - bool_map)
topk_distribution = softmax(topk_distribution, axis=1)
# normalize for np.float64
# choose np.float64 to avoid overflow in softmax operation
topk_distribution = topk_distribution.astype(np.float64)
for batch_idx in range(self.batch_size):
topk_distribution[batch_idx] = topk_distribution[batch_idx] / np.sum(topk_distribution[batch_idx])
ret_ids = []
for batch_idx in range(self.batch_size):
select_index = np.argmax(np.random.multinomial(1, topk_distribution[batch_idx]))
ret_ids.append(topk_indices[batch_idx][select_index])
return ret_ids
class Sample():
"""
Initiate a Sample object for sampling next token(s) from previous text.
Args:
decoder (Model): GPT2 model to do generation.
config (GPT2Config): configuration of given GPT2 model.
tokenizer (GPT2Tokenizer): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
generate_length (int): number of tokens which should be generated. Default: 1.
topk_num (int): number of k in Top-k Sampling, 0 for no condition constrained,
equivalent to k = self.vocab_size. Default:0
topp_prob (float): probability parameter of Top-p sampling.
if p = 1.0, it equals to do nothing. (nucleus sampling). Default: 1.0
temperature (float): temperature for Top-k sampling. Default: 1.0
min_tokens_to_keep (int): guarantee for there is at least min_tokens_to_keep token(s) generated. Default:1
early_stop (bool): whether stop when the model generates <EOS> token.
It is functioned when batch_size is 1. Default: False
demo_mode(bool): True if input_str is a str not a List of str.
self.batch_size should be 1 if it is True. Default: False
return_ids (bool): whether return ids generated from Sample. Default: False
return_last_token_logits (bool): whether return logits of last token for each time step during generation.
Default: False
append_eos (bool): whether append <EOS> token id to input_ids pass directly to GPT2Model class. Default: False
"""
def __init__(self,
decoder,
config=None,
batch_size=None,
tokenizer=None,
generate_length=1,
topk_num=0,
topp_prob=1.0,
temperature=1.0,
min_tokens_to_keep=1,
early_stop=False,
demo_mode=False,
return_ids=False,
return_last_token_logits=False,
append_eos=False,
):
assert config is not None, 'Config is a must for sampling.'
self.decoder = decoder
self.config = config
self.tokenizer = tokenizer
self.generate_length = generate_length
self.topk_num = topk_num
self.topp_prob = topp_prob
self.temperature = temperature
self.min_tokens_to_keep = min_tokens_to_keep
self.early_stop = early_stop
self.demo_mode = demo_mode
self.return_ids = return_ids
self.return_last_token_logits = return_last_token_logits
self.append_eos = append_eos
self.seq_length = config.seq_length
self.batch_size = config.batch_size if batch_size is None else batch_size
self.vocab_size = config.vocab_size
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reshape = P.Reshape()
self.cumsum = P.CumSum()
self.onehot = P.OneHot()
self.cast = P.Cast()
self.concat = P.Concat()
self.sample_function = P.RandomCategorical(mstype.int32)
self.filter_distribution = TopKTopP_Filter(batch_size=self.batch_size,
vocab_size=self.vocab_size,
k=self.topk_num,
p=self.topp_prob,
temperature=self.temperature,
min_tokens_to_keep=self.min_tokens_to_keep)
if self.tokenizer is not None:
self.eos_id = self.tokenizer.eos_token_id
else:
self.eos_id = config.vocab_size - 1
if self.tokenizer is not None:
self.eos_text = self.tokenizer.eos_token
else:
self.eos_text = "<|endoftext|>"
if self.demo_mode is True:
assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format(
self.batch_size)
def _extract_string_from_tensor(self, input_ids, mode="pair"):
"""
Args:
input_ids(Tensor): input sentences with shape [self.batch_size, self.seq_len]
mode (str): ["pair", "single"]
"pair" for tasks with paired inputs `<bos> A <eos> B <eos>`,
such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`,
reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`.
"single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task.
Returns:
source_list (list): the list of source_text or first part of text.
target_list (list): the list of target_text or second part of text.
If self.batch_size is 1, it will return the first sentence of list, that is to say, the string.
Example:
for pair mode, if self.demo_mode is True, it will return source_list[0], target_list[0]
"""
assert self.tokenizer is not None, 'There is no tokenizer'
source_list = [""] * self.batch_size
target_list = [""] * self.batch_size
eos_text = self.tokenizer.eos_token
len_eos_text = len(eos_text)
input_ids_np = input_ids.asnumpy()
input_ids_np = input_ids_np.reshape((self.batch_size, self.seq_length))
# input_ids = self.reshape(input_ids, (self.batch_size, self.seq_length))
if mode == "pair":
for batch_idx in range(self.batch_size):
sentence_tensor = input_ids_np[batch_idx]
sentence_list = sentence_tensor.tolist()[1:]
sentence = self.tokenizer.decode(sentence_list)
source_start = 0
source_end = sentence.find(eos_text, 0)
target_start = source_end + len_eos_text
target_end = sentence[target_start:].find(eos_text, 0) + target_start
source_list[batch_idx] = sentence[source_start:source_end]
target_list[batch_idx] = sentence[target_start:target_end]
if self.batch_size == 1 and self.demo_mode is True:
return source_list[0], target_list[0]
return source_list, target_list
if mode == "single":
for batch_idx in range(self.batch_size):
sentence_tensor = input_ids_np[batch_idx]
sentence_list = sentence_tensor.tolist()[1:]
sentence = self.tokenizer.decode(sentence_list)
source_start = 0
source_end = sentence.find(eos_text, 0)
source_list[batch_idx] = sentence[source_start:source_end]
if self.batch_size == 1 and self.demo_mode is True:
return source_list[0]
else:
raise ValueError('mode:{} not supported, only support [pair, single].'.format(mode))
return source_list
def _tensorize_ids_with_masks(self, src_str):
"""
Transform from string to tensor
Args:
src_str: string or list of strings
Return:
input_ids (Tensor): shape with [self.batch_size, self.seq_length]
input_mask (Tensor): shape with [self.batch_size, self.seq_length]
src_len_list (list): the length of tokens of src_string after decoded by self.tokenzier
"""
if isinstance(src_str, str):
src_str = [src_str]
single_sentence_shape = (1, self.seq_length)
src_len_list = list()
input_ids = None
input_mask = None
for batch_idx in range(self.batch_size):
src_ids_list = self.tokenizer.encode(src_str[batch_idx])
src_ids_len = len(src_ids_list)
if src_ids_len > self.seq_length:
src_ids_list = src_ids_list[:self.seq_length]
src_ids_len = self.seq_length
src_len_list.append(src_ids_len)
return_dict = self.tokenizer.prepare_for_model(src_ids_list,
max_length=self.config.seq_length,
add_special_tokens=False)
input_ids_list = return_dict['input_ids']
input_mask_list = return_dict['attention_mask']
input_ids_np = np.array(input_ids_list, dtype=int)
input_mask_np = np.array(input_mask_list, dtype=int)
input_ids_np = input_ids_np.reshape(single_sentence_shape)
input_mask_np = input_mask_np.reshape(single_sentence_shape)
# input_ids_tensor = self.reshape(Tensor(np.array(input_ids_list, dtype=int), dtype=mstype.int32),
# single_sentence_shape)
# input_mask_tensor = self.reshape(Tensor(np.array(input_mask_list, dtype=int), dtype=mstype.int32),
# single_sentence_shape)
if batch_idx == 0:
# input_ids = input_ids_tensor
# input_mask = input_mask_tensor
input_ids_np_ = input_ids_np
input_mask_np_ = input_mask_np
else:
# input_ids = self.concat((input_ids, input_ids_tensor))
# input_mask = self.concat((input_mask, input_mask_tensor))
input_ids_np_ = np.concatenate((input_ids_np_, input_ids_np), axis=0)
input_mask_np_ = np.concatenate((input_mask_np_, input_mask_np), axis=0)
input_ids = Tensor(input_ids_np_, dtype=mstype.int32)
input_mask = Tensor(input_mask_np_, dtype=mstype.int32)
return input_ids, input_mask, src_len_list
class LastTokenPos():
"""
class for record input_strs and the position of their last tokens
Args:
input_ (Union[list, Tensor]): list if input is a list containing strings,
Tensor with shape (batch_size, seq_length) representing input_mask.
"""
def __init__(self, input_, seq_length=1024):
if isinstance(input_, list):
self.input_strs = input_
self.input_mask = None
else:
self.input_strs = None
self.input_mask = input_
self.seq_length = seq_length
if self.input_strs is not None:
self.pos_list = [len(input_str) - 1 for input_str in self.input_strs]
else:
input_mask_ = P.Cast()(self.input_mask, mstype.float32)
temp_pos_list = P.ReduceSum(keep_dims=False)(input_mask_, axis=1).asnumpy().astype(np.int32).tolist()
# minimum value is always 0 for safety
self.pos_list = [max(0, pos - 1) for pos in temp_pos_list]
def get_pos(self, shift: int = 0):
# return last token if overflow
shift_list = [min(self.seq_length - 1, pos + shift) for pos in self.pos_list]
return shift_list
def _sample_from_distribution(self, distribution):
"""
sample one token per batch from self.sample_function().
Arg:
distribution (Tensor): the distribution or logits of the last token of different batches.
shape with [batch_size, vocab_size]
Return:
word_index (Tensor): shape with [batch_size, ]
"""
distribution = self.reshape(distribution, (self.vocab_size, self.batch_size))
topk_distribution = distribution[:self.topk_num, ::]
topk_distribution = self.reshape(topk_distribution, (self.batch_size, -1))
word_index = self.sample_function(P.Softmax()(topk_distribution), 1, 1)
word_index = self.reshape(word_index, (-1,))
return word_index
def _demo_mode_check(self, input_str):
"""
type check for demo_mode: 1 batch, input_str is not None and initiate full_str as input_str
"""
if self.batch_size == 1 and self.demo_mode is True:
assert input_str is not None, "demo mode should have input str"
# type check
if isinstance(input_str, list):
assert isinstance(input_str[0], str), "type of input_str is {}, " \
"which should be str instead.".format(type(input_str[0]))
if len(input_str) != 1:
print("[WARNING] Sample.generate: length of input_str is larger than 1, "
"choose input_str[0] as input_str.")
input_str = input_str[0]
assert isinstance(input_str, str), "type of input_str is {}, " \
"which should be str instead.".format(input_str)
input_str = [input_str]
return input_str
def _input_check_and_normalize(self, input_str=None, input_ids=None, input_mask=None, generate_length=None):
"""
input check function
"""
if input_str is not None:
assert self.tokenizer is not None, 'if choose to give input_str, a tokenizer is necessary.'
input_str = self._demo_mode_check(input_str)
if input_ids is not None:
assert input_mask is not None, 'if input_ids is given, input_mask is required either.'
if input_str is not None and input_ids is not None and input_mask is not None:
print('[WARNING] Sample.generate got input_str, input_ids and input_mask, '
'choose input_str as default for input')
if input_ids is None and input_mask is None:
input_ids, input_mask, _ = self._tensorize_ids_with_masks(input_str)
else:
if input_str is None:
if input_ids is not None:
input_str = self._extract_string_from_tensor(input_ids, mode="full")
if generate_length is not None:
# reload generate_length
generate_length = int(generate_length)
assert generate_length >= 0, 'generate_length can not be negative.'
else:
generate_length = self.generate_length
return input_str, input_ids, input_mask, generate_length
def generate(self, input_str=None, input_ids=None, input_mask=None, generate_length=None, do_sample=True):
"""
base function for text generation given a batch_size list of str or str itself (when demo mode is on)
Args
input_str (list(str) or str): prompt string.
generate_length: number of tokens to generate.
Returns:
generate_str: string generated by the GPT-2 model.
full_str: input_str appended with generate_str.
"""
input_str, input_ids, input_mask, generate_length = self._input_check_and_normalize(input_str,
input_ids,
input_mask,
generate_length)
return_ids_list = [[]] * self.batch_size
last_token = self.LastTokenPos(input_mask, seq_length=self.seq_length)
for i in range(generate_length):
last_token_pos_list = last_token.get_pos(shift=i)
early_stop_mask = [0] * self.batch_size
# unsorted logits (distribution) of next word
logits = self.decoder.predict(input_ids, input_mask)
if self.return_last_token_logits is True:
if i == 0:
# [batch_size, 1, vocab_size]
return_last_logits = extract_single_token_logits(logits, last_token_pos_list)
else:
# [batch_size, 1, vocab_size] + [batch_size, i, vocab_size] --> [batch_size, i+1, vocab_size]
return_last_logits = P.Concat(axis=1)((return_last_logits,
extract_single_token_logits(logits, last_token_pos_list)))
nextword_distribution = self.reshape(logits[0, last_token_pos_list[0]:last_token_pos_list[0]+1:1, ::],
(1, -1))
# stack up nextword_distribution if batch_size is larger than 1
if self.batch_size > 1:
for batch_idx in range(1, self.batch_size):
nextword_distribution_rest = self.reshape(
logits[batch_idx, last_token_pos_list[batch_idx]:last_token_pos_list[batch_idx] + 1:1, ::],
(1, -1))
nextword_distribution = self.concat((nextword_distribution, nextword_distribution_rest))
if do_sample:
# get sampled ids
nextword_distribution = nextword_distribution.asnumpy().astype(np.float32)
real_next_word_index_list = self.filter_distribution.calculate(nextword_distribution)
else:
np_nextword_distribution = nextword_distribution.asnumpy()
next_word_index = np.argmax(np_nextword_distribution, axis=-1)
real_next_word_index_list = next_word_index.tolist()
append_ids = []
# tokenizer.decode and early_stop (if all batched generates a EOS, then it is time to say goodbye)
for batch_idx in range(self.batch_size):
next_word_index = real_next_word_index_list[batch_idx]
# earlystop if the model generates a EOS token.
if self.early_stop is True:
if next_word_index == self.eos_id:
if self.batch_size == 1:
break
else:
early_stop_mask[batch_idx] = 1
continue
return_ids_list[batch_idx].append(next_word_index)
append_ids.append(next_word_index)
# check early_stop mask at the end of each loop
if 0 not in early_stop_mask:
break
input_ids, input_mask = add_last_token(input_ids,
input_mask,
overflow_strategy="shift",
append_ids=append_ids,
next_token_pos=last_token.get_pos(shift=i + 1))
# add str to full str
generate_str = [""] * self.batch_size
full_str = [""] * self.batch_size
text_cnt = 0
for text_ids in return_ids_list:
text = self.tokenizer.decode(text_ids)
generate_str[text_cnt] = text
text_cnt += 1
for batch_idx in range(self.batch_size):
full_str[batch_idx] = input_str[batch_idx] + generate_str[batch_idx]
# return by several conditions
if self.batch_size == 1 and self.demo_mode is True:
if self.return_ids:
return generate_str[0], input_str[0], return_ids_list[0]
return generate_str[0], input_str[0]
if self.return_ids:
if self.return_last_token_logits:
return return_ids_list, return_last_logits
return return_ids_list
return generate_str, full_str

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get config setting"""
def get_train_setting(finetune_config):
"""get train config setting"""
cfg = finetune_config
print("Loading GPT2 Finetune Config setting......")
print(" | optimizer: {}".format(cfg.optimizer))
opt = cfg['optimizer']
print(" | learning rate: {}".format(cfg[opt]['learning_rate']))
print(" | end learning rate: {}".format(
cfg[opt]['end_learning_rate'] if 'end_learning_rate' in cfg[opt] else 'None'))
print(" | weight decay: {}\n".format(cfg[opt]['weight_decay'] if 'weight_decay' in cfg[opt] else 'None'))
def get_model_setting(finetune_config, model_config):
"""get GPT-2 model config setting"""
cfg = finetune_config
gpt2_net_cfg = model_config
print("Loading GPT2 Model Config setting......")
print(" | model size: {}".format(cfg.gpt2_network))
print(" | batch_size: {}".format(gpt2_net_cfg.batch_size))
print(" | seq_length: {}".format(gpt2_net_cfg.seq_length))
print(" | vocab_size: {}".format(gpt2_net_cfg.vocab_size))
print(" | d_model: {}".format(gpt2_net_cfg.d_model))
print(" | num_hidden_layers: {}".format(gpt2_net_cfg.num_hidden_layers))
print(" | num_attention_heads: {}".format(gpt2_net_cfg.num_attention_heads))
print(" | hidden_dropout: {}".format(gpt2_net_cfg.hidden_dropout))
print(" | attention_dropout: {}".format(gpt2_net_cfg.attention_dropout))
print(" | summary_first_dropout: {}\n".format(gpt2_net_cfg.summary_first_dropout))

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning schedule"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
class GPT2LearningRate(LearningRateSchedule):
"""
Implements of warmup-polydecay learning rate scheduler.
Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
decay_steps (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate.
Returns:
lr (Tensor): The learning rate value for the current step.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(GPT2LearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr

View File

@ -0,0 +1,185 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metric method for downstream task"""
import string
import re
from collections import Counter
import numpy as np
from .rouge_score import get_rouge_score
from .bleu import compute_bleu
class LastWordAccuracy():
"""
LastWordAccuracy class is for lambada task (predict the final word of sentence)
"""
def __init__(self):
self.acc_num = 0
self.total_num = 0
def normalize(self, word):
"""normalization"""
word = word.lstrip()
word = word.rstrip()
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return remove_punc(lower(word))
def update(self, predict_label, gold_label):
if isinstance(predict_label, str) and isinstance(gold_label, str):
predict_label = [predict_label]
gold_label = [gold_label]
for predict_word, gold_word in zip(predict_label, gold_label):
self.total_num += 1
if self.normalize(predict_word) == self.normalize(gold_word):
self.acc_num += 1
class Accuracy():
"""
calculate accuracy
"""
def __init__(self):
self.acc_num = 0
self.total_num = 0
def update(self, logits, labels):
"""accuracy update"""
labels = np.reshape(labels, -1)
logits_id = np.argmax(logits, axis=-1)
print(" | Preict Label: {} Gold Label: {}".format(logits_id, labels))
self.acc_num += np.sum(labels == logits_id)
self.total_num += len(labels)
print("\n| Accuracy = {} \n".format(self.acc_num / self.total_num))
class F1():
"""calculate F1 score"""
def __init__(self):
self.f1_score = 0.0
def get_normalize_answer_token(self, string_):
"""Lower text and remove punctuation, article and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(char for char in text if char not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(string_)))).split()
def update(self, pred_answer, gold_answer):
"""F1 update"""
common = Counter(pred_answer) & Counter(gold_answer)
num_same = sum(common.values())
# the number of same tokens between pred_answer and gold_answer
precision = 1.0 * num_same / len(pred_answer) if pred_answer else 0
recall = 1.0 * num_same / len(gold_answer) if gold_answer else 0
if ' '.join(pred_answer).strip() == "" and ' '.join(gold_answer).strip() == "":
self.f1_score += 1
else:
self.f1_score += 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0
print('| precision: {}, recall: {}\n'.format(precision, recall))
class BLEU():
"""calculate BLEU score"""
def __init__(self, tokenizer=None, max_order=4, smooth=True):
self.bleu = 0.0
self.total_num = 0
self.tokenizer = tokenizer
self.max_order = max_order
self.smooth = smooth
def sum_bleu(self, references, translations, max_order, smooth):
"""calculate the sum of bleu score"""
all_result = []
bleu_avg = 0.0
for refer, trans in zip(references, translations):
result = compute_bleu([[refer]], [trans], max_order, smooth)
all_result.append(result)
bleu_avg += result[0]
bleu_avg /= len(references)
return bleu_avg, all_result
def update(self, hypotheses, references):
"""BLEU update"""
hypo_l = []
ref_l = []
if self.tokenizer is not None:
for hypo, ref in zip(hypotheses, references):
if ref.strip() == '':
print("Reference is None, skip it !")
continue
if hypo.strip() == '':
print("translation is None, skip it !")
continue
hypo_l.append(self.tokenizer.encode(hypo))
ref_l.append(self.tokenizer.encode(ref))
if hypo_l and ref_l:
hypotheses = hypo_l
references = ref_l
bleu_avg, _ = self.sum_bleu(references, hypotheses, self.max_order, self.smooth)
self.bleu += bleu_avg * 100
self.total_num += 1
print("============== BLEU: {} ==============".format(float(self.bleu / self.total_num)))
class Rouge():
'''
Get Rouge Score
'''
def __init__(self):
self.Rouge1 = 0.0
self.Rouge2 = 0.0
self.RougeL = 0.0
self.total_num = 0
def update(self, hypothesis, targets):
scores = get_rouge_score(hypothesis, targets)
self.Rouge1 += scores['rouge-1']['f'] * 100
self.Rouge2 += scores['rouge-2']['f'] * 100
self.RougeL += scores['rouge-l']['f'] * 100
self.total_num += 1
print("=============== ROUGE: {} ===============".format(
(self.Rouge1 + self.Rouge2 + self.RougeL) / float(3.0 * self.total_num)))

View File

@ -0,0 +1,466 @@
,
.
?
!
#
~
=
-
"
'
:
-
--
|
a
about
above
across
after
again
against
all
almost
alone
along
already
also
although
always
among
an
and
another
any
anybody
anyone
anything
anywhere
are
area
areas
around
as
ask
asked
asking
asks
at
away
b
back
backed
backing
backs
be
became
because
become
becomes
been
before
began
behind
being
beings
best
better
between
big
both
bro
but
by
c
came
can
cannot
case
cases
certain
certainly
clear
clearly
come
could
d
did
differ
different
differently
do
does
done
down
down
downed
downing
downs
during
dr
e
each
early
eh
either
end
ended
ending
ends
enough
even
evenly
ever
every
everybody
everyone
everything
everywhere
f
fact
facts
far
felt
few
find
finds
first
for
four
from
full
fully
further
furthered
furthering
furthers
g
gave
general
generally
get
gets
give
given
gives
going
good
goods
got
great
greater
greatest
group
grouped
grouping
groups
h
had
has
have
having
he
her
here
herself
hey
high
high
high
higher
highest
him
himself
his
house
how
however
i
if
important
in
interest
interested
interesting
interests
into
is
it
its
itself
j
just
k
kae
keep
keeps
kind
knew
know
known
knows
kya
l
lads
large
largely
last
later
latest
least
less
let
lets
like
likely
long
longer
longest
m
made
make
making
man
many
may
me
member
members
men
might
mister
more
most
mostly
mr
Mr
mrs
much
must
my
myself
n
na
necessary
need
needed
needing
needs
never
new
new
newer
newest
next
no
nobody
non
noone
not
nothing
now
nowhere
number
numbers
nt
nn
nope
ny
o
oi
of
off
often
old
older
oldest
on
once
one
only
open
opened
opening
opens
or
order
ordered
ordering
orders
other
others
our
out
over
oh
p
part
parted
parting
parts
per
perhaps
place
places
please
point
pointed
pointing
points
possible
present
presented
presenting
presents
problem
problems
put
puts
q
quite
r
rather
really
right
right
room
rooms
s
said
same
saw
say
says
second
seconds
see
seem
seemed
seeming
seems
sees
several
shall
she
should
show
showed
showing
shows
side
sides
since
small
smaller
smallest
so
some
somebody
someone
something
somewhere
state
states
still
still
such
sure
t
take
taken
than
that
the
their
them
then
there
therefore
these
they
thing
things
think
thinks
this
those
though
thought
thoughts
three
through
thus
to
today
together
too
took
toward
turn
turned
turning
turns
two
u
uh
um
under
until
up
upon
us
use
used
uses
v
very
w
want
wanted
wanting
wants
was
way
ways
we
well
wells
went
were
what
when
where
whether
which
while
who
whole
whose
why
will
with
within
without
work
worked
working
works
would
x
y
ya
ye
year
years
yet
you
young
younger
youngest
your
yours
z

View File

@ -0,0 +1,39 @@
"""Calculate ROUGE score."""
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import List
from rouge import Rouge
def get_rouge_score(hypothesis: List[str], target: List[str]):
"""
Calculate ROUGE score.
Args:
hypothesis (List[str]): Inference result.
target (List[str]): Reference.
"""
if not hypothesis or not target:
raise ValueError(f"`hypothesis` and `target` can not be None.")
_rouge = Rouge()
print("hypothesis:", hypothesis)
print("target:", target)
scores = _rouge.get_scores(hypothesis, target, avg=True)
print(" | ROUGE Score:")
print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}")
print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}")
print(f" | RG-L(F): {scores['rouge-l']['f'] * 100:8.2f}")
return scores

View File

@ -0,0 +1,186 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
task utils
"""
import regex as re
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
# for lambada task
def extract_logits(logits=None, position=None):
"""
Args
logits (Tensor): Tensor(batch_size,seq_length,vocab_size) e.g.(8,1024,50257)
position (numpy.array): the array stored the fianl word position, shape with [batch_size, 2]
Return:
output_logits (Tensor): extract the Specified logit according to the position,
shape with [batch_size, vocab_size]
"""
batch_size = logits.shape[0]
for batch_idx in range(batch_size):
word_logits_pos = int(position[batch_idx, 0] - 1)
logit = logits[batch_idx:batch_idx+1:1, word_logits_pos, ::] # [1, vocab_size]
if batch_idx == 0:
output_logits = logit
else:
output_logits = P.Concat()((output_logits, logit)) # [batch_size, vocab_size]
return output_logits
def get_final_word_label(input_ids, input_length, tokenizer=None):
"""
get whole word label_str from input_ids
Args:
input_ids: Tensor(batch_size,seq_length), indices of input text
config: GPT2Config, config of GPT2 model, if not initiated,
this function will create a MockConfig by params of input_ids, optional
tokenizer: GPT2Tokenizer, if not initiated, it will be created using the default setting in utils. tokenization,
optional
Returns:
batch_word_label: [str], lastword str given lambada as label
"""
input_ids_np = input_ids.asnumpy()
input_length_np = input_length.asnumpy()
batch_word_label = []
for batch_idx in range(len(input_ids_np)):
word_spos = input_length_np[batch_idx, 0]
word_epos = input_length_np[batch_idx, 1]
final_word_ids = input_ids_np[batch_idx, word_spos:word_epos]
final_word_str = tokenizer.decode(final_word_ids.tolist())
batch_word_label.append(final_word_str)
return batch_word_label
def calculate_final_word_loss(logits, batch_size, input_ids, input_length, loss):
"""
Calculate the last word loss.
"""
logits = logits.asnumpy()
input_len_np = input_length.asnumpy()
input_ids_np = input_ids.asnumpy()
sum_batch_loss = 0.0
for batch in range(batch_size):
lastword_spos = input_len_np[batch, 0]
lastword_epos = input_len_np[batch, 1]
last_word_logits = logits[batch, lastword_spos - 1:lastword_epos - 1:1, ::]
last_word_logits_tensor = Tensor(last_word_logits, mstype.float32)
last_word_label = input_ids_np[batch, lastword_spos:lastword_epos:1]
print("last word label: ", last_word_label)
last_word_label_tensor = Tensor(last_word_label, mstype.int32)
last_word_loss = loss(last_word_logits_tensor, last_word_label_tensor)
last_word_loss = float(last_word_loss.asnumpy())
sum_batch_loss += last_word_loss
print(" | loss: ", last_word_loss)
avg_batch_loss = float(sum_batch_loss / batch_size)
return avg_batch_loss
# for cbt task
def calculate_choice_prob_for_cbt(logits, batch_size, input_length, input_ids):
"""
calculate choice prob for cbt
Args:
logits:
batch_size: Any
input_length: {asnumpy}
input_ids: {asnumpy}
Returns:
choice_prob: List[float]
"""
choice_prob = [] # [batch_size]
logits = logits.asnumpy()
input_len_np = input_length.asnumpy()
input_ids_np = input_ids.asnumpy()
for batch in range(batch_size):
sum_ = 0.0
rest_spos = input_len_np[batch, 0]
rest_epos = input_len_np[batch, 1] + 1
for rest_pos in range(rest_spos - 1, rest_epos - 1):
rest_token_id = input_ids_np[batch, rest_pos + 1]
log_prob = logits[batch, rest_pos, rest_token_id]
sum_ = sum_ + log_prob
choice_prob.append(sum_)
print("rest sentence prob: ", sum_)
return choice_prob
# for summarization task
def modify_paramdict(param_dict, mode="zero-shot", model_prefix="gpt2."):
"""
modify keys of param_dict to fit model.
Args:
param_dic: dict, dictionary of parameters imported from a ckpt file
mode: str, "zero-shot" for an pretrained GPT2 model;
"finetune" for an finetuned model for certain task.
Return:
reorganized_param_dict: dict, new param_dict to fit in model for different tasks.
"""
final_param_dict = dict()
if mode == "zero-shot":
for name in param_dict:
final_param_dict[model_prefix + name] = param_dict[name]
final_param_dict['lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
elif mode == "finetuned":
embedding_name = "gpt2_embedding_lookup.embedding_table"
embedding_name_old = ""
for name in param_dict:
name_remove_prefix = name[len(model_prefix):]
name_prefix = name[:len(model_prefix)]
final_param_dict[name_remove_prefix] = param_dict[name]
if embedding_name in name and name_prefix == model_prefix:
embedding_name_old = name
final_param_dict[embedding_name] = param_dict[embedding_name_old]
else:
raise ValueError("mode should be [zero-shot, finetuned]")
return final_param_dict
def clean_hypo(text):
"""
to prevent generation of empty string, and lower text
Arg:
text: str, input str
Return:
text: str, cleaned input str
"""
text = text.lower()
eng_re = re.compile(r'[a-z]+', re.I)
length_con = len(eng_re.findall(text))
if length_con == 0:
return '<EMPTY>'
return text

View File

@ -0,0 +1,217 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
tensor manipulations
"""
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops import operations as P
def extract_string_from_tensor(input_ids, mode="single", config=None, tokenizer=None):
"""
Args:
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
mode (str): ["pair", "single"]
"pair" for tasks with paired inputs `<bos> A <eos> B <eos>`,
such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`,
reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`.
"single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task.
config: the configuration of GPT-2 model.
tokenizer: the tokenizer of GPT-2 model.
Return:
prompt_list (list): list of prompt_text
reference_list (list): list of reference_text, or second part of text
rest_list (list): list of rest_text, or rest part of text
"""
batch_size = config.batch_size
seq_length = config.seq_length
prompt_list = [""] * batch_size
reference_list = [""] * batch_size
eos_text = tokenizer.eos_token
len_eos_text = len(eos_text)
input_ids_np = input_ids.asnumpy()
input_ids_np = input_ids_np.reshape((batch_size, seq_length))
# input_ids = P.Reshape()(input_ids, (batch_size, seq_length))
if mode == "pair":
for batch_idx in range(batch_size):
sentence_tensor = input_ids_np[batch_idx]
sentence_list = sentence_tensor.asnumpy().tolist()[1:]
sentence = tokenizer.decode(sentence_list)
prompt_start = 0
prompt_end = sentence.find(eos_text, 0)
reference_start = prompt_end + len_eos_text
reference_end = sentence[reference_start:].find(
eos_text, 0) + reference_start
prompt_list[batch_idx] = sentence[prompt_start:prompt_end]
reference_list[batch_idx] = sentence[reference_start:reference_end]
return prompt_list, reference_list
# For single output datasets such as WikiText, etc.
if mode == "single":
for batch_idx in range(batch_size):
sentence_tensor = input_ids_np[batch_idx]
sentence_list = sentence_tensor.asnumpy().tolist()[1:]
sentence = tokenizer.decode(sentence_list)
prompt_start = 0
prompt_end = sentence.find(eos_text, 0)
prompt_list[batch_idx] = sentence[prompt_start:prompt_end]
else:
raise NotImplementedError('mode:{} not supported.'.format(mode))
return prompt_list
def extract_single_token_logits(logits=None, seq_pos=None):
"""
Args
logits: (batch_size,seq_length,vocab_size) e.g. when batchsize is 8,
sequence length is 1024 and vocab_size is 50257,
then logits is a Tensor with shape (8,1024,50257)
seq_pos:(batch_size) list
Return:
output_logits: (batch_size,1,vocab_size) extract the logit to predict the last token.
"""
batch_size = logits.shape[0]
logits_np = logits.asnumpy()
logits_type = P.DType()(logits)
for i in range(batch_size):
# logit = logits[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::]
logit_np = logits_np[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::]
if i == 0:
# output_logits = logit
output_logits = logit_np
else:
# output_logits = P.Concat()((output_logits, logit))
output_logits = np.concatenate((output_logits, logit_np), axis=0)
output_logits = Tensor(output_logits, dtype=logits_type)
return output_logits
def get_last_one_pos(input_mask: Tensor):
"""
Arg:
input_mask (Tensor): (batch_size,seq_length)
Return:
pos (Tensor): (batch_size,)
"""
input_mask_ = P.Cast()(input_mask, mstype.float32)
pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,)
pos = P.Cast()(pos, mstype.int32)
pos = pos - 1
return pos
def get_next_one_pos(input_mask: Tensor):
"""
Arg:
input_mask (Tensor): (batch_size,seq_length)
"""
input_mask_ = P.Cast()(input_mask, mstype.float32)
pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,)
pos = P.Cast()(pos, mstype.int32)
return pos
def add_last_token_mask(input_mask: Tensor, overflow_strategy: str = "shift"):
"""
add last token mask
Args:
input_mask: Tensor
overflow_strategy: str
Returns:
Tensor
"""
pos = get_next_one_pos(input_mask).asnumpy()
input_mask_np = input_mask.asnumpy()
maximum_length = input_mask.shape[1]
batch_size = input_mask.shape[0]
for idx in range(batch_size):
# not overflow
if pos[idx] < maximum_length:
input_mask_np[idx][pos[idx]] = 1
# overflow
else:
if overflow_strategy == "shift":
continue
if overflow_strategy == "truncate":
continue
else:
raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy))
return Tensor(input_mask_np, dtype=mstype.int32)
def add_last_token(input_ids: Tensor, input_mask: Tensor, overflow_strategy: str = "shift", append_ids=None,
next_token_pos=None):
"""
add last token
Args:
input_ids: Tensor
input_mask: Tensor
overflow_strategy: str
append_ids: Any
next_token_pos: Any
Returns:
Tensor
"""
# get positional list/numpy array
if next_token_pos is None:
pos = get_next_one_pos(input_mask).asnumpy()
else:
pos = next_token_pos
# get numpy of inputs
input_mask_np = input_mask.asnumpy()
input_ids_np = input_ids.asnumpy()
maximum_length = int(input_mask.shape[1])
batch_size = int(input_mask.shape[0])
for idx in range(batch_size):
# not overflow
if pos[idx] < maximum_length:
input_mask_np[idx][int(pos[idx])] = 1
input_ids_np[idx][int(pos[idx])] = append_ids[idx]
# overflow
else:
if overflow_strategy == "shift":
# shift one token left
input_ids_np[idx][0:maximum_length - 1] = input_ids_np[idx][1:maximum_length]
input_ids_np[idx][maximum_length - 1] = append_ids[idx]
continue
if overflow_strategy == "truncate":
# do nothing
continue
else:
raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy))
return Tensor(input_ids_np, dtype=mstype.int32), Tensor(input_mask_np, dtype=mstype.int32)

View File

@ -0,0 +1,517 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
tokenization
"""
import json
from functools import lru_cache
from typing import List, Optional
import logging
import regex as re
logger = logging.getLogger(__name__)
@lru_cache()
def bytes_to_unicode():
"""
bytes to unicode
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(i) for i in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer():
"""
GPT2Tokenizer
"""
def __init__(
self,
vocab_file,
merge_file,
add_prefix_space=False,
):
with open(vocab_file, 'r', encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.vocab_size = len(self.decoder)
with open(merge_file, 'r', encoding="utf-8") as merge_handle:
bpe_merges = merge_handle.read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.add_prefix_space = add_prefix_space
self.cache = {}
self.unk_token = "<|endoftext|>"
self.unk_token_id = 50256
self.bos_token = "<|endoftext|>"
self.bos_token_id = 50256
self.eos_token = "<|endoftext|>"
self.eos_token_id = 50256
self.pad_token = "<|endoftext|>"
self.pad_token_id = 50256
def bpe(self, token):
"""
bpe encode
"""
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(token)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i + 1 < len(word) and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string using bpe encode. """
text = self.prepare_for_tokenization(text, is_pretokenized=False)
# print(text)
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
""" the index of the token in the vocabulary. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, _id):
""" return the origin bpe token according to id"""
return self.decoder.get(_id)
def _convert_tokens_to_string(self, tokens):
""" return a string according to the list of tokens"""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors='ignore')
return text
def encode(self, text):
""" get the index list of text"""
text_id = []
bpe_tokens = self._tokenize(text)
for token in bpe_tokens:
text_id.append(self._convert_token_to_id(token))
return text_id
def decode(self, ids):
""" return a string according to the index list of tokens"""
tokens = []
for id_ in ids:
tokens.append(self._convert_id_to_token(id_))
return self._convert_tokens_to_string(tokens)
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
""" whether to add a whitespace in the front of text """
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_pretokenized or add_prefix_space:
text = " " + text
return text
def add_special_tokens(self, special_tokens_dict):
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
current vocabulary).
Args:
special_tokens_dict (dictionary `str` to `str`):
Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``,
``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
``additional_special_tokens``].
Returns:
added_tokens (int): Number of tokens added to the vocabulary
"""
# special_tokens_dict = {'cls_token': '<CLS>'}
if not special_tokens_dict:
return 0
added_tokens = 0
for key, value in special_tokens_dict.items():
setattr(self, key, value)
assert isinstance(value, str), f"Token {value} for key {key} should be a str instance"
added_tokens += self.add_tokens([value], special_tokens=True)
return added_tokens
def add_tokens(self, new_tokens, special_tokens=False):
if not new_tokens:
return 0
if not isinstance(new_tokens, (list, tuple)):
new_tokens = [new_tokens]
return self._add_tokens(new_tokens, special_tokens=special_tokens)
def _add_tokens(self, new_tokens, special_tokens=False):
"""
_add_tokens
Args:
new_tokens (list[str]): Token(s) to add in vocabulary.
special_tokens (bool): Whether or not the tokens should be added as special tokens.
Returns:
the number of the new added tokens.
"""
new_tokens = [str(token) for token in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
tokens_to_add.append(token)
logger.info("Adding %s to the vocabulary ! ", token)
added_tok_encoder = dict((tok, self.vocab_size + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.encoder.update(added_tok_encoder)
self.decoder.update(added_tok_decoder)
return len(tokens_to_add)
def num_special_tokens_to_add(self, pair: bool = False):
token_ids_0 = []
token_ids_1 = []
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None):
"""
Build model inputs from a sequence or a pair of sequence by concatenating and adding special tokens.
A GPT2 sequence has the following format:
- single sequence: ``<bos> X <eos>``
- pair of sequences: ``<bos> A <eos> B <eos>``
Args:
token_ids_0 (List[int]): List of IDs to which the special tokens will be added
token_ids_1 (List[int], `optional`, defaults to `None`): Optional second list of IDs for sequence pairs.
"""
bos = [self.bos_token_id]
eos = [self.eos_token_id]
if token_ids_1 is None:
return bos + token_ids_0 + eos
return bos + token_ids_0 + eos + token_ids_1 + eos
def truncate_sequences(self, ids, num_tokens_to_remove, truncation_strategy="ONLY_FIRST", direction="RIGHT"):
"""
truncate sequences
Args:
ids: Any
num_tokens_to_remove:
truncation_strategy: str
direction: str
Returns:
(ids, overflowing_tokens): (Any, list)
"""
if num_tokens_to_remove <= 0:
return ids, []
overflowing_tokens = []
if truncation_strategy == "ONLY_FIRST":
if len(ids) > num_tokens_to_remove:
if direction == "RIGHT":
overflowing_tokens = ids[-num_tokens_to_remove:]
ids = ids[:-num_tokens_to_remove]
if direction == "LEFT":
overflowing_tokens = ids[:num_tokens_to_remove]
ids = ids[num_tokens_to_remove:]
else:
logger.error("The first sequence length is smaller than removed tokens. ")
else:
logger.error("Please select correct truncation strategy, for instance 'ONLY_FIRST'")
return (ids, overflowing_tokens)
def _pad(self, encoded_inputs, max_length=None, padding_strategy=None,
return_attention_mask: Optional[bool] = None):
"""
_pad
Args:
encoded_inputs:
max_length: Any
padding_strategy: Any
return_attention_mask: Optional[bool]
Returns:
encoded_inputs:
"""
needs_to_be_padded = (len(encoded_inputs["input_ids"]) != max_length)
if needs_to_be_padded:
if padding_strategy == "MAX_LENGTH":
difference = max_length - len(encoded_inputs["input_ids"])
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
else:
raise ValueError("Invalid padding strategy")
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
return encoded_inputs
def pad(self, encoded_inputs, max_length: Optional[int] = None, padding_strategy="MAX_LENGTH",
return_attention_mask=True):
"""
pad
Args:
encoded_inputs:
max_length: Optional[int]
padding_strategy: str
return_attention_mask: bool
Returns:
batch_outputs: Dict[Any, list]
"""
# no batch encoded_inputs["input_ids"]--->[98, 67, 32388, 318, 1912, 287, 170, 8496, 318, 905, 2667, 32]
if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
encoded_inputs = self._pad(
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
return_attention_mask=return_attention_mask
)
return encoded_inputs
# encoded_inputs with batch_size
batch_size = len(encoded_inputs["input_ids"])
assert all(
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == "LONGEST":
max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
padding_strategy = "MAX_LENGTH"
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
outputs = self._pad(
encoded_inputs=inputs,
max_length=max_length,
padding_strategy=padding_strategy,
return_attention_mask=return_attention_mask
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return batch_outputs
def prepare_for_model(self,
ids,
pair_ids=None,
add_special_tokens=True,
max_length=None,
padding=None,
truncate_direction="RIGHT",
return_overflowing_tokens=False,
return_attention_mask=True):
"""
prepare for model
Args:
ids:
pair_ids:
add_special_tokens: bool
max_length: Any
padding: Any
truncate_direction: str
return_overflowing_tokens: bool
return_attention_mask: bool
Returns:
encoded_inputs:Dict
"""
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
if max_length and total_len > max_length:
ids, overflowing_tokens = self.truncate_sequences(ids=ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy="ONLY_FIRST",
direction=truncate_direction)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
# build output dictionary
encoded_inputs["input_ids"] = sequence
# check lengths
if max_length is None or len(encoded_inputs["input_ids"]) > max_length:
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model (%ids > %length). Running this sequence through the model will result in "
"indexing errors", len(ids), max_length
)
# padding
if padding or return_attention_mask:
encoded_inputs = self.pad(encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy="MAX_LENGTH",
return_attention_mask=return_attention_mask)
return encoded_inputs
class CNN_DailyMail_tokenizer(GPT2Tokenizer):
"""
CNN DailyMail tokenizer
"""
def prepare_for_model(self,
ids,
pair_ids,
max_length=1024,
max_summary_length=150,
add_special_tokens=True,
padding=None,
return_overflowing_tokens=False,
return_attention_mask=True):
len_ids = len(ids)
len_pair_ids = len(pair_ids)
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids
ids_overflowing_tokens = []
pair_overflowing_tokens = []
# Truncation: Handle max sequence length
if total_len > max_length-3:
if len_pair_ids > max_summary_length:
num_tokens_to_remove = len_pair_ids - max_summary_length
pair_ids, pair_overflowing_tokens = self.truncate_sequences(ids=pair_ids,
num_tokens_to_remove=num_tokens_to_remove,
truncation_strategy="ONLY_FIRST",
direction="RIGHT")
if len_ids+max_summary_length > max_length-3:
num_tokens_to_remove = (len_ids + max_summary_length) - (max_length - 3)
ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids,
num_tokens_to_remove=num_tokens_to_remove,
truncation_strategy="ONLY_FIRST",
direction="RIGHT")
else:
ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids,
num_tokens_to_remove=total_len - (max_length-3),
truncation_strategy="ONLY_FIRST",
direction="RIGHT")
if return_overflowing_tokens:
encoded_inputs["article_overflowing_tokens"] = ids_overflowing_tokens
encoded_inputs["highlights_overflowing_tokens"] = pair_overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - (max_length-3)
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
encoded_inputs["input_ids"] = sequence
# check lengths
if max_length is None or len(encoded_inputs["input_ids"]) > max_length:
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model (%ids > %length). Running this sequence through the model will result "
"in indexing errors", len(ids), max_length
)
# padding
if padding or return_attention_mask:
encoded_inputs = self.pad(encoded_inputs=encoded_inputs,
max_length=max_length,
padding_strategy="MAX_LENGTH",
return_attention_mask=return_attention_mask)
return encoded_inputs
def Tokenizer(vocab_file="./pretrain-data/gpt2-vocab.json",
merge_file="./pretrain-data/gpt2-merges.txt",
mode="normal"):
""" use the GPT2Tokenizer"""
print(" | Tokenizer mode: {}".format(mode))
if mode == "normal":
tokenizer = GPT2Tokenizer(vocab_file, merge_file, add_prefix_space=False)
elif mode == "cnn_dailymail":
tokenizer = CNN_DailyMail_tokenizer(vocab_file, merge_file, add_prefix_space=False)
else:
raise ValueError("No Such Mode for {} in src.utils.tokenization.Tokenizer()".format(mode))
return tokenizer

View File

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
init weight
"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
def _average_units(shape):
if not shape:
return 1
if len(shape) == 1:
return float(shape[0])
if len(shape) == 2:
return float(shape[0] + shape[1]) / 2.
raise RuntimeError("not support shape.")
def weight_variable(shape):
scale_shape = shape
avg_units = _average_units(scale_shape)
scale = 1.0 / max(1., avg_units)
limit = math.sqrt(3.0 * scale)
values = np.random.uniform(-limit, limit, shape).astype(np.float32)
return Tensor(values)
def one_weight(shape):
ones = np.ones(shape).astype(np.float32)
return Tensor(ones)
def zero_weight(shape):
zeros = np.zeros(shape).astype(np.float32)
return Tensor(zeros)
def normal_weight(shape, num_units):
norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32)
return Tensor(norm)

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset preprocess"""
import argparse
from src.utils.data_preprocess import lambada_dataset_preprocess
from src.utils.data_preprocess import cbt_dataset_preprocess
from src.utils.data_preprocess import wikitext_dataset_preprocess
from src.utils.data_preprocess import ptb_dataset_preprocess
from src.utils.data_preprocess import onebw_dataset_preprocess
from src.utils.data_preprocess import coqa_dataset_preprocess
from src.utils.data_preprocess import wmt14_en_fr_preprocess
def main():
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
parser.add_argument("--task", type=str, default="translation",
help="The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada"
"Summarization, ReadingComprehension]")
parser.add_argument("--input_file", type=str, default="",
help="The raw dataset path. ")
parser.add_argument("--dataset", type=str, default="onebw",
help="The name of dataset which should be processed, only for LanguageModeling task.")
parser.add_argument("--output_file", type=str, default="",
help="The output dataset path after preprocessing.")
parser.add_argument("--condition", type=str, default="test",
help="Process train or test dataset, including [train, test], only for 1BW and "
"CNN & DailyMail dataset.")
args_opt = parser.parse_args()
task = args_opt.task
condition = args_opt.condition
dataset = args_opt.dataset
input_file = args_opt.input_file
output_file = args_opt.output_file
if task.lower() == "languagemodeling":
print("Start processing Language Modeling dataset ...")
if dataset.lower() == "wikitext2" or dataset.lower() == "wikitext103":
wikitext_dataset_preprocess(input_file=input_file, output_file=output_file)
elif dataset.lower() == "ptb":
ptb_dataset_preprocess(input_file=input_file, output_file=output_file)
elif dataset.lower() == "onebw":
onebw_dataset_preprocess(condition, input_file=input_file, output_file=output_file)
else:
raise ValueError("Only support wikitext2, wikitext103, ptb, onebw dataset")
elif task.lower() == "lambada":
print("Start processing Lambada dataset ...")
lambada_dataset_preprocess(input_file=input_file, output_file=output_file)
elif task.lower() == "cbt":
print("Start processing CBT dataset ...")
cbt_dataset_preprocess(input_file=input_file, output_file=output_file)
elif task.lower() == "readingcomprehension":
print("Start processing ReadingComprehension dataset ...")
coqa_dataset_preprocess(input_file=input_file, output_file=output_file)
elif task.lower() == "summarization":
print("Start processing Summarization dataset ...")
elif task.lower() == "translation":
print("Start processing Translation dataset ...")
wmt14_en_fr_preprocess(input_file=input_file, output_file=output_file)
else:
raise ValueError("Only support Language Modeling, CBT, Translation, Lambada, "
"Summarization, Reading Comprehension task.")
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,107 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python implementation of BLEU and smooth-BLEU.
This module provides a Python implementation of BLEU and smooth-BLEU.
Smooth BLEU is computed following the method outlined in the paper:
Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
evaluation metrics for machine translation. COLING 2004.
"""
import collections
import math
def _get_ngrams(segment, max_order):
"""
Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
"""
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (references, translation) in zip(reference_corpus, translation_corpus):
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) /
(possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) /
possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
bp = math.exp(1 - 1. / ratio)
bleu = geo_mean * bp
return (bleu, precisions, bp, ratio, translation_length, reference_length)

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long