forked from mindspore-Ecosystem/mindspore
add gpt2 to model zoo
This commit is contained in:
parent
44e068f643
commit
a465a48fb7
|
@ -0,0 +1,931 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [GPT-2模型](#GPT-2模型)
|
||||
- [模型架构](#模型架构)
|
||||
- [下游任务](#下游任务)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [模型转换](#模型转换)
|
||||
- [准备数据集](#准备数据集)
|
||||
- [Language Modeling 语言建模任务](#Language Modeling语言建模任务)
|
||||
- [Children's Book Test 任务](#Children's Book Test任务)
|
||||
- [LAMBADA 任务](#LAMBADA任务)
|
||||
- [Reading Comprehension 任务](#Reading Comprehension任务)
|
||||
- [Summarization 任务](#Summarization任务)
|
||||
- [Translation 任务](#Translation任务)
|
||||
- [配置](#配置)
|
||||
- [微调&评估过程](#微调&训练评估过程)
|
||||
- [Language Modeling 任务](#Language Modeling任务)
|
||||
- 微调
|
||||
- 评估
|
||||
- [Children's Book Test 任务](#Children's Book Test任务)
|
||||
- 评估
|
||||
- [LAMBADA 任务](#LAMBADA任务)
|
||||
- 评估
|
||||
- [Reading Comprehension 任务](#Reading Comprehension任务)
|
||||
- 评估
|
||||
- [Summarization 任务](#Summarization任务)
|
||||
- 评估
|
||||
- [Translation 任务](#Translation任务)
|
||||
- 评估
|
||||
- [环境要求](#环境要求)
|
||||
- [平台](#平台)
|
||||
- [其他要求](#其他要求)
|
||||
- [性能](#性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [Language Modeling 任务](#Language Modeling任务)
|
||||
- [Children's Book Test 任务](#Children's Book Test任务)
|
||||
- [LAMBADA 任务](#LAMBADA任务)
|
||||
- [Reading Comprehension 任务](#Reading Comprehension任务)
|
||||
- [Summarization 任务](#Summarization任务)
|
||||
- [Translation 任务](#Translation任务)
|
||||
- [训练性能](#训练性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [其他](#其他)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# GPT-2模型
|
||||
|
||||
[GPT-2介绍](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 由Open于2019年发布。GPT-2模型是继承于GPT模型,GPT-2是一个非常庞大的语言模型,它主要是用于预测下一个单词。按照参数量的大小,GPT-2模型可分为small(117M)、medium(345M)、large(762M)、xlarge(1542M)。
|
||||
|
||||
[GPT-2介绍](https://openai.com/blog/better-language-models/)
|
||||
|
||||
[GPT-2论文](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf): Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
|
||||
|
||||
# 模型架构
|
||||
|
||||
GPT-2模型由Transformer的解码器实现,Transformer包括多个编码器层和多个解码器层,但在GPT-2模型中仅使用了Transformer的解码器部分。
|
||||
微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。
|
||||
测试过程中,通过微调后的模型预测结果,对于某些任务可以直接进行zero-shot评估即可。
|
||||
|
||||
# 下游任务
|
||||
|
||||
本文主要涉及6个下游任务,包括:
|
||||
|
||||
- Language Modeling 任务
|
||||
- Children‘s Book Test 任务
|
||||
- LAMBADA任务
|
||||
- Reading Comprehension任务
|
||||
- Summarization任务
|
||||
- Translation任务
|
||||
|
||||
数据集相关信息,参见[https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)。
|
||||
|
||||
## 脚本说明
|
||||
|
||||
GPT-2脚本及代码结构如下:
|
||||
|
||||
```text
|
||||
├── GPT-2
|
||||
├── README.md // MASS模型介绍
|
||||
├── scripts
|
||||
│ ├──run_cbt.sh // CBT任务的微调&评估脚本
|
||||
│ ├──run_lambada.sh // LAMBADA任务的微调&评估脚本
|
||||
│ ├──run_language_model.sh // 语言建模任务的微调&评估脚本
|
||||
│ ├──run_read_comprehension.sh // 阅读理解任务的微调&评估脚本
|
||||
│ ├──run_summarization.sh // 摘要生成任务的微调&评估脚本
|
||||
│ ├──run_translation.sh // 翻译任务的微调&评估脚本
|
||||
├──src
|
||||
│ ├──clip_grad_utils.py // 用于梯度裁剪
|
||||
| ├──dataset.py // 数据集加载用于微调或推理
|
||||
│ ├──finetune_eval_config.py // 微调和推理配置文件
|
||||
│ ├──gpt2_for_finetune.py // 用于梯度裁剪
|
||||
| ├──GPT2_generation.py // 生成模块
|
||||
│ ├──GPT2_model.py // GPT2模型脚本
|
||||
│ ├──GPT2ForCBT.py // CBT任务的模型脚本
|
||||
│ ├──GPT2ForLanguageModel.py // 语言建模任务的模型脚本
|
||||
│ ├──GPT2ForReadComprehension.py // 阅读理解任务的模型脚本
|
||||
│ ├──GPT2ForSummarization.py // 摘要生成任务的模型脚本
|
||||
│ ├──GPT2ForTranslation.py // 翻译任务的模型脚本
|
||||
│ ├──weight_init.py // 初始化权重
|
||||
│ ├──utils
|
||||
│ ├──bleu_score.py // 用于计算BLEU分数
|
||||
│ ├──rouge_score.py // 用于计算ROUGE分数
|
||||
│ ├──CrossEntropy.py // 交叉熵损失
|
||||
│ ├──data_preprocess.py // 数据集预处理脚本
|
||||
│ ├──generation_utils.py // 用于帮助生成模型,包含采样等方法
|
||||
│ ├──get_config_setting.py // 获取配置信息
|
||||
│ ├──task_utils.py // 辅助下游任务的功能脚本
|
||||
│ ├──lr_schedule.py // 学习率策略脚本
|
||||
│ ├──metric_method.py // 下游任务的评价指标
|
||||
│ ├──tensor_manipulations.py // 涉及张量操作
|
||||
│ ├──tokenization.py // 标记化,包含BPE编码和解码
|
||||
│ ├──pretrain-data
|
||||
│ ├──stopwords.txt // 用于LAMBADA任务的stopword filter
|
||||
├──create_cbt_data.py // 用于CBT任务创建mindrecord
|
||||
├──create_lambada_data.py // 用于lambada任务创建mindrecord
|
||||
├──create_lambada_data.py // 用于其他任务创建mindrecord
|
||||
├──create_summary_data.py // 用于summarization任务创建mindrecord
|
||||
├──download_cnn_dailymail.py // 下载CNN & Dailymail数据集
|
||||
├──cnn_dataset_sampler.py // CNN & Dailymail训练集采样器
|
||||
├──eval_rc_addition_answer.py // 使用addition_answer评估阅读理解任务
|
||||
├──run_CBT_task.py // CBT任务微调&推理API入口
|
||||
├──run_lambada.py // LAMBADA任务微调&推理API入口
|
||||
├──run_language_mdoel.py // 语言建模任务微调&推理API入口
|
||||
├──run_ReadComprehension.py // 阅读理解任务微调&推理API入口
|
||||
├──run_summarization.py // 摘要生成任务微调&推理API入口
|
||||
├──run_translation.py // 翻译任务微调&推理API入口
|
||||
├──task_dataset_preprocess.py // 各个任务的数据集处理入口
|
||||
├──convert_tf_ckpt
|
||||
│ ├──read_weight_tf.py // 读取tensorflow下的预训练模型
|
||||
│ ├──trans_dict.py // 模型参数名称字典
|
||||
│ ├──save_weight_ms.py // 生成mindspore ckpt
|
||||
├──third_party
|
||||
│ ├──gpt2-merges.txt
|
||||
│ ├──gpt2-vocab.json // GPT-2预训练词表
|
||||
│ ├──bleu.py // 辅助bleu值计算的第三方代码
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 模型转换
|
||||
|
||||
- 下载GPT-2的预训练模型 [GPT-2预训练模型下载](https://github.com/openai/gpt-2/blob/master/download_model.py)
|
||||
|
||||
- 在tensorflow的环境下,运行`read_weight_tf.py`,示例代码如下:
|
||||
|
||||
`python read_weight_tf.py --ckpt_file_path=/{path}/model.ckpt`
|
||||
|
||||
- 在mindspore的环境下,运行`save_weight_ms.py`,示例代码如下:
|
||||
|
||||
`python save_weight_ms.py --output_file_name="mindspore_gpt2_small.ckpt"`
|
||||
|
||||
## 准备数据集
|
||||
|
||||
### Language Modeling语言建模任务
|
||||
|
||||
#### WikiText2 、WikiText103、PTB、1BW 数据集
|
||||
|
||||
- [WikiText2数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) 解压后使用`wikitext-2 /wiki.test.tokens`作为测试集
|
||||
- [WikiText103数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 解压后使用`wikitext-103 /wiki.test.tokens`作为测试集
|
||||
- [PTB数据集下载](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz) 解压后使用 `/simple-examples/data/ptb.test.txt` 测试集,使用 `/simple-examples/data/ptb.test.txt` 作为训练集
|
||||
- [1BW数据集下载](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz) 解压后使用`1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050`作为测试集,使用`1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00001-of-00100`作为原始训练集,进行随机采样后得到30000条训练集样本
|
||||
|
||||
使用`task_dataset_preprocess.py`可以对以上数据集进行清洗。
|
||||
|
||||
`task_dataset_preprocess.py`的主要参数如下:
|
||||
|
||||
```bash
|
||||
--task: The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada, Summarization, ReadingComprehension].
|
||||
--input_file: The raw dataset path.
|
||||
--dataset: The name of dataset which should be processed, only for LanguageModeling task.
|
||||
--output_file: The output dataset path after preprocessing.
|
||||
--condition: Process train or test dataset, including [train, test], only for 1BW and CNN & DailyMail dataset.
|
||||
```
|
||||
|
||||
示例代码如下:
|
||||
|
||||
清洗PTB训练集和测试集
|
||||
|
||||
```bash
|
||||
python task_dataset_preprocess.py --task "LanguageModeling" --input_file /{path}/ptb.test.txt --dataset "ptb" --output_file /{path}/ptb_clean_test.txt --condition "test"
|
||||
```
|
||||
|
||||
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
`create_lm_data.py`的主要参数如下:
|
||||
|
||||
```bash
|
||||
--input_file: Input raw text file.
|
||||
--output_file: Output MindRecord file.
|
||||
--num_splits: The MindRecord file will be split into the number of partition.
|
||||
--max_seq_length: Maximum sequence length.
|
||||
--vocab_file: url of gpt2-vocab.json.
|
||||
--merge_file: url of gpt2-merges.txt
|
||||
```
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_lm_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
|
||||
```
|
||||
|
||||
### Children's Book Test任务
|
||||
|
||||
#### CBT-CN / CBT-NE 数据集
|
||||
|
||||
- [CBT数据集下载](http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz) 使用在`/data`目录下使用`cbtest_CN_valid_2000ex.txt、cbtest_NE_valid_2000ex.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
|
||||
|
||||
```bash
|
||||
python task_dataset_preprocess.py --task "CBT" --input_file /{path}/cbtest_CN_valid_2000ex.txt --dataset "cbt" --output_file /{path}/cbt_cn_valid.txt
|
||||
```
|
||||
|
||||
使用`create_cbt_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
`create_cbt_data.py`的主要参数如下:
|
||||
|
||||
```bash
|
||||
--input_file: Input raw text file.
|
||||
--output_file: Output MindRecord file.
|
||||
--num_splits: The MindRecord file will be split into the number of partition.
|
||||
--max_seq_length: Maximum sequence length.
|
||||
--num_choice: Number of choices.
|
||||
--vocab_file: url of gpt2-vocab.json.
|
||||
--merge_file: url of gpt2-merges.txt
|
||||
```
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_cbt_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --num_choice 10 --vocab_file={path} --merge_file={path}
|
||||
```
|
||||
|
||||
### LAMBADA任务
|
||||
|
||||
#### LAMBADA 数据集
|
||||
|
||||
- [LAMBADA数据集下载](https://zenodo.org/record/2630551#.X-yCSTTithH) 使用`lambada_test_plain_text.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
|
||||
|
||||
```bash
|
||||
python task_dataset_preprocess.py --task "LAMBADA" --input_file /{path}/lambada_test_plain_text.txt --dataset "LAMBADA" --output_file /{path}/lambada_test_clean.txt
|
||||
```
|
||||
|
||||
使用`create_lambada_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
`create_lambada_data.py`的主要参数如下:
|
||||
|
||||
```bash
|
||||
--input_file: Input raw text file.
|
||||
--output_file: Output MindRecord file.
|
||||
--num_splits: The MindRecord file will be split into the number of partition.
|
||||
--max_seq_length: Maximum sequence length.
|
||||
--vocab_file: url of gpt2-vocab.json.
|
||||
--merge_file: url of gpt2-merges.txt
|
||||
```
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_lambada_data.py --input_file /{path}/lambada_test_clean.txt --output_file /{path}/lambada-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
|
||||
```
|
||||
|
||||
### Reading Comprehension 任务
|
||||
|
||||
#### CoQA数据集
|
||||
|
||||
- [CoQA数据集下载](http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json) 使用`coqa-dev-v1.0.json`作为该任务的评估集,清洗该数据集,示例代码如下:
|
||||
|
||||
```bash
|
||||
python task_dataset_preprocess.py --task "ReadingComprehension" --input_file /{path}/coqa-dev-v1.0.json --dataset "coqa" --output_file /{path}/coqa_dev.txt
|
||||
```
|
||||
|
||||
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_lm_data.py --input_file /{path}/coqa_dev.txt --output_file /{path}/coqa-dev-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
|
||||
```
|
||||
|
||||
### Summarization 任务
|
||||
|
||||
#### CNN & Dailymail数据集
|
||||
|
||||
- 下载该数据集,使用`download_cnn_dailymail.py`脚本进行下载,示例代码如下:
|
||||
|
||||
```bash
|
||||
下载测试集
|
||||
python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split test
|
||||
|
||||
下载训练集
|
||||
python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split train
|
||||
```
|
||||
|
||||
从训练集中随机采用10000条样本作为最终的微调的训练集,使用`cnn_dataset_sampler.py`脚本进行训练的采样操作,生成新的训练集,示例代码如下:
|
||||
|
||||
```bash
|
||||
GPT-2 small和GPT-2 medium模型的训练集中seq_length=1024, 因此该脚本中设置max_length=1022
|
||||
python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
|
||||
--output_path="/{path}/cnn_train_hint_small.txt"
|
||||
--replace_hint="true"
|
||||
--sample="true"
|
||||
--max_length=1022
|
||||
--prob=0.25
|
||||
--max_items=10000
|
||||
--hint="TL;DR:"
|
||||
|
||||
|
||||
GPT-2 large模型的训练集中seq_length=768,因此该脚本中设置max_length=766
|
||||
python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
|
||||
--output_path="/{path}/cnn_train_hint_large.txt"
|
||||
--replace_hint="true"
|
||||
--sample="true"
|
||||
--max_length=766
|
||||
--prob=0.25
|
||||
--max_items=10000
|
||||
--hint="TL;DR:"
|
||||
```
|
||||
|
||||
使用`create_summary_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_summary_data.py --input_file /{path}/cnn_dailymail_test.txt --output_file /{path}/cnn_dailymail-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} --mode 'cnn_dailymail'
|
||||
```
|
||||
|
||||
### Translation 任务
|
||||
|
||||
#### WMT14 En-Fr数据集
|
||||
|
||||
- [WMT14 En-Fr数据集下载](http://statmt.org/wmt14/test-full.tgz) 使用`newstest2014-fren-ref.en.sgm`和`newstest2014-fren-ref.fr.sgm`作为该任务的评估集,合并且清洗该数据集,示例代码如下:
|
||||
|
||||
```bash
|
||||
python task_dataset_preprocess.py --task "Translation" --input_file /{path}/test-full --dataset "wmt14" --output_file /{path}/wmt14
|
||||
```
|
||||
|
||||
在`output_file`路径下会生成两个文件`wmt14.en_fr.txt`和`wmt14.fr_en.txt`,分别用于评估`En-Fr`和`Fr-En`。
|
||||
|
||||
使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
|
||||
|
||||
示例代码如下:
|
||||
|
||||
```bash
|
||||
python create_lm_data.py --input_file /{path}/wmt14.en_fr.txt --output_file /{path}/en-fr-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
|
||||
|
||||
python create_lm_data.py --input_file /{path}/wmt14.fr_en.txt --output_file /{path}/fr-en-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
`src/finetune_eval_config.py`为GPT-2模型训练和推理的配置文件,便于为大多数选项及参数赋值,包括GPT-2 模型规模、模型的配置、优化器参数等。
|
||||
有关属性的详细信息,参见`src/finetune_eval_config.py`文件。
|
||||
|
||||
## 微调&评估过程
|
||||
|
||||
### Language Modeling 语言建模任务
|
||||
|
||||
#### 微调
|
||||
|
||||
- PTB数据集
|
||||
|
||||
GPT-2 small / GPT-2 medium / GPT-2 large模型需要在PTB训练集上进行微调。微调模型时,只需要使用shell脚本`scripts/run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`scripts/run_language_model.sh`脚本。
|
||||
|
||||
微调模型时,首先配置`src/finetune_eval_config.py`中的选项:
|
||||
|
||||
将`cfg`下的`gpt2_network`设置为相应的GPT-2模型大小`[small/medium/large]`。
|
||||
将`cfg`下的`optimizer`设置为`Lamb`,进行优化器的选择(可采用'momentum/adam/lamb’)。
|
||||
选定了GPT-2模型后需要设置模型的参数,包括`batch_size`和`seq_length`。
|
||||
|
||||
而后执行`scripts/run_language_model.sh`这个shell脚本:
|
||||
|
||||
```bash
|
||||
sh scripts/run_language_model.sh --device_target="Ascend"
|
||||
--do_train="true"
|
||||
--do_eval="false"
|
||||
--epoch_num=1
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--save_finetune_ckpt_path={save_finetune_ckpt_path}
|
||||
--load_pretrain_ckpt_path={load_pretrain_ckpt_path}
|
||||
--train_data_file_path={train_data_file_path}
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。
|
||||
|
||||
```bash
|
||||
sh scripts/run_language_model.sh [--options]
|
||||
```
|
||||
|
||||
`run_language_model.sh`的用法如下:
|
||||
|
||||
```text
|
||||
usage: run_language_model.sh [--device_target DEVICE_TARGET] [--device_id N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--metric_method The eval method including [PPL]. Default: "PPL"
|
||||
--do_train Enable train. Default: "false"
|
||||
--do_eval Enable evaluation. Default: "true"
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
|
||||
--epoch_num Epoch number. Default: 1
|
||||
--train_data_shuffle Enable train data shuffle. Default: "true"
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: "false"
|
||||
--save_finetune_ckpt_path Save the finetuned checkpoint path
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path for train
|
||||
--load_finetune_ckpt_path Load the checkpoint file path for evaluation
|
||||
--train_data_file_path Data path, it is better to use absolute path
|
||||
--eval_data_file_path Data path, it is better to use absolute path
|
||||
```
|
||||
|
||||
- 1BW数据集
|
||||
|
||||
GPT-2 large模型需要在1BW训练集上进行微调。微调模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。该微调方法与PTB数据集的一致。
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`WikiText2/WikiText103/PTB/1BW`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用PPL,即设置`--metric_method="PPL"`。
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_language_model.sh`这个shell脚本,若该模型在某个数据集上被微调了,则使用该模型进行对应测试集的评估时需要设置`--eval_type="finetuned"`,否则设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是微调好后的checkpoint文件位置
|
||||
|
||||
```bash
|
||||
sh scripts/run_language_model.sh --device_target="Ascend"
|
||||
--metric_method="PPL"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="finetuned"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。
|
||||
|
||||
### Children's Book Test任务
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`CBT-CN/CBT-NE`验证集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy,即设置`--metric_method="Accuracy"`。
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_cbt.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_CBT_task.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_cbt.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
|
||||
|
||||
```bash
|
||||
sh scripts/run_cbt.sh --device_target="Ascend"
|
||||
--num_choice=10
|
||||
--metric_method="Accuarcy"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="zero-shot"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。
|
||||
|
||||
```bash
|
||||
sh scripts/run_cbt.sh [--options]
|
||||
```
|
||||
|
||||
`run_cbt.sh`的用法如下:
|
||||
|
||||
```text
|
||||
usage: run_CBT_task.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--num_choice The number of choice in CBT task
|
||||
--metric_method The eval method including [Accuracy]. Default: "Accuracy"
|
||||
--do_train Enable train. Default: "false"
|
||||
--do_eval Enable evaluation. Default: "true"
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
|
||||
--epoch_num Epoch number. Default: 1
|
||||
--train_data_shuffle Enable train data shuffle. Default: "true"
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: "false"
|
||||
--save_finetune_ckpt_path Save the finetuned checkpoint path
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path for train
|
||||
--load_finetune_ckpt_path Load the checkpoint file path for evaluation
|
||||
--train_data_file_path Data path, it is better to use absolute path
|
||||
--eval_data_file_path Data path, it is better to use absolute path
|
||||
|
||||
```
|
||||
|
||||
### LAMBADA任务
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`LAMBADA`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy和PPL,即设置`--metric_method="Accuracy"` 或者`--metric_method="PPL"`。
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_lambada.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_lambada.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_lambada.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
|
||||
|
||||
评估Accuracy
|
||||
|
||||
```bash
|
||||
sh scripts/run_lambada.sh --device_target="Ascend"
|
||||
--metric_method="Accuarcy"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="zero-shot"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--generate_length_dynamically="true"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
--tokenizer_file_path={tokenizer_file_path}
|
||||
--stop_word_file_path={stop_word_file_path}
|
||||
```
|
||||
|
||||
评估PPL
|
||||
|
||||
```bash
|
||||
sh scripts/run_lambada.sh --device_target="Ascend"
|
||||
--metric_method="PPL"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="zero-shot"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。
|
||||
|
||||
```bash
|
||||
sh scripts/run_lambada.sh [--options]
|
||||
```
|
||||
|
||||
```text
|
||||
usage: run_lambada.sh [--device_target DEVICE_TARGET] [--device_id N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--generate_length_dynamically GENERATE_LENGTH_DYNAMICALLY]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
[--tokenizer_file_path TOKENIZER_FILE_PATH]
|
||||
[--stop_word_file_path STOP_WORD_FILE_PATH]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--metric_method The eval method including [Accuracy, PPL]. Default: "Accuracy"
|
||||
--do_train Enable train. Default: "false"
|
||||
--do_eval Enable evaluation. Default: "true"
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
|
||||
--epoch_num Epoch number. Default: 1
|
||||
--train_data_shuffle Enable train data shuffle. Default: "true"
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: "false"
|
||||
--generate_length_dynamically Enable generate_length_Dynamically. Default: "true"
|
||||
--save_finetune_ckpt_path Save the checkpoint path
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path
|
||||
--load_finetune_ckpt_path Load the checkpoint file path
|
||||
--train_data_file_path Data path, it is better to use absolute path
|
||||
--eval_data_file_path Data path, it is better to use absolute path
|
||||
--tokenizer_file_path pretrained vocab and merge file path
|
||||
--stop_word_file_path The stop word file path
|
||||
```
|
||||
|
||||
### Reading Comprehension任务
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`CoQA`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="F1"` 。
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_read_comprehension.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_read_comprehension.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_read_comprehension.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
|
||||
|
||||
```bash
|
||||
sh scripts/run_read_comprehension.sh --device_target="Ascend"
|
||||
--metric_method="F1"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="zero-shot"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
--tokenizer_file_path={tokenizer_file_path}
|
||||
--generate_length=55
|
||||
--top_k=1
|
||||
--top_p="1.0"
|
||||
--temperature="1.0"
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。而后将得到的日志文件作为`eval_rc_addition_answer.py`脚本的`input_file`,同时将原CoQA开发集`coqa-dev-v1.0.json`作为`addition_file`。
|
||||
|
||||
执行`python eval_rc_addition_answer.py --input_file={path} --addition_file={path}`得到最终的F1值。
|
||||
|
||||
```bash
|
||||
sh scripts/run_read_comprehension.sh [--options]
|
||||
```
|
||||
|
||||
```text
|
||||
usage: run_read_comprehension.sh [--device_target DEVICE_TARGET] [--device_id N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
[--tokenizer_file_path TOKENIZER_FILE_PATH]
|
||||
[--generate_length N] [--top_k N] [--top_p TOP_P]
|
||||
[--temperature TEMPERATURE]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--metric_method The eval method including [F1]. Default: "F1"
|
||||
--do_train Enable train. Default: "false"
|
||||
--do_eval Enable evaluation. Default: "false"
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
|
||||
--epoch_num Epoch number. Default: 1
|
||||
--train_data_shuffle Enable train data shuffle. Default: "true"
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: "false"
|
||||
--save_finetune_ckpt_path Save the checkpoint path
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path
|
||||
--load_finetune_ckpt_path Load the checkpoint file path
|
||||
--train_data_file_path Data path, it is better to use absolute path
|
||||
--eval_data_file_path Data path, it is better to use absolute path
|
||||
--tokenizer_file_path pretrained vocab and merge file path
|
||||
--generate_length The generation length of answer sentence
|
||||
--top_k Parameter for Top-K sampling
|
||||
--top_p Parameter for Top-P sampling
|
||||
--temperature Parameter for generation, greater if generation more diverse
|
||||
```
|
||||
|
||||
### Summarization任务
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`CNN_Dailymail`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="ROUGE"` 。
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_summarization.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_summarization.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_summarization.sh`这个shell脚本,且对于`hint`的情况设置`eval_type="finetuned"`,`--load_finetune_ckpt_path`是需要加载微调好的checkpoint文件;而对于`no hint`的情况设置`eval_type="zero-shot"`除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
|
||||
|
||||
```bash
|
||||
sh scripts/run_summarization.sh --device_target="Ascend"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--metric_method="Rouge"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--generate_length=100
|
||||
--top_k=2
|
||||
--top_p="1.0"
|
||||
--temperature="1.0"
|
||||
--eval_type="finetuned"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
--tokenizer_file_path={tokenizer_file_path}
|
||||
|
||||
```
|
||||
|
||||
日志和输出文件可以在`./ms_log/`路径下获取。
|
||||
|
||||
```bash
|
||||
sh scripts/run_summarization.sh [--options]
|
||||
```
|
||||
|
||||
`run_summarization.sh`的用法如下:
|
||||
|
||||
```text
|
||||
usage: run_summarization.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--do_train Enable train. Default: false.
|
||||
--do_eval Enable evaluation. Default: false.
|
||||
--metric_method The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge. Default: "false"
|
||||
--epoch_num Epoch number. Default: 2.
|
||||
--train_data_shuffle Enable train data shuffle. Default: true.
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: false.
|
||||
--save_finetune_ckpt_path Save the checkpoint path.
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path.
|
||||
--load_finetune_ckpt_path Load the checkpoint file path.
|
||||
--train_data_file_path Data path, it is better to use absolute path.
|
||||
--eval_data_file_path Data path, it is better to use absolute path.
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.
|
||||
--top_k Top k tokens chosen for sampling.
|
||||
--top_p Top p accumulated probability threshold for logit to be counted.
|
||||
--generate_length The number of generated tokens.
|
||||
--temperature Temperature on logits for sampling.
|
||||
--tokenizer_file_path Vocab & merge file path.
|
||||
```
|
||||
|
||||
### Translation任务
|
||||
|
||||
#### 评估
|
||||
|
||||
GPT-2模型可以在`WMT14 En-Fr`和`WMT14 Fr-En`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用BLEU,即设置`--metric_method="BLEU"` 。
|
||||
|
||||
注:读者需要自行下载`bleu.py`脚本[脚本链接](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py), 而后将该脚本放置于`src/utils/`目录下
|
||||
|
||||
评估模型时,只需要使用shell脚本`run_translation.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_translation.py`脚本。
|
||||
|
||||
评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_translation.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
|
||||
|
||||
```bash
|
||||
sh scripts/run_translation.sh --device_target="Ascend"
|
||||
--metric_method="BLEU"
|
||||
--do_train="false"
|
||||
--do_eval="true"
|
||||
--eval_type="zero-shot"
|
||||
--train_data_shuffle="true"
|
||||
--eval_data_shuffle="false"
|
||||
--load_finetune_ckpt_path={load_eval_ckpt_path}
|
||||
--eval_data_file_path={eval_data_file_path}
|
||||
--tokenizer_file_path={tokenizer_file_path}
|
||||
--generate_length=100
|
||||
--top_k=1
|
||||
--top_p="1.0"
|
||||
--temperature="1.0"
|
||||
```
|
||||
|
||||
```bash
|
||||
sh scripts/run_translation.sh [--options]
|
||||
```
|
||||
|
||||
```text
|
||||
usage: run_translation.sh [--device_target DEVICE_TARGET] [--device_id N]
|
||||
[--metric_method METRIC_METHOD]
|
||||
[--do_train DO_TRAIN] [--do_eval DO_EVAL]
|
||||
[--eval_type EVAL_TYPE] [--epoch_num N]
|
||||
[--train_data_shuffle TRAIN_DATA_SHUFFLE]
|
||||
[--eval_data_shuffle EVAL_DATA_SHUFFLE]
|
||||
[--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
|
||||
[--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
|
||||
[--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
|
||||
[--train_data_file_path TRAIN_DATA_FILE_PATH]
|
||||
[--eval_data_file_path EVAL_DATA_FILE_PATH]
|
||||
[--tokenizer_file_path TOKENIZER_FILE_PATH]
|
||||
[--generate_length N] [--top_k N] [--top_p TOP_P]
|
||||
[--temperature TEMPERATURE]
|
||||
options:
|
||||
--device_target Device type. Default: "Ascend"
|
||||
--device_id ID of target device
|
||||
--metric_method The eval method including [BLEU]. Default: "BLEU"
|
||||
--do_train Enable train. Default: "false"
|
||||
--do_eval Enable evaluation. Default: "true"
|
||||
--eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
|
||||
--epoch_num Epoch number. Default: 1
|
||||
--train_data_shuffle Enable train data shuffle. Default: "true"
|
||||
--eval_data_shuffle Enable eval data shuffle. Default: "false"
|
||||
--save_finetune_ckpt_path Save the checkpoint path
|
||||
--load_pretrain_ckpt_path Load the checkpoint file path
|
||||
--load_finetune_ckpt_path Load the checkpoint file path
|
||||
--train_data_file_path Data path, it is better to use absolute path
|
||||
--eval_data_file_path Data path, it is better to use absolute path
|
||||
--tokenizer_file_path pretrained vocab and merge file path
|
||||
--generate_length The generation length of translation sentence
|
||||
--top_k Parameter for Top-K sampling
|
||||
--top_p Parameter for Top-P sampling
|
||||
--temperature Parameter for generation, greater if generation more diverse
|
||||
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
## 平台
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend处理器准备硬件环境。- 如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,申请通过即可获得资源。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 更多关于Mindspore的信息,请查看以下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
## 其他要求
|
||||
|
||||
```text
|
||||
math
|
||||
numpy
|
||||
copy
|
||||
collections
|
||||
re
|
||||
rouge 1.0.0
|
||||
datasets >=0.4.0
|
||||
json
|
||||
tensorflow
|
||||
```
|
||||
|
||||
# 性能
|
||||
|
||||
## 推理性能
|
||||
|
||||
### Language Modeling任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在Language Modeling任务中的PPL得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | PPL | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | WikiText2 | Ascend | zero-shot | 24.5 | 29.41 |
|
||||
| GPT-2 medium | WikiText2 | Ascend | zero-shot | 19.41 | 22.76 |
|
||||
| GPT-2 large | WikiText2 | Ascend | zero-shot | 17.08 | 19.93 |
|
||||
| GPT-2 small | WikiText103 | Ascend | zero-shot | 26.89 | 37.5 |
|
||||
| GPT-2 medium | WikiText103 | Ascend | zero-shot | 20.23 | 26.37 |
|
||||
| GPT-2 large | WikiText103 | Ascend | zero-shot | 17.48 | 22.05 |
|
||||
| GPT-2 small | PTB | Ascend | finetune | 23.91 | 65.85 |
|
||||
| GPT-2 medium | PTB | Ascend | finetune | 20.06 | 47.33 |
|
||||
| GPT-2 large | PTB | Ascend | finetune | 18.84 | 40.31 |
|
||||
| GPT-2 small | 1BW | Ascend | zero-shot | 63.13 | 75.2 |
|
||||
| GPT-2 medium | 1BW | Ascend | zero-shot | 50.98 | 55.72 |
|
||||
| GPT-2 large | 1BW | Ascend | finetune | 29.28 | 44.575 |
|
||||
|
||||
### Children's Book Test 任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在Children's Book Test 任务中的Accuracy得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | ACC | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | CBT-CN valid | Ascend | zero-shot | 87.85 | 87.65 |
|
||||
| GPT-2 medium | CBT-CN valid | Ascend | zero-shot | 92.1 | 92.35 |
|
||||
| GPT-2 large | CBT-CN valid | Ascend | zero-shot | 93.7 | 93.45 |
|
||||
| GPT-2 small | CBT-NE valid | Ascend | zero-shot | 85.1 | 83.4 |
|
||||
| GPT-2 medium | CBT-NE valid | Ascend | zero-shot | 87.55 | 87.1 |
|
||||
| GPT-2 large | CBT-NE valid | Ascend | zero-shot | 89.1 | 88 |
|
||||
|
||||
### LAMBADA 任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在LAMBADA 任务中的Accuracy和PPL得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | ACC | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | Lambada-test | Ascend | zero-shot | 45.99 | 45.99 |
|
||||
| GPT-2 medium | Lambada-test | Ascend | zero-shot | 58.59 | 55.48 |
|
||||
| GPT-2 large | Lambada-test | Ascend | zero-shot | 62.74 | 60.12 |
|
||||
|
||||
| 模型 | dataset | device | eval_type | PPL | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | Lambada-test | Ascend | zero-shot | 22.95 | 35.13 |
|
||||
| GPT-2 medium | Lambada-test | Ascend | zero-shot | 10.69 | 15.6 |
|
||||
| GPT-2 large | Lambada-test | Ascend | zero-shot | 8.64 | 10.87 |
|
||||
|
||||
### Reading Comprehension 任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在Reading Comprehension任务中的F1得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | F1 | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | CoQA | Ascend | zero-shot | 25.94 | 25~26 |
|
||||
| GPT-2 medium | CoQA | Ascend | zero-shot | 43.69 | 42~43 |
|
||||
| GPT-2 large | CoQA | Ascend | zero-shot | 49.39 | 49~51 |
|
||||
|
||||
### Summarization 任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在Summarization任务中的ROUGE得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | ROUGE | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | CNN_Dailymail(TL;DR) | Ascend | finetune | 21.4 | 16.8~17 |
|
||||
| GPT-2 medium | CNN_Dailymail(TL;DR) | Ascend | finetune | 25.94 | 20.6~20.9 |
|
||||
| GPT-2 large | CNN_Dailymail(TL;DR) | Ascend | finetune | 26.73 | 21.5~21.6 |
|
||||
|
||||
| 模型 | dataset | device | eval_type | ROUGE | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.08 | 15.03(xlarge) |
|
||||
| GPT-2 medium | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.16 | 15.03(xlarge) |
|
||||
| GPT-2 large | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.29 | 15.03(xlarge) |
|
||||
|
||||
### Translation 任务
|
||||
|
||||
下表展示了GPT-2 small、medium、large三种规模的模型在Translation任务中的BLEU得分情况。
|
||||
|
||||
| 模型 | dataset | device | eval_type | BLEU | OpenAI |
|
||||
| :--- | :------ | :------ | :------ | :------ | :------ |
|
||||
| GPT-2 small | WMT-14 Fr-En | Ascend | zero-shot | 4.49 | 0.7~0.8 |
|
||||
| GPT-2 medium | WMT-14 Fr-En | Ascend | zero-shot | 7.09 | 2.0~3.0 |
|
||||
| GPT-2 large | WMT-14 Fr-En | Ascend | zero-shot | 7.97 | 6.5~7.0 |
|
||||
| GPT-2 small | WMT-14 En-Fr | Ascend | zero-shot | 2.81 | 5(xlarge) |
|
||||
| GPT-2 medium | WMT-14 En-Fr | Ascend | zero-shot | 3.2 | 5(xlarge) |
|
||||
| GPT-2 large | WMT-14 En-Fr | Ascend | zero-shot | 3.06 | 5(xlarge) |
|
||||
|
||||
# 其他
|
||||
|
||||
该模型已在Ascend环境下环境下得到验证。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
CNN & DailyMail train dataset sampler
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
from random import random
|
||||
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def replace_split_word(read_path, output_path, tldr_str="TL;DR:", original_split='\t'):
|
||||
"""
|
||||
append tldr str
|
||||
"""
|
||||
with open(read_path, "r") as r, open(output_path, "a") as w:
|
||||
line = r.readline()
|
||||
while line:
|
||||
article = line[:line.find(original_split)] + ' ' + tldr_str + ' '
|
||||
ref = line[line.rfind(original_split) + 1:]
|
||||
w.write(article + ref)
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def sample(read_path, out_path, threshold=1.0, max_items=0xFFFFFFF):
|
||||
"""
|
||||
sample function
|
||||
"""
|
||||
cnt = 0
|
||||
total_cnt = 0
|
||||
with open(read_path, "r") as r, open(out_path, "a") as w:
|
||||
line = r.readline()
|
||||
while line:
|
||||
total_cnt += 1
|
||||
if cnt >= max_items:
|
||||
break
|
||||
if random() > threshold:
|
||||
line = r.readline()
|
||||
continue
|
||||
w.write(line)
|
||||
if (cnt + 1) % 3000 == 0:
|
||||
print("Now Processed Samples: {}, total: {}".format(cnt, total_cnt))
|
||||
cnt += 1
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def clip_article(input_path, out_path, hint, max_length):
|
||||
"""
|
||||
clip article that the sample (article + summary) exceed max_length
|
||||
"""
|
||||
tokenizer = Tokenizer()
|
||||
cnt = 0
|
||||
with open(input_path, "r") as r, open(out_path, "a+") as a:
|
||||
line = r.readline()
|
||||
while line:
|
||||
pos = line.rfind(hint)
|
||||
article = line[:pos]
|
||||
summary = line[pos:]
|
||||
if len(tokenizer.encode(line)) > max_length:
|
||||
l_article = tokenizer.encode(article)[:max_length - len(tokenizer.encode(summary))]
|
||||
article = tokenizer.decode(l_article) + " "
|
||||
if cnt % 1000 == 0:
|
||||
print(article + summary)
|
||||
print("==============================")
|
||||
cnt += 1
|
||||
a.write(article + summary)
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def sampler_dataset():
|
||||
"""
|
||||
run CNN & DailyMail train dataset sampler
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_path", type=str, default="",
|
||||
help="input file path")
|
||||
parser.add_argument("--output_path", type=str, default="",
|
||||
help="out file path")
|
||||
parser.add_argument("--replace_hint", type=str, default="true")
|
||||
parser.add_argument("--sample", type=str, default="true",
|
||||
help="do sample? true or false")
|
||||
parser.add_argument("--max_length", type=int, default=1022,
|
||||
help="max seq_length of input_raw_dataset")
|
||||
parser.add_argument("--prob", type=float, default=0.25,
|
||||
help="sample rate")
|
||||
parser.add_argument("--max_items", type=int, default=10000,
|
||||
help="max number of document")
|
||||
parser.add_argument("--hint", type=str, default="TL:DR;",
|
||||
help="hint text")
|
||||
args = parser.parse_args()
|
||||
|
||||
# temp_files, one for storing inputs in every stage, the other for storing middle results.
|
||||
temp_file_input = sys.path[0] + '/temp_file1_by_sampler_py.txt'
|
||||
temp_file_proc = sys.path[0] + '/temp_file2_by_sampler_py.txt'
|
||||
|
||||
read_path = args.input_path
|
||||
output_path = args.output_path
|
||||
prob = args.prob
|
||||
max_items = args.max_items
|
||||
hint = args.hint
|
||||
max_length = args.max_length
|
||||
split_str = '\t'
|
||||
|
||||
shutil.copyfile(read_path, temp_file_input)
|
||||
clip_article(temp_file_input, temp_file_proc, hint=split_str, max_length=max_length)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
if args.replace_hint.lower() == "true":
|
||||
replace_split_word(temp_file_input, temp_file_proc, hint, split_str)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
if args.sample.lower() == "true":
|
||||
sample(temp_file_input, temp_file_proc, prob, max_items)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
shutil.copyfile(temp_file_input, output_path)
|
||||
os.remove(temp_file_input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sampler_dataset()
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Read weight using tensorflow
|
||||
to read the parameters of the gpt-2 pretrained model from tensorflow checkpoint
|
||||
and save them into npy files for mindspore to load.
|
||||
|
||||
*this script is based on the gpt-2 model downloaded from openai.*
|
||||
"""
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from .trans_dict import trans_dict_tf
|
||||
|
||||
|
||||
def read_weight(ckpt_path):
|
||||
"""
|
||||
read weight
|
||||
Args:
|
||||
ckpt_path: the path of tensorflow checkpoint
|
||||
"""
|
||||
# model path and model name
|
||||
init_vars = tf.train.list_variables(ckpt_path)
|
||||
# load the model parameters into vars
|
||||
save_param_num = 0
|
||||
|
||||
for name, _ in init_vars:
|
||||
array = tf.train.load_variable(ckpt_path, name)
|
||||
# By this you can understand the next step easily
|
||||
name = name[6:].replace(r"/", ".")
|
||||
# skip 'model/' and change var names to avoid path mistake
|
||||
if name not in trans_dict_tf.keys():
|
||||
print(name + " is not in this model")
|
||||
else:
|
||||
np.save(trans_dict_tf[name] + ".npy", array)
|
||||
save_param_num = save_param_num + 1
|
||||
# save the parameters by 'npy'
|
||||
|
||||
print("finished!")
|
||||
print("save {num} parameters.".format(num=save_param_num))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
|
||||
parser.add_argument("--ckpt_file_path", type=str, default="",
|
||||
help="The tensorflow GPT-2 model checkpoint file path")
|
||||
args_opt = parser.parse_args()
|
||||
ckpt_path = args_opt.ckpt_file_path
|
||||
read_weight(ckpt_path=ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Save weight using mindspore, to load the parameters of gpt-2 model from npy file.
|
||||
npy files should be in the same path with this script. Otherwise you should change the path name of the script.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
from .trans_dict import trans_dict_tf
|
||||
|
||||
|
||||
def trans_model_parameter(ckpt_name):
|
||||
"""
|
||||
transform model parameters
|
||||
Args:
|
||||
ckpt_name (str): the name of the transformed checkpoint.
|
||||
"""
|
||||
file_names = [name for name in os.listdir() if name.endswith(".npy")]
|
||||
# to find all file names with suffix '.npy' in the current path.
|
||||
new_params_list = []
|
||||
for file_name in file_names:
|
||||
var_name = file_name[:-4]
|
||||
param_dict = {"name": var_name, "data": Tensor(np.load(file_name))}
|
||||
if var_name in trans_dict_tf.values():
|
||||
new_params_list.append(param_dict)
|
||||
print(var_name+" has been saved")
|
||||
|
||||
save_checkpoint(new_params_list, ckpt_name)
|
||||
# to load the parameters from npy files and save them as mindspore checkpoint
|
||||
print("Finished:the parameters have been saved into mindspore checkpoint.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
|
||||
parser.add_argument("--output_file_name", type=str, default="",
|
||||
help="The name of output checkpoint name")
|
||||
args_opt = parser.parse_args()
|
||||
ckpt_name = args_opt.output_file_name
|
||||
trans_model_parameter(ckpt_name=ckpt_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,892 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""transform diction"""
|
||||
trans_dict_tf = {
|
||||
'h0.attn.c_attn.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h0.attn.c_attn.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h0.attn.c_proj.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h0.attn.c_proj.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h0.ln_1.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h0.ln_1.g': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h0.ln_2.b': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta',
|
||||
'h0.ln_2.g': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma',
|
||||
'h0.mlp.c_fc.b': 'gpt2_decoder.layers.0.feedforward.c_fc.bias',
|
||||
'h0.mlp.c_fc.w': 'gpt2_decoder.layers.0.feedforward.c_fc.weight',
|
||||
'h0.mlp.c_proj.b': 'gpt2_decoder.layers.0.feedforward.c_proj.bias',
|
||||
'h0.mlp.c_proj.w': 'gpt2_decoder.layers.0.feedforward.c_proj.weight',
|
||||
'h1.attn.c_attn.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h1.attn.c_attn.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h1.attn.c_proj.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h1.attn.c_proj.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h1.ln_1.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h1.ln_1.g': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h1.ln_2.b': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta',
|
||||
'h1.ln_2.g': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma',
|
||||
'h1.mlp.c_fc.b': 'gpt2_decoder.layers.1.feedforward.c_fc.bias',
|
||||
'h1.mlp.c_fc.w': 'gpt2_decoder.layers.1.feedforward.c_fc.weight',
|
||||
'h1.mlp.c_proj.b': 'gpt2_decoder.layers.1.feedforward.c_proj.bias',
|
||||
'h1.mlp.c_proj.w': 'gpt2_decoder.layers.1.feedforward.c_proj.weight',
|
||||
'h2.attn.c_attn.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h2.attn.c_attn.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h2.attn.c_proj.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h2.attn.c_proj.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h2.ln_1.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h2.ln_1.g': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h2.ln_2.b': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta',
|
||||
'h2.ln_2.g': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma',
|
||||
'h2.mlp.c_fc.b': 'gpt2_decoder.layers.2.feedforward.c_fc.bias',
|
||||
'h2.mlp.c_fc.w': 'gpt2_decoder.layers.2.feedforward.c_fc.weight',
|
||||
'h2.mlp.c_proj.b': 'gpt2_decoder.layers.2.feedforward.c_proj.bias',
|
||||
'h2.mlp.c_proj.w': 'gpt2_decoder.layers.2.feedforward.c_proj.weight',
|
||||
'h3.attn.c_attn.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h3.attn.c_attn.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h3.attn.c_proj.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h3.attn.c_proj.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h3.ln_1.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h3.ln_1.g': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h3.ln_2.b': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta',
|
||||
'h3.ln_2.g': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma',
|
||||
'h3.mlp.c_fc.b': 'gpt2_decoder.layers.3.feedforward.c_fc.bias',
|
||||
'h3.mlp.c_fc.w': 'gpt2_decoder.layers.3.feedforward.c_fc.weight',
|
||||
'h3.mlp.c_proj.b': 'gpt2_decoder.layers.3.feedforward.c_proj.bias',
|
||||
'h3.mlp.c_proj.w': 'gpt2_decoder.layers.3.feedforward.c_proj.weight',
|
||||
'h4.attn.c_attn.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h4.attn.c_attn.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h4.attn.c_proj.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h4.attn.c_proj.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h4.ln_1.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h4.ln_1.g': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h4.ln_2.b': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta',
|
||||
'h4.ln_2.g': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma',
|
||||
'h4.mlp.c_fc.b': 'gpt2_decoder.layers.4.feedforward.c_fc.bias',
|
||||
'h4.mlp.c_fc.w': 'gpt2_decoder.layers.4.feedforward.c_fc.weight',
|
||||
'h4.mlp.c_proj.b': 'gpt2_decoder.layers.4.feedforward.c_proj.bias',
|
||||
'h4.mlp.c_proj.w': 'gpt2_decoder.layers.4.feedforward.c_proj.weight',
|
||||
'h5.attn.c_attn.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h5.attn.c_attn.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h5.attn.c_proj.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h5.attn.c_proj.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h5.ln_1.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h5.ln_1.g': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h5.ln_2.b': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta',
|
||||
'h5.ln_2.g': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma',
|
||||
'h5.mlp.c_fc.b': 'gpt2_decoder.layers.5.feedforward.c_fc.bias',
|
||||
'h5.mlp.c_fc.w': 'gpt2_decoder.layers.5.feedforward.c_fc.weight',
|
||||
'h5.mlp.c_proj.b': 'gpt2_decoder.layers.5.feedforward.c_proj.bias',
|
||||
'h5.mlp.c_proj.w': 'gpt2_decoder.layers.5.feedforward.c_proj.weight',
|
||||
'h6.attn.c_attn.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h6.attn.c_attn.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h6.attn.c_proj.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h6.attn.c_proj.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h6.ln_1.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h6.ln_1.g': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h6.ln_2.b': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta',
|
||||
'h6.ln_2.g': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma',
|
||||
'h6.mlp.c_fc.b': 'gpt2_decoder.layers.6.feedforward.c_fc.bias',
|
||||
'h6.mlp.c_fc.w': 'gpt2_decoder.layers.6.feedforward.c_fc.weight',
|
||||
'h6.mlp.c_proj.b': 'gpt2_decoder.layers.6.feedforward.c_proj.bias',
|
||||
'h6.mlp.c_proj.w': 'gpt2_decoder.layers.6.feedforward.c_proj.weight',
|
||||
'h7.attn.c_attn.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h7.attn.c_attn.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h7.attn.c_proj.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h7.attn.c_proj.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h7.ln_1.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h7.ln_1.g': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h7.ln_2.b': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta',
|
||||
'h7.ln_2.g': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma',
|
||||
'h7.mlp.c_fc.b': 'gpt2_decoder.layers.7.feedforward.c_fc.bias',
|
||||
'h7.mlp.c_fc.w': 'gpt2_decoder.layers.7.feedforward.c_fc.weight',
|
||||
'h7.mlp.c_proj.b': 'gpt2_decoder.layers.7.feedforward.c_proj.bias',
|
||||
'h7.mlp.c_proj.w': 'gpt2_decoder.layers.7.feedforward.c_proj.weight',
|
||||
'h8.attn.c_attn.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h8.attn.c_attn.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h8.attn.c_proj.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h8.attn.c_proj.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h8.ln_1.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h8.ln_1.g': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h8.ln_2.b': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta',
|
||||
'h8.ln_2.g': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma',
|
||||
'h8.mlp.c_fc.b': 'gpt2_decoder.layers.8.feedforward.c_fc.bias',
|
||||
'h8.mlp.c_fc.w': 'gpt2_decoder.layers.8.feedforward.c_fc.weight',
|
||||
'h8.mlp.c_proj.b': 'gpt2_decoder.layers.8.feedforward.c_proj.bias',
|
||||
'h8.mlp.c_proj.w': 'gpt2_decoder.layers.8.feedforward.c_proj.weight',
|
||||
'h9.attn.c_attn.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h9.attn.c_attn.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h9.attn.c_proj.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h9.attn.c_proj.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h9.ln_1.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h9.ln_1.g': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h9.ln_2.b': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta',
|
||||
'h9.ln_2.g': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma',
|
||||
'h9.mlp.c_fc.b': 'gpt2_decoder.layers.9.feedforward.c_fc.bias',
|
||||
'h9.mlp.c_fc.w': 'gpt2_decoder.layers.9.feedforward.c_fc.weight',
|
||||
'h9.mlp.c_proj.b': 'gpt2_decoder.layers.9.feedforward.c_proj.bias',
|
||||
'h9.mlp.c_proj.w': 'gpt2_decoder.layers.9.feedforward.c_proj.weight',
|
||||
'h10.attn.c_attn.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h10.attn.c_attn.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h10.attn.c_proj.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h10.attn.c_proj.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h10.ln_1.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h10.ln_1.g': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h10.ln_2.b': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta',
|
||||
'h10.ln_2.g': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma',
|
||||
'h10.mlp.c_fc.b': 'gpt2_decoder.layers.10.feedforward.c_fc.bias',
|
||||
'h10.mlp.c_fc.w': 'gpt2_decoder.layers.10.feedforward.c_fc.weight',
|
||||
'h10.mlp.c_proj.b': 'gpt2_decoder.layers.10.feedforward.c_proj.bias',
|
||||
'h10.mlp.c_proj.w': 'gpt2_decoder.layers.10.feedforward.c_proj.weight',
|
||||
'h11.attn.c_attn.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h11.attn.c_attn.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h11.attn.c_proj.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h11.attn.c_proj.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h11.ln_1.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h11.ln_1.g': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h11.ln_2.b': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta',
|
||||
'h11.ln_2.g': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma',
|
||||
'h11.mlp.c_fc.b': 'gpt2_decoder.layers.11.feedforward.c_fc.bias',
|
||||
'h11.mlp.c_fc.w': 'gpt2_decoder.layers.11.feedforward.c_fc.weight',
|
||||
'h11.mlp.c_proj.b': 'gpt2_decoder.layers.11.feedforward.c_proj.bias',
|
||||
'h11.mlp.c_proj.w': 'gpt2_decoder.layers.11.feedforward.c_proj.weight',
|
||||
'h12.attn.c_attn.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h12.attn.c_attn.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h12.attn.c_proj.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h12.attn.c_proj.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h12.ln_1.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h12.ln_1.g': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h12.ln_2.b': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta',
|
||||
'h12.ln_2.g': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma',
|
||||
'h12.mlp.c_fc.b': 'gpt2_decoder.layers.12.feedforward.c_fc.bias',
|
||||
'h12.mlp.c_fc.w': 'gpt2_decoder.layers.12.feedforward.c_fc.weight',
|
||||
'h12.mlp.c_proj.b': 'gpt2_decoder.layers.12.feedforward.c_proj.bias',
|
||||
'h12.mlp.c_proj.w': 'gpt2_decoder.layers.12.feedforward.c_proj.weight',
|
||||
'h13.attn.c_attn.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h13.attn.c_attn.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h13.attn.c_proj.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h13.attn.c_proj.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h13.ln_1.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h13.ln_1.g': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h13.ln_2.b': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta',
|
||||
'h13.ln_2.g': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma',
|
||||
'h13.mlp.c_fc.b': 'gpt2_decoder.layers.13.feedforward.c_fc.bias',
|
||||
'h13.mlp.c_fc.w': 'gpt2_decoder.layers.13.feedforward.c_fc.weight',
|
||||
'h13.mlp.c_proj.b': 'gpt2_decoder.layers.13.feedforward.c_proj.bias',
|
||||
'h13.mlp.c_proj.w': 'gpt2_decoder.layers.13.feedforward.c_proj.weight',
|
||||
'h14.attn.c_attn.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h14.attn.c_attn.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h14.attn.c_proj.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h14.attn.c_proj.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h14.ln_1.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h14.ln_1.g': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h14.ln_2.b': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta',
|
||||
'h14.ln_2.g': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma',
|
||||
'h14.mlp.c_fc.b': 'gpt2_decoder.layers.14.feedforward.c_fc.bias',
|
||||
'h14.mlp.c_fc.w': 'gpt2_decoder.layers.14.feedforward.c_fc.weight',
|
||||
'h14.mlp.c_proj.b': 'gpt2_decoder.layers.14.feedforward.c_proj.bias',
|
||||
'h14.mlp.c_proj.w': 'gpt2_decoder.layers.14.feedforward.c_proj.weight',
|
||||
'h15.attn.c_attn.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h15.attn.c_attn.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h15.attn.c_proj.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h15.attn.c_proj.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h15.ln_1.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h15.ln_1.g': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h15.ln_2.b': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta',
|
||||
'h15.ln_2.g': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma',
|
||||
'h15.mlp.c_fc.b': 'gpt2_decoder.layers.15.feedforward.c_fc.bias',
|
||||
'h15.mlp.c_fc.w': 'gpt2_decoder.layers.15.feedforward.c_fc.weight',
|
||||
'h15.mlp.c_proj.b': 'gpt2_decoder.layers.15.feedforward.c_proj.bias',
|
||||
'h15.mlp.c_proj.w': 'gpt2_decoder.layers.15.feedforward.c_proj.weight',
|
||||
'h16.attn.c_attn.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h16.attn.c_attn.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h16.attn.c_proj.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h16.attn.c_proj.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h16.ln_1.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h16.ln_1.g': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h16.ln_2.b': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta',
|
||||
'h16.ln_2.g': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma',
|
||||
'h16.mlp.c_fc.b': 'gpt2_decoder.layers.16.feedforward.c_fc.bias',
|
||||
'h16.mlp.c_fc.w': 'gpt2_decoder.layers.16.feedforward.c_fc.weight',
|
||||
'h16.mlp.c_proj.b': 'gpt2_decoder.layers.16.feedforward.c_proj.bias',
|
||||
'h16.mlp.c_proj.w': 'gpt2_decoder.layers.16.feedforward.c_proj.weight',
|
||||
'h17.attn.c_attn.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h17.attn.c_attn.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h17.attn.c_proj.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h17.attn.c_proj.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h17.ln_1.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h17.ln_1.g': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h17.ln_2.b': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta',
|
||||
'h17.ln_2.g': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma',
|
||||
'h17.mlp.c_fc.b': 'gpt2_decoder.layers.17.feedforward.c_fc.bias',
|
||||
'h17.mlp.c_fc.w': 'gpt2_decoder.layers.17.feedforward.c_fc.weight',
|
||||
'h17.mlp.c_proj.b': 'gpt2_decoder.layers.17.feedforward.c_proj.bias',
|
||||
'h17.mlp.c_proj.w': 'gpt2_decoder.layers.17.feedforward.c_proj.weight',
|
||||
'h18.attn.c_attn.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h18.attn.c_attn.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h18.attn.c_proj.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h18.attn.c_proj.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h18.ln_1.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h18.ln_1.g': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h18.ln_2.b': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta',
|
||||
'h18.ln_2.g': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma',
|
||||
'h18.mlp.c_fc.b': 'gpt2_decoder.layers.18.feedforward.c_fc.bias',
|
||||
'h18.mlp.c_fc.w': 'gpt2_decoder.layers.18.feedforward.c_fc.weight',
|
||||
'h18.mlp.c_proj.b': 'gpt2_decoder.layers.18.feedforward.c_proj.bias',
|
||||
'h18.mlp.c_proj.w': 'gpt2_decoder.layers.18.feedforward.c_proj.weight',
|
||||
'h19.attn.c_attn.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h19.attn.c_attn.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h19.attn.c_proj.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h19.attn.c_proj.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h19.ln_1.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h19.ln_1.g': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h19.ln_2.b': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta',
|
||||
'h19.ln_2.g': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma',
|
||||
'h19.mlp.c_fc.b': 'gpt2_decoder.layers.19.feedforward.c_fc.bias',
|
||||
'h19.mlp.c_fc.w': 'gpt2_decoder.layers.19.feedforward.c_fc.weight',
|
||||
'h19.mlp.c_proj.b': 'gpt2_decoder.layers.19.feedforward.c_proj.bias',
|
||||
'h19.mlp.c_proj.w': 'gpt2_decoder.layers.19.feedforward.c_proj.weight',
|
||||
'h20.attn.c_attn.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h20.attn.c_attn.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h20.attn.c_proj.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h20.attn.c_proj.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h20.ln_1.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h20.ln_1.g': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h20.ln_2.b': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta',
|
||||
'h20.ln_2.g': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma',
|
||||
'h20.mlp.c_fc.b': 'gpt2_decoder.layers.20.feedforward.c_fc.bias',
|
||||
'h20.mlp.c_fc.w': 'gpt2_decoder.layers.20.feedforward.c_fc.weight',
|
||||
'h20.mlp.c_proj.b': 'gpt2_decoder.layers.20.feedforward.c_proj.bias',
|
||||
'h20.mlp.c_proj.w': 'gpt2_decoder.layers.20.feedforward.c_proj.weight',
|
||||
'h21.attn.c_attn.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h21.attn.c_attn.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h21.attn.c_proj.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h21.attn.c_proj.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h21.ln_1.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h21.ln_1.g': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h21.ln_2.b': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta',
|
||||
'h21.ln_2.g': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma',
|
||||
'h21.mlp.c_fc.b': 'gpt2_decoder.layers.21.feedforward.c_fc.bias',
|
||||
'h21.mlp.c_fc.w': 'gpt2_decoder.layers.21.feedforward.c_fc.weight',
|
||||
'h21.mlp.c_proj.b': 'gpt2_decoder.layers.21.feedforward.c_proj.bias',
|
||||
'h21.mlp.c_proj.w': 'gpt2_decoder.layers.21.feedforward.c_proj.weight',
|
||||
'h22.attn.c_attn.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h22.attn.c_attn.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h22.attn.c_proj.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h22.attn.c_proj.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h22.ln_1.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h22.ln_1.g': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h22.ln_2.b': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta',
|
||||
'h22.ln_2.g': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma',
|
||||
'h22.mlp.c_fc.b': 'gpt2_decoder.layers.22.feedforward.c_fc.bias',
|
||||
'h22.mlp.c_fc.w': 'gpt2_decoder.layers.22.feedforward.c_fc.weight',
|
||||
'h22.mlp.c_proj.b': 'gpt2_decoder.layers.22.feedforward.c_proj.bias',
|
||||
'h22.mlp.c_proj.w': 'gpt2_decoder.layers.22.feedforward.c_proj.weight',
|
||||
'h23.attn.c_attn.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h23.attn.c_attn.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h23.attn.c_proj.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h23.attn.c_proj.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h23.ln_1.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h23.ln_1.g': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h23.ln_2.b': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta',
|
||||
'h23.ln_2.g': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma',
|
||||
'h23.mlp.c_fc.b': 'gpt2_decoder.layers.23.feedforward.c_fc.bias',
|
||||
'h23.mlp.c_fc.w': 'gpt2_decoder.layers.23.feedforward.c_fc.weight',
|
||||
'h23.mlp.c_proj.b': 'gpt2_decoder.layers.23.feedforward.c_proj.bias',
|
||||
'h23.mlp.c_proj.w': 'gpt2_decoder.layers.23.feedforward.c_proj.weight',
|
||||
'h24.attn.c_attn.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h24.attn.c_attn.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h24.attn.c_proj.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h24.attn.c_proj.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h24.ln_1.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h24.ln_1.g': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h24.ln_2.b': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta',
|
||||
'h24.ln_2.g': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma',
|
||||
'h24.mlp.c_fc.b': 'gpt2_decoder.layers.24.feedforward.c_fc.bias',
|
||||
'h24.mlp.c_fc.w': 'gpt2_decoder.layers.24.feedforward.c_fc.weight',
|
||||
'h24.mlp.c_proj.b': 'gpt2_decoder.layers.24.feedforward.c_proj.bias',
|
||||
'h24.mlp.c_proj.w': 'gpt2_decoder.layers.24.feedforward.c_proj.weight',
|
||||
'h25.attn.c_attn.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h25.attn.c_attn.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h25.attn.c_proj.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h25.attn.c_proj.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h25.ln_1.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h25.ln_1.g': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h25.ln_2.b': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta',
|
||||
'h25.ln_2.g': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma',
|
||||
'h25.mlp.c_fc.b': 'gpt2_decoder.layers.25.feedforward.c_fc.bias',
|
||||
'h25.mlp.c_fc.w': 'gpt2_decoder.layers.25.feedforward.c_fc.weight',
|
||||
'h25.mlp.c_proj.b': 'gpt2_decoder.layers.25.feedforward.c_proj.bias',
|
||||
'h25.mlp.c_proj.w': 'gpt2_decoder.layers.25.feedforward.c_proj.weight',
|
||||
'h26.attn.c_attn.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h26.attn.c_attn.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h26.attn.c_proj.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h26.attn.c_proj.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h26.ln_1.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h26.ln_1.g': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h26.ln_2.b': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta',
|
||||
'h26.ln_2.g': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma',
|
||||
'h26.mlp.c_fc.b': 'gpt2_decoder.layers.26.feedforward.c_fc.bias',
|
||||
'h26.mlp.c_fc.w': 'gpt2_decoder.layers.26.feedforward.c_fc.weight',
|
||||
'h26.mlp.c_proj.b': 'gpt2_decoder.layers.26.feedforward.c_proj.bias',
|
||||
'h26.mlp.c_proj.w': 'gpt2_decoder.layers.26.feedforward.c_proj.weight',
|
||||
'h27.attn.c_attn.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h27.attn.c_attn.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h27.attn.c_proj.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h27.attn.c_proj.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h27.ln_1.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h27.ln_1.g': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h27.ln_2.b': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta',
|
||||
'h27.ln_2.g': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma',
|
||||
'h27.mlp.c_fc.b': 'gpt2_decoder.layers.27.feedforward.c_fc.bias',
|
||||
'h27.mlp.c_fc.w': 'gpt2_decoder.layers.27.feedforward.c_fc.weight',
|
||||
'h27.mlp.c_proj.b': 'gpt2_decoder.layers.27.feedforward.c_proj.bias',
|
||||
'h27.mlp.c_proj.w': 'gpt2_decoder.layers.27.feedforward.c_proj.weight',
|
||||
'h28.attn.c_attn.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h28.attn.c_attn.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h28.attn.c_proj.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h28.attn.c_proj.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h28.ln_1.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h28.ln_1.g': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h28.ln_2.b': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta',
|
||||
'h28.ln_2.g': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma',
|
||||
'h28.mlp.c_fc.b': 'gpt2_decoder.layers.28.feedforward.c_fc.bias',
|
||||
'h28.mlp.c_fc.w': 'gpt2_decoder.layers.28.feedforward.c_fc.weight',
|
||||
'h28.mlp.c_proj.b': 'gpt2_decoder.layers.28.feedforward.c_proj.bias',
|
||||
'h28.mlp.c_proj.w': 'gpt2_decoder.layers.28.feedforward.c_proj.weight',
|
||||
'h29.attn.c_attn.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h29.attn.c_attn.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h29.attn.c_proj.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h29.attn.c_proj.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h29.ln_1.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h29.ln_1.g': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h29.ln_2.b': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta',
|
||||
'h29.ln_2.g': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma',
|
||||
'h29.mlp.c_fc.b': 'gpt2_decoder.layers.29.feedforward.c_fc.bias',
|
||||
'h29.mlp.c_fc.w': 'gpt2_decoder.layers.29.feedforward.c_fc.weight',
|
||||
'h29.mlp.c_proj.b': 'gpt2_decoder.layers.29.feedforward.c_proj.bias',
|
||||
'h29.mlp.c_proj.w': 'gpt2_decoder.layers.29.feedforward.c_proj.weight',
|
||||
'h30.attn.c_attn.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h30.attn.c_attn.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h30.attn.c_proj.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h30.attn.c_proj.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h30.ln_1.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h30.ln_1.g': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h30.ln_2.b': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta',
|
||||
'h30.ln_2.g': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma',
|
||||
'h30.mlp.c_fc.b': 'gpt2_decoder.layers.30.feedforward.c_fc.bias',
|
||||
'h30.mlp.c_fc.w': 'gpt2_decoder.layers.30.feedforward.c_fc.weight',
|
||||
'h30.mlp.c_proj.b': 'gpt2_decoder.layers.30.feedforward.c_proj.bias',
|
||||
'h30.mlp.c_proj.w': 'gpt2_decoder.layers.30.feedforward.c_proj.weight',
|
||||
'h31.attn.c_attn.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h31.attn.c_attn.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h31.attn.c_proj.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h31.attn.c_proj.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h31.ln_1.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h31.ln_1.g': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h31.ln_2.b': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta',
|
||||
'h31.ln_2.g': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma',
|
||||
'h31.mlp.c_fc.b': 'gpt2_decoder.layers.31.feedforward.c_fc.bias',
|
||||
'h31.mlp.c_fc.w': 'gpt2_decoder.layers.31.feedforward.c_fc.weight',
|
||||
'h31.mlp.c_proj.b': 'gpt2_decoder.layers.31.feedforward.c_proj.bias',
|
||||
'h31.mlp.c_proj.w': 'gpt2_decoder.layers.31.feedforward.c_proj.weight',
|
||||
'h32.attn.c_attn.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h32.attn.c_attn.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h32.attn.c_proj.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h32.attn.c_proj.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h32.ln_1.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h32.ln_1.g': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h32.ln_2.b': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta',
|
||||
'h32.ln_2.g': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma',
|
||||
'h32.mlp.c_fc.b': 'gpt2_decoder.layers.32.feedforward.c_fc.bias',
|
||||
'h32.mlp.c_fc.w': 'gpt2_decoder.layers.32.feedforward.c_fc.weight',
|
||||
'h32.mlp.c_proj.b': 'gpt2_decoder.layers.32.feedforward.c_proj.bias',
|
||||
'h32.mlp.c_proj.w': 'gpt2_decoder.layers.32.feedforward.c_proj.weight',
|
||||
'h33.attn.c_attn.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h33.attn.c_attn.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h33.attn.c_proj.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h33.attn.c_proj.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h33.ln_1.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h33.ln_1.g': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h33.ln_2.b': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta',
|
||||
'h33.ln_2.g': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma',
|
||||
'h33.mlp.c_fc.b': 'gpt2_decoder.layers.33.feedforward.c_fc.bias',
|
||||
'h33.mlp.c_fc.w': 'gpt2_decoder.layers.33.feedforward.c_fc.weight',
|
||||
'h33.mlp.c_proj.b': 'gpt2_decoder.layers.33.feedforward.c_proj.bias',
|
||||
'h33.mlp.c_proj.w': 'gpt2_decoder.layers.33.feedforward.c_proj.weight',
|
||||
'h34.attn.c_attn.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h34.attn.c_attn.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h34.attn.c_proj.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h34.attn.c_proj.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h34.ln_1.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h34.ln_1.g': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h34.ln_2.b': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta',
|
||||
'h34.ln_2.g': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma',
|
||||
'h34.mlp.c_fc.b': 'gpt2_decoder.layers.34.feedforward.c_fc.bias',
|
||||
'h34.mlp.c_fc.w': 'gpt2_decoder.layers.34.feedforward.c_fc.weight',
|
||||
'h34.mlp.c_proj.b': 'gpt2_decoder.layers.34.feedforward.c_proj.bias',
|
||||
'h34.mlp.c_proj.w': 'gpt2_decoder.layers.34.feedforward.c_proj.weight',
|
||||
'h35.attn.c_attn.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h35.attn.c_attn.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h35.attn.c_proj.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h35.attn.c_proj.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h35.ln_1.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h35.ln_1.g': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h35.ln_2.b': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta',
|
||||
'h35.ln_2.g': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma',
|
||||
'h35.mlp.c_fc.b': 'gpt2_decoder.layers.35.feedforward.c_fc.bias',
|
||||
'h35.mlp.c_fc.w': 'gpt2_decoder.layers.35.feedforward.c_fc.weight',
|
||||
'h35.mlp.c_proj.b': 'gpt2_decoder.layers.35.feedforward.c_proj.bias',
|
||||
'h35.mlp.c_proj.w': 'gpt2_decoder.layers.35.feedforward.c_proj.weight',
|
||||
'ln_f.b': 'layer_norm.layer_norm.gamma',
|
||||
'ln_f.g': 'layer_norm.layer_norm.beta',
|
||||
'wpe': 'gpt2_embedding_postprocess.position_embedding_table',
|
||||
'wte': 'gpt2_embedding_lookup.embedding_table'
|
||||
} # transfer dictionary
|
||||
|
||||
trans_dict_py = {
|
||||
'h.0.attn.c_attn.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.0.attn.c_attn.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.0.attn.c_proj.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.0.attn.c_proj.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.0.ln_1.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.0.ln_1.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.0.ln_2.bias': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta',
|
||||
'h.0.ln_2.weight': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.0.mlp.c_fc.bias': 'gpt2_decoder.layers.0.feedforward.c_fc.bias',
|
||||
'h.0.mlp.c_fc.weight': 'gpt2_decoder.layers.0.feedforward.c_fc.weight',
|
||||
'h.0.mlp.c_proj.bias': 'gpt2_decoder.layers.0.feedforward.c_proj.bias',
|
||||
'h.0.mlp.c_proj.weight': 'gpt2_decoder.layers.0.feedforward.c_proj.weight',
|
||||
'h.1.attn.c_attn.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.1.attn.c_attn.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.1.attn.c_proj.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.1.attn.c_proj.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.1.ln_1.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.1.ln_1.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.1.ln_2.bias': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta',
|
||||
'h.1.ln_2.weight': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.1.mlp.c_fc.bias': 'gpt2_decoder.layers.1.feedforward.c_fc.bias',
|
||||
'h.1.mlp.c_fc.weight': 'gpt2_decoder.layers.1.feedforward.c_fc.weight',
|
||||
'h.1.mlp.c_proj.bias': 'gpt2_decoder.layers.1.feedforward.c_proj.bias',
|
||||
'h.1.mlp.c_proj.weight': 'gpt2_decoder.layers.1.feedforward.c_proj.weight',
|
||||
'h.2.attn.c_attn.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.2.attn.c_attn.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.2.attn.c_proj.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.2.attn.c_proj.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.2.ln_1.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.2.ln_1.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.2.ln_2.bias': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta',
|
||||
'h.2.ln_2.weight': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.2.mlp.c_fc.bias': 'gpt2_decoder.layers.2.feedforward.c_fc.bias',
|
||||
'h.2.mlp.c_fc.weight': 'gpt2_decoder.layers.2.feedforward.c_fc.weight',
|
||||
'h.2.mlp.c_proj.bias': 'gpt2_decoder.layers.2.feedforward.c_proj.bias',
|
||||
'h.2.mlp.c_proj.weight': 'gpt2_decoder.layers.2.feedforward.c_proj.weight',
|
||||
'h.3.attn.c_attn.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.3.attn.c_attn.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.3.attn.c_proj.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.3.attn.c_proj.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.3.ln_1.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.3.ln_1.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.3.ln_2.bias': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta',
|
||||
'h.3.ln_2.weight': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.3.mlp.c_fc.bias': 'gpt2_decoder.layers.3.feedforward.c_fc.bias',
|
||||
'h.3.mlp.c_fc.weight': 'gpt2_decoder.layers.3.feedforward.c_fc.weight',
|
||||
'h.3.mlp.c_proj.bias': 'gpt2_decoder.layers.3.feedforward.c_proj.bias',
|
||||
'h.3.mlp.c_proj.weight': 'gpt2_decoder.layers.3.feedforward.c_proj.weight',
|
||||
'h.4.attn.c_attn.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.4.attn.c_attn.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.4.attn.c_proj.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.4.attn.c_proj.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.4.ln_1.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.4.ln_1.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.4.ln_2.bias': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta',
|
||||
'h.4.ln_2.weight': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.4.mlp.c_fc.bias': 'gpt2_decoder.layers.4.feedforward.c_fc.bias',
|
||||
'h.4.mlp.c_fc.weight': 'gpt2_decoder.layers.4.feedforward.c_fc.weight',
|
||||
'h.4.mlp.c_proj.bias': 'gpt2_decoder.layers.4.feedforward.c_proj.bias',
|
||||
'h.4.mlp.c_proj.weight': 'gpt2_decoder.layers.4.feedforward.c_proj.weight',
|
||||
'h.5.attn.c_attn.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.5.attn.c_attn.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.5.attn.c_proj.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.5.attn.c_proj.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.5.ln_1.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.5.ln_1.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.5.ln_2.bias': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta',
|
||||
'h.5.ln_2.weight': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.5.mlp.c_fc.bias': 'gpt2_decoder.layers.5.feedforward.c_fc.bias',
|
||||
'h.5.mlp.c_fc.weight': 'gpt2_decoder.layers.5.feedforward.c_fc.weight',
|
||||
'h.5.mlp.c_proj.bias': 'gpt2_decoder.layers.5.feedforward.c_proj.bias',
|
||||
'h.5.mlp.c_proj.weight': 'gpt2_decoder.layers.5.feedforward.c_proj.weight',
|
||||
'h.6.attn.c_attn.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.6.attn.c_attn.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.6.attn.c_proj.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.6.attn.c_proj.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.6.ln_1.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.6.ln_1.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.6.ln_2.bias': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta',
|
||||
'h.6.ln_2.weight': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.6.mlp.c_fc.bias': 'gpt2_decoder.layers.6.feedforward.c_fc.bias',
|
||||
'h.6.mlp.c_fc.weight': 'gpt2_decoder.layers.6.feedforward.c_fc.weight',
|
||||
'h.6.mlp.c_proj.bias': 'gpt2_decoder.layers.6.feedforward.c_proj.bias',
|
||||
'h.6.mlp.c_proj.weight': 'gpt2_decoder.layers.6.feedforward.c_proj.weight',
|
||||
'h.7.attn.c_attn.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.7.attn.c_attn.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.7.attn.c_proj.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.7.attn.c_proj.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.7.ln_1.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.7.ln_1.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.7.ln_2.bias': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta',
|
||||
'h.7.ln_2.weight': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.7.mlp.c_fc.bias': 'gpt2_decoder.layers.7.feedforward.c_fc.bias',
|
||||
'h.7.mlp.c_fc.weight': 'gpt2_decoder.layers.7.feedforward.c_fc.weight',
|
||||
'h.7.mlp.c_proj.bias': 'gpt2_decoder.layers.7.feedforward.c_proj.bias',
|
||||
'h.7.mlp.c_proj.weight': 'gpt2_decoder.layers.7.feedforward.c_proj.weight',
|
||||
'h.8.attn.c_attn.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.8.attn.c_attn.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.8.attn.c_proj.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.8.attn.c_proj.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.8.ln_1.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.8.ln_1.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.8.ln_2.bias': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta',
|
||||
'h.8.ln_2.weight': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.8.mlp.c_fc.bias': 'gpt2_decoder.layers.8.feedforward.c_fc.bias',
|
||||
'h.8.mlp.c_fc.weight': 'gpt2_decoder.layers.8.feedforward.c_fc.weight',
|
||||
'h.8.mlp.c_proj.bias': 'gpt2_decoder.layers.8.feedforward.c_proj.bias',
|
||||
'h.8.mlp.c_proj.weight': 'gpt2_decoder.layers.8.feedforward.c_proj.weight',
|
||||
'h.9.attn.c_attn.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.9.attn.c_attn.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.9.attn.c_proj.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.9.attn.c_proj.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.9.ln_1.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.9.ln_1.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.9.ln_2.bias': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta',
|
||||
'h.9.ln_2.weight': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.9.mlp.c_fc.bias': 'gpt2_decoder.layers.9.feedforward.c_fc.bias',
|
||||
'h.9.mlp.c_fc.weight': 'gpt2_decoder.layers.9.feedforward.c_fc.weight',
|
||||
'h.9.mlp.c_proj.bias': 'gpt2_decoder.layers.9.feedforward.c_proj.bias',
|
||||
'h.9.mlp.c_proj.weight': 'gpt2_decoder.layers.9.feedforward.c_proj.weight',
|
||||
'h.10.attn.c_attn.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.10.attn.c_attn.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.10.attn.c_proj.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.10.attn.c_proj.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.10.ln_1.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.10.ln_1.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.10.ln_2.bias': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta',
|
||||
'h.10.ln_2.weight': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.10.mlp.c_fc.bias': 'gpt2_decoder.layers.10.feedforward.c_fc.bias',
|
||||
'h.10.mlp.c_fc.weight': 'gpt2_decoder.layers.10.feedforward.c_fc.weight',
|
||||
'h.10.mlp.c_proj.bias': 'gpt2_decoder.layers.10.feedforward.c_proj.bias',
|
||||
'h.10.mlp.c_proj.weight': 'gpt2_decoder.layers.10.feedforward.c_proj.weight',
|
||||
'h.11.attn.c_attn.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.11.attn.c_attn.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.11.attn.c_proj.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.11.attn.c_proj.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.11.ln_1.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.11.ln_1.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.11.ln_2.bias': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta',
|
||||
'h.11.ln_2.weight': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.11.mlp.c_fc.bias': 'gpt2_decoder.layers.11.feedforward.c_fc.bias',
|
||||
'h.11.mlp.c_fc.weight': 'gpt2_decoder.layers.11.feedforward.c_fc.weight',
|
||||
'h.11.mlp.c_proj.bias': 'gpt2_decoder.layers.11.feedforward.c_proj.bias',
|
||||
'h.11.mlp.c_proj.weight': 'gpt2_decoder.layers.11.feedforward.c_proj.weight',
|
||||
'h.12.attn.c_attn.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.12.attn.c_attn.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.12.attn.c_proj.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.12.attn.c_proj.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.12.ln_1.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.12.ln_1.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.12.ln_2.bias': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta',
|
||||
'h.12.ln_2.weight': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.12.mlp.c_fc.bias': 'gpt2_decoder.layers.12.feedforward.c_fc.bias',
|
||||
'h.12.mlp.c_fc.weight': 'gpt2_decoder.layers.12.feedforward.c_fc.weight',
|
||||
'h.12.mlp.c_proj.bias': 'gpt2_decoder.layers.12.feedforward.c_proj.bias',
|
||||
'h.12.mlp.c_proj.weight': 'gpt2_decoder.layers.12.feedforward.c_proj.weight',
|
||||
'h.13.attn.c_attn.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.13.attn.c_attn.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.13.attn.c_proj.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.13.attn.c_proj.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.13.ln_1.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.13.ln_1.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.13.ln_2.bias': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta',
|
||||
'h.13.ln_2.weight': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.13.mlp.c_fc.bias': 'gpt2_decoder.layers.13.feedforward.c_fc.bias',
|
||||
'h.13.mlp.c_fc.weight': 'gpt2_decoder.layers.13.feedforward.c_fc.weight',
|
||||
'h.13.mlp.c_proj.bias': 'gpt2_decoder.layers.13.feedforward.c_proj.bias',
|
||||
'h.13.mlp.c_proj.weight': 'gpt2_decoder.layers.13.feedforward.c_proj.weight',
|
||||
'h.14.attn.c_attn.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.14.attn.c_attn.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.14.attn.c_proj.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.14.attn.c_proj.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.14.ln_1.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.14.ln_1.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.14.ln_2.bias': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta',
|
||||
'h.14.ln_2.weight': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.14.mlp.c_fc.bias': 'gpt2_decoder.layers.14.feedforward.c_fc.bias',
|
||||
'h.14.mlp.c_fc.weight': 'gpt2_decoder.layers.14.feedforward.c_fc.weight',
|
||||
'h.14.mlp.c_proj.bias': 'gpt2_decoder.layers.14.feedforward.c_proj.bias',
|
||||
'h.14.mlp.c_proj.weight': 'gpt2_decoder.layers.14.feedforward.c_proj.weight',
|
||||
'h.15.attn.c_attn.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.15.attn.c_attn.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.15.attn.c_proj.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.15.attn.c_proj.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.15.ln_1.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.15.ln_1.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.15.ln_2.bias': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta',
|
||||
'h.15.ln_2.weight': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.15.mlp.c_fc.bias': 'gpt2_decoder.layers.15.feedforward.c_fc.bias',
|
||||
'h.15.mlp.c_fc.weight': 'gpt2_decoder.layers.15.feedforward.c_fc.weight',
|
||||
'h.15.mlp.c_proj.bias': 'gpt2_decoder.layers.15.feedforward.c_proj.bias',
|
||||
'h.15.mlp.c_proj.weight': 'gpt2_decoder.layers.15.feedforward.c_proj.weight',
|
||||
'h.16.attn.c_attn.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.16.attn.c_attn.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.16.attn.c_proj.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.16.attn.c_proj.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.16.ln_1.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.16.ln_1.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.16.ln_2.bias': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta',
|
||||
'h.16.ln_2.weight': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.16.mlp.c_fc.bias': 'gpt2_decoder.layers.16.feedforward.c_fc.bias',
|
||||
'h.16.mlp.c_fc.weight': 'gpt2_decoder.layers.16.feedforward.c_fc.weight',
|
||||
'h.16.mlp.c_proj.bias': 'gpt2_decoder.layers.16.feedforward.c_proj.bias',
|
||||
'h.16.mlp.c_proj.weight': 'gpt2_decoder.layers.16.feedforward.c_proj.weight',
|
||||
'h.17.attn.c_attn.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.17.attn.c_attn.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.17.attn.c_proj.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.17.attn.c_proj.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.17.ln_1.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.17.ln_1.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.17.ln_2.bias': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta',
|
||||
'h.17.ln_2.weight': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.17.mlp.c_fc.bias': 'gpt2_decoder.layers.17.feedforward.c_fc.bias',
|
||||
'h.17.mlp.c_fc.weight': 'gpt2_decoder.layers.17.feedforward.c_fc.weight',
|
||||
'h.17.mlp.c_proj.bias': 'gpt2_decoder.layers.17.feedforward.c_proj.bias',
|
||||
'h.17.mlp.c_proj.weight': 'gpt2_decoder.layers.17.feedforward.c_proj.weight',
|
||||
'h.18.attn.c_attn.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.18.attn.c_attn.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.18.attn.c_proj.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.18.attn.c_proj.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.18.ln_1.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.18.ln_1.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.18.ln_2.bias': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta',
|
||||
'h.18.ln_2.weight': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.18.mlp.c_fc.bias': 'gpt2_decoder.layers.18.feedforward.c_fc.bias',
|
||||
'h.18.mlp.c_fc.weight': 'gpt2_decoder.layers.18.feedforward.c_fc.weight',
|
||||
'h.18.mlp.c_proj.bias': 'gpt2_decoder.layers.18.feedforward.c_proj.bias',
|
||||
'h.18.mlp.c_proj.weight': 'gpt2_decoder.layers.18.feedforward.c_proj.weight',
|
||||
'h.19.attn.c_attn.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.19.attn.c_attn.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.19.attn.c_proj.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.19.attn.c_proj.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.19.ln_1.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.19.ln_1.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.19.ln_2.bias': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta',
|
||||
'h.19.ln_2.weight': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.19.mlp.c_fc.bias': 'gpt2_decoder.layers.19.feedforward.c_fc.bias',
|
||||
'h.19.mlp.c_fc.weight': 'gpt2_decoder.layers.19.feedforward.c_fc.weight',
|
||||
'h.19.mlp.c_proj.bias': 'gpt2_decoder.layers.19.feedforward.c_proj.bias',
|
||||
'h.19.mlp.c_proj.weight': 'gpt2_decoder.layers.19.feedforward.c_proj.weight',
|
||||
'h.20.attn.c_attn.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.20.attn.c_attn.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.20.attn.c_proj.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.20.attn.c_proj.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.20.ln_1.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.20.ln_1.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.20.ln_2.bias': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta',
|
||||
'h.20.ln_2.weight': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.20.mlp.c_fc.bias': 'gpt2_decoder.layers.20.feedforward.c_fc.bias',
|
||||
'h.20.mlp.c_fc.weight': 'gpt2_decoder.layers.20.feedforward.c_fc.weight',
|
||||
'h.20.mlp.c_proj.bias': 'gpt2_decoder.layers.20.feedforward.c_proj.bias',
|
||||
'h.20.mlp.c_proj.weight': 'gpt2_decoder.layers.20.feedforward.c_proj.weight',
|
||||
'h.21.attn.c_attn.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.21.attn.c_attn.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.21.attn.c_proj.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.21.attn.c_proj.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.21.ln_1.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.21.ln_1.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.21.ln_2.bias': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta',
|
||||
'h.21.ln_2.weight': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.21.mlp.c_fc.bias': 'gpt2_decoder.layers.21.feedforward.c_fc.bias',
|
||||
'h.21.mlp.c_fc.weight': 'gpt2_decoder.layers.21.feedforward.c_fc.weight',
|
||||
'h.21.mlp.c_proj.bias': 'gpt2_decoder.layers.21.feedforward.c_proj.bias',
|
||||
'h.21.mlp.c_proj.weight': 'gpt2_decoder.layers.21.feedforward.c_proj.weight',
|
||||
'h.22.attn.c_attn.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.22.attn.c_attn.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.22.attn.c_proj.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.22.attn.c_proj.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.22.ln_1.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.22.ln_1.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.22.ln_2.bias': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta',
|
||||
'h.22.ln_2.weight': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.22.mlp.c_fc.bias': 'gpt2_decoder.layers.22.feedforward.c_fc.bias',
|
||||
'h.22.mlp.c_fc.weight': 'gpt2_decoder.layers.22.feedforward.c_fc.weight',
|
||||
'h.22.mlp.c_proj.bias': 'gpt2_decoder.layers.22.feedforward.c_proj.bias',
|
||||
'h.22.mlp.c_proj.weight': 'gpt2_decoder.layers.22.feedforward.c_proj.weight',
|
||||
'h.23.attn.c_attn.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.23.attn.c_attn.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.23.attn.c_proj.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.23.attn.c_proj.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.23.ln_1.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.23.ln_1.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.23.ln_2.bias': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta',
|
||||
'h.23.ln_2.weight': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.23.mlp.c_fc.bias': 'gpt2_decoder.layers.23.feedforward.c_fc.bias',
|
||||
'h.23.mlp.c_fc.weight': 'gpt2_decoder.layers.23.feedforward.c_fc.weight',
|
||||
'h.23.mlp.c_proj.bias': 'gpt2_decoder.layers.23.feedforward.c_proj.bias',
|
||||
'h.23.mlp.c_proj.weight': 'gpt2_decoder.layers.23.feedforward.c_proj.weight',
|
||||
'h.24.attn.c_attn.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.24.attn.c_attn.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.24.attn.c_proj.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.24.attn.c_proj.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.24.ln_1.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.24.ln_1.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.24.ln_2.bias': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta',
|
||||
'h.24.ln_2.weight': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.24.mlp.c_fc.bias': 'gpt2_decoder.layers.24.feedforward.c_fc.bias',
|
||||
'h.24.mlp.c_fc.weight': 'gpt2_decoder.layers.24.feedforward.c_fc.weight',
|
||||
'h.24.mlp.c_proj.bias': 'gpt2_decoder.layers.24.feedforward.c_proj.bias',
|
||||
'h.24.mlp.c_proj.weight': 'gpt2_decoder.layers.24.feedforward.c_proj.weight',
|
||||
'h.25.attn.c_attn.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.25.attn.c_attn.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.25.attn.c_proj.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.25.attn.c_proj.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.25.ln_1.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.25.ln_1.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.25.ln_2.bias': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta',
|
||||
'h.25.ln_2.weight': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.25.mlp.c_fc.bias': 'gpt2_decoder.layers.25.feedforward.c_fc.bias',
|
||||
'h.25.mlp.c_fc.weight': 'gpt2_decoder.layers.25.feedforward.c_fc.weight',
|
||||
'h.25.mlp.c_proj.bias': 'gpt2_decoder.layers.25.feedforward.c_proj.bias',
|
||||
'h.25.mlp.c_proj.weight': 'gpt2_decoder.layers.25.feedforward.c_proj.weight',
|
||||
'h.26.attn.c_attn.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.26.attn.c_attn.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.26.attn.c_proj.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.26.attn.c_proj.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.26.ln_1.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.26.ln_1.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.26.ln_2.bias': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta',
|
||||
'h.26.ln_2.weight': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.26.mlp.c_fc.bias': 'gpt2_decoder.layers.26.feedforward.c_fc.bias',
|
||||
'h.26.mlp.c_fc.weight': 'gpt2_decoder.layers.26.feedforward.c_fc.weight',
|
||||
'h.26.mlp.c_proj.bias': 'gpt2_decoder.layers.26.feedforward.c_proj.bias',
|
||||
'h.26.mlp.c_proj.weight': 'gpt2_decoder.layers.26.feedforward.c_proj.weight',
|
||||
'h.27.attn.c_attn.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.27.attn.c_attn.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.27.attn.c_proj.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.27.attn.c_proj.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.27.ln_1.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.27.ln_1.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.27.ln_2.bias': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta',
|
||||
'h.27.ln_2.weight': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.27.mlp.c_fc.bias': 'gpt2_decoder.layers.27.feedforward.c_fc.bias',
|
||||
'h.27.mlp.c_fc.weight': 'gpt2_decoder.layers.27.feedforward.c_fc.weight',
|
||||
'h.27.mlp.c_proj.bias': 'gpt2_decoder.layers.27.feedforward.c_proj.bias',
|
||||
'h.27.mlp.c_proj.weight': 'gpt2_decoder.layers.27.feedforward.c_proj.weight',
|
||||
'h.28.attn.c_attn.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.28.attn.c_attn.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.28.attn.c_proj.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.28.attn.c_proj.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.28.ln_1.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.28.ln_1.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.28.ln_2.bias': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta',
|
||||
'h.28.ln_2.weight': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.28.mlp.c_fc.bias': 'gpt2_decoder.layers.28.feedforward.c_fc.bias',
|
||||
'h.28.mlp.c_fc.weight': 'gpt2_decoder.layers.28.feedforward.c_fc.weight',
|
||||
'h.28.mlp.c_proj.bias': 'gpt2_decoder.layers.28.feedforward.c_proj.bias',
|
||||
'h.28.mlp.c_proj.weight': 'gpt2_decoder.layers.28.feedforward.c_proj.weight',
|
||||
'h.29.attn.c_attn.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.29.attn.c_attn.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.29.attn.c_proj.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.29.attn.c_proj.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.29.ln_1.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.29.ln_1.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.29.ln_2.bias': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta',
|
||||
'h.29.ln_2.weight': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.29.mlp.c_fc.bias': 'gpt2_decoder.layers.29.feedforward.c_fc.bias',
|
||||
'h.29.mlp.c_fc.weight': 'gpt2_decoder.layers.29.feedforward.c_fc.weight',
|
||||
'h.29.mlp.c_proj.bias': 'gpt2_decoder.layers.29.feedforward.c_proj.bias',
|
||||
'h.29.mlp.c_proj.weight': 'gpt2_decoder.layers.29.feedforward.c_proj.weight',
|
||||
'h.30.attn.c_attn.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.30.attn.c_attn.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.30.attn.c_proj.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.30.attn.c_proj.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.30.ln_1.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.30.ln_1.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.30.ln_2.bias': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta',
|
||||
'h.30.ln_2.weight': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.30.mlp.c_fc.bias': 'gpt2_decoder.layers.30.feedforward.c_fc.bias',
|
||||
'h.30.mlp.c_fc.weight': 'gpt2_decoder.layers.30.feedforward.c_fc.weight',
|
||||
'h.30.mlp.c_proj.bias': 'gpt2_decoder.layers.30.feedforward.c_proj.bias',
|
||||
'h.30.mlp.c_proj.weight': 'gpt2_decoder.layers.30.feedforward.c_proj.weight',
|
||||
'h.31.attn.c_attn.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.31.attn.c_attn.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.31.attn.c_proj.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.31.attn.c_proj.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.31.ln_1.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.31.ln_1.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.31.ln_2.bias': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta',
|
||||
'h.31.ln_2.weight': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.31.mlp.c_fc.bias': 'gpt2_decoder.layers.31.feedforward.c_fc.bias',
|
||||
'h.31.mlp.c_fc.weight': 'gpt2_decoder.layers.31.feedforward.c_fc.weight',
|
||||
'h.31.mlp.c_proj.bias': 'gpt2_decoder.layers.31.feedforward.c_proj.bias',
|
||||
'h.31.mlp.c_proj.weight': 'gpt2_decoder.layers.31.feedforward.c_proj.weight',
|
||||
'h.32.attn.c_attn.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.32.attn.c_attn.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.32.attn.c_proj.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.32.attn.c_proj.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.32.ln_1.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.32.ln_1.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.32.ln_2.bias': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta',
|
||||
'h.32.ln_2.weight': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.32.mlp.c_fc.bias': 'gpt2_decoder.layers.32.feedforward.c_fc.bias',
|
||||
'h.32.mlp.c_fc.weight': 'gpt2_decoder.layers.32.feedforward.c_fc.weight',
|
||||
'h.32.mlp.c_proj.bias': 'gpt2_decoder.layers.32.feedforward.c_proj.bias',
|
||||
'h.32.mlp.c_proj.weight': 'gpt2_decoder.layers.32.feedforward.c_proj.weight',
|
||||
'h.33.attn.c_attn.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.33.attn.c_attn.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.33.attn.c_proj.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.33.attn.c_proj.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.33.ln_1.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.33.ln_1.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.33.ln_2.bias': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta',
|
||||
'h.33.ln_2.weight': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.33.mlp.c_fc.bias': 'gpt2_decoder.layers.33.feedforward.c_fc.bias',
|
||||
'h.33.mlp.c_fc.weight': 'gpt2_decoder.layers.33.feedforward.c_fc.weight',
|
||||
'h.33.mlp.c_proj.bias': 'gpt2_decoder.layers.33.feedforward.c_proj.bias',
|
||||
'h.33.mlp.c_proj.weight': 'gpt2_decoder.layers.33.feedforward.c_proj.weight',
|
||||
'h.34.attn.c_attn.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.34.attn.c_attn.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.34.attn.c_proj.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.34.attn.c_proj.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.34.ln_1.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.34.ln_1.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.34.ln_2.bias': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta',
|
||||
'h.34.ln_2.weight': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.34.mlp.c_fc.bias': 'gpt2_decoder.layers.34.feedforward.c_fc.bias',
|
||||
'h.34.mlp.c_fc.weight': 'gpt2_decoder.layers.34.feedforward.c_fc.weight',
|
||||
'h.34.mlp.c_proj.bias': 'gpt2_decoder.layers.34.feedforward.c_proj.bias',
|
||||
'h.34.mlp.c_proj.weight': 'gpt2_decoder.layers.34.feedforward.c_proj.weight',
|
||||
'h.35.attn.c_attn.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias',
|
||||
'h.35.attn.c_attn.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight',
|
||||
'h.35.attn.c_proj.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias',
|
||||
'h.35.attn.c_proj.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight',
|
||||
'h.35.ln_1.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta',
|
||||
'h.35.ln_1.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma',
|
||||
'h.35.ln_2.bias': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta',
|
||||
'h.35.ln_2.weight': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma',
|
||||
'h.35.mlp.c_fc.bias': 'gpt2_decoder.layers.35.feedforward.c_fc.bias',
|
||||
'h.35.mlp.c_fc.weight': 'gpt2_decoder.layers.35.feedforward.c_fc.weight',
|
||||
'h.35.mlp.c_proj.bias': 'gpt2_decoder.layers.35.feedforward.c_proj.bias',
|
||||
'h.35.mlp.c_proj.weight': 'gpt2_decoder.layers.35.feedforward.c_proj.weight',
|
||||
'ln_f.bias': 'layer_norm.layer_norm.gamma',
|
||||
'ln_f.weight': 'layer_norm.layer_norm.beta',
|
||||
'wpe.weight': 'gpt2_embedding_postprocess.position_embedding_table',
|
||||
'wte.weight': 'gpt2_embedding_lookup.embedding_table'
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for Children's Book Test task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None, num_choice=None):
|
||||
"""A single sample instance for cbt task."""
|
||||
text = text.replace(" \t ", "\t ")
|
||||
sentence = text.strip().split("\t")
|
||||
context_length = len(tokenizer.encode(sentence[0]))
|
||||
|
||||
whole_sentence = sentence[0] + sentence[1]
|
||||
whole_sentence = whole_sentence.strip()
|
||||
assert whole_sentence != ""
|
||||
print(" | whole sentence: ", whole_sentence)
|
||||
ids = tokenizer.encode(whole_sentence)
|
||||
input_length = len(ids)
|
||||
pair_ids = None
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="RIGHT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
|
||||
output["length"] = [context_length + 1] + [input_length + 1]
|
||||
|
||||
gold_answer_id = int(sentence[2])
|
||||
assert gold_answer_id < 10
|
||||
output["mc_labels"] = gold_answer_id
|
||||
|
||||
for name, value in output.items():
|
||||
print(name)
|
||||
print(value)
|
||||
print("==================================")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
assert len(input_ids) == len(input_mask)
|
||||
length = instance["length"] # list
|
||||
mc_labels = instance["mc_labels"]
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["input_length"] = np.asarray(length)
|
||||
features["mc_labels"] = mc_labels
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, default="", help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, default="", help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--num_choice", type=int, required=True, help='Number of choices. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
num_choice = args.num_choice
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"input_length": {"type": "int64", "shape": [-1]},
|
||||
"mc_labels": {"type": "int64"}
|
||||
}
|
||||
writer.add_schema(data_schema, "cbt-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length, num_choice)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for LAMBADA task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for LAMBADA task."""
|
||||
text = text.replace(" \t ", "\t ")
|
||||
sentence = text.strip().split("\t")
|
||||
context_length = len(tokenizer.encode(sentence[0]))
|
||||
|
||||
whole_sentence = sentence[0] + sentence[1]
|
||||
whole_sentence = whole_sentence.strip()
|
||||
assert whole_sentence != ""
|
||||
print(" | whole sentence: ", whole_sentence)
|
||||
ids = tokenizer.encode(whole_sentence)
|
||||
input_length = len(ids)
|
||||
pair_ids = None
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="RIGHT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
|
||||
# input_length = <bos> + text_length, not include <eos>
|
||||
output["length"] = [context_length + 1] + [input_length + 1]
|
||||
|
||||
for k, v in output.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print("==================================")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
assert len(input_ids) == len(input_mask)
|
||||
length = instance["length"] # list
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["input_length"] = np.asarray(length)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"input_length": {"type": "int64", "shape": [-1]},
|
||||
}
|
||||
writer.add_schema(data_schema, "lambada-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for LM task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for LM task."""
|
||||
sentence = text.strip().split("\t")
|
||||
|
||||
ids = tokenizer.encode(sentence[0])
|
||||
pair_ids = None
|
||||
if len(sentence) == 2:
|
||||
pair_ids = tokenizer.encode(sentence[1])
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="LEFT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
label_ids = instance["input_ids"]
|
||||
assert len(input_ids) == len(label_ids)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["label_ids"] = np.asarray(label_ids)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"label_ids": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "lm-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for Summarization task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils import tokenization
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for Summarization task."""
|
||||
sentence = text.strip().split("\t")
|
||||
ids = tokenizer.encode(sentence[0])
|
||||
pair_ids = None
|
||||
if len(sentence) == 2:
|
||||
pair_ids = tokenizer.encode(sentence[1])
|
||||
if len(sentence) >= 3:
|
||||
article = sentence[0]
|
||||
for i in range(1, len(sentence) - 1):
|
||||
article += sentence[i]
|
||||
ids = tokenizer.encode(article)
|
||||
pair_ids = tokenizer.encode(sentence[-1])
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
label_ids = instance["input_ids"]
|
||||
assert len(input_ids) == len(label_ids)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["label_ids"] = np.asarray(label_ids)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file.')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length.')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
parser.add_argument("--mode", type=str, required=True, default='cnn_dailymail', help='mode of dataset creation')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file, mode=args.mode)
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"label_ids": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "wikitext2-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""download the CNN & DailyMail for Summarization task"""
|
||||
|
||||
import argparse
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def generate_txt(url, split_, number=None, version="3.0.0"):
|
||||
"""
|
||||
generate txt file of cnn_dailymail dataset
|
||||
|
||||
Args:
|
||||
url (str): directory of dataset txt file.
|
||||
split_ (str): test or train.
|
||||
number (int): top-n number of samples from dataset
|
||||
version (str): "3.0.0" by default
|
||||
|
||||
"""
|
||||
cnn = load_dataset("cnn_dailymail", version, split=split_)
|
||||
if number == -1:
|
||||
number = len(cnn)
|
||||
f = open(url + split_ + '.txt', 'w')
|
||||
for idx in range(number):
|
||||
article = cnn[idx]['article']
|
||||
article = article.replace('\n', ' ')
|
||||
highlights = cnn[idx]['highlights']
|
||||
highlights = highlights.replace('\n', ' ')
|
||||
f.write(article + "\t" + highlights + '\n')
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Download CNN_Dailymail 3.0.0 using datasets by Huggingface')
|
||||
parser.add_argument('--dir', type=str, default="", help="directory of dataset")
|
||||
parser.add_argument('--split', type=str, default='test', help="[test,train]")
|
||||
parser.add_argument('--num', type=int, default=-1,
|
||||
help=" number of samples by default order. "
|
||||
"If num is -1, it will download whole dataset. Default: -1")
|
||||
args = parser.parse_args()
|
||||
|
||||
data_directory = args.dir
|
||||
split = args.split
|
||||
num = args.num
|
||||
|
||||
generate_txt(url=data_directory, split_=split, number=num)
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation reading comprehension result with additional answer."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def get_normalize_answer_token(string_):
|
||||
"""normalize the answer token, Lower text and remove punctuation, article and extra whitespace"""
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(char for char in text if char not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(string_)))).split()
|
||||
|
||||
|
||||
def calculate_f1(pred_answer, gold_answer):
|
||||
"""
|
||||
calculate final F1 score with addition answer
|
||||
"""
|
||||
f1_score = 0
|
||||
pred_answer = get_normalize_answer_token(pred_answer)
|
||||
gold_answer = get_normalize_answer_token(gold_answer)
|
||||
common = Counter(pred_answer) & Counter(gold_answer)
|
||||
num_same = sum(common.values())
|
||||
# the number of same tokens between pred_answer and gold_answer
|
||||
precision = 1.0 * num_same / len(pred_answer) if pred_answer.strip() == "" else 0
|
||||
recall = 1.0 * num_same / len(gold_answer) if gold_answer.strip() == "" else 0
|
||||
if pred_answer.strip() == "" and gold_answer.strip() == "":
|
||||
f1_score = 1
|
||||
else:
|
||||
f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0
|
||||
return f1_score
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
|
||||
parser.add_argument("--input_file", type=str, default="",
|
||||
help="The log file path of evaluation in Reading Comprehension. ")
|
||||
parser.add_argument("--addition_file", type=str, default="", help="Coqa-dev-v1.0.json path")
|
||||
args_opt = parser.parse_args()
|
||||
input_file = args_opt.input_file
|
||||
addition_file = args_opt.addition_file
|
||||
|
||||
find_word = 'Pred_answer:'
|
||||
find_word_length = len(find_word)
|
||||
pred_answer_list = []
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
index = line.find(find_word)
|
||||
if index != -1:
|
||||
pred_answer = line[index + find_word_length:].strip()
|
||||
pred_answer_list.append(pred_answer)
|
||||
|
||||
dataset = json.load(open(addition_file))
|
||||
pred_answer_num = 0
|
||||
total_f1score = 0
|
||||
average_f1score = 0
|
||||
data_num = len(pred_answer_list)
|
||||
|
||||
for story in dataset['data']:
|
||||
questions = story['questions']
|
||||
multiple_answers = [story['answers']]
|
||||
multiple_answers += story['additional_answers'].values()
|
||||
for question in questions:
|
||||
pred_a = pred_answer_list[pred_answer_num]
|
||||
turn_id = question['turn_id']
|
||||
max_score = 0
|
||||
max_group = 0
|
||||
flag = 0
|
||||
for i, answer in enumerate(multiple_answers):
|
||||
gold_a = answer[turn_id - 1]['input_text']
|
||||
score = calculate_f1(pred_a, gold_a)
|
||||
if score > max_score:
|
||||
max_score = score
|
||||
max_group = i
|
||||
# calculate the max score in multiple answers and record it's number.
|
||||
gold_a = multiple_answers[max_group][turn_id - 1]['input_text']
|
||||
pred_answer_num += 1
|
||||
total_f1score += max_score
|
||||
average_f1score = total_f1score / pred_answer_num
|
||||
|
||||
print('==================== data {} ===================='.format(pred_answer_num))
|
||||
print('| Gold_answer:{}'.format(gold_a))
|
||||
print('| Pred_answer:{}'.format(pred_a))
|
||||
print('| F1_Score:{:.8f}'.format(average_f1score))
|
||||
print('=====================================================\n')
|
||||
|
||||
if pred_answer_num >= data_num:
|
||||
flag = 1
|
||||
break
|
||||
# Stop flag
|
||||
if flag:
|
||||
print('Finished evaluation with addition answer! \n')
|
||||
print("********************** Testing Finished **********************")
|
||||
print('| Test file name: {}'.format(input_file))
|
||||
print('| Final F1 score: {:.8f}'.format(average_f1score))
|
||||
print('| Total data num: {}'.format(pred_answer_num))
|
||||
print("**************************************************************")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for Children's Book Test task.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CBT
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.utils.metric_method import Accuracy
|
||||
from src.dataset import create_cbt_dataset, create_language_model_dataset
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.utils.task_utils import calculate_choice_prob_for_cbt
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrained model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetuned model checkpoint.
|
||||
epoch_num: the number of epoch.
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_" + "cbt_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" + str(epoch_num) +\
|
||||
"_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load pretrained parameter successfully!\n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("==================== Starting Finetuning ====================")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("==================== Finetuning Success ====================")
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, num_choice=None):
|
||||
"""
|
||||
Do evaluation for CBT task.
|
||||
Args:
|
||||
dataset: the eval dataset.
|
||||
network: the network with loss.
|
||||
metric: the evaluation method.
|
||||
load_checkpoint_path: the file path which saved finetuned model checkpoint.
|
||||
eval_type:
|
||||
num_choice:
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
|
||||
if metric.lower() == "accuracy":
|
||||
print("Prepare to calculate the accuracy score ...")
|
||||
gpt2_cbt = network(config=gpt2_net_cfg,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False
|
||||
)
|
||||
gpt2_cbt.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
if eval_type == "zero-shot":
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
load_param_into_net(gpt2_cbt, final_param_dict)
|
||||
print("load pretrained parameter successfully!\n")
|
||||
elif eval_type == "finetuned":
|
||||
load_param_into_net(gpt2_cbt, param_dict)
|
||||
print("load finetuned parameter successfully!\n")
|
||||
else:
|
||||
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
|
||||
|
||||
model = Model(gpt2_cbt)
|
||||
callback = Accuracy()
|
||||
columns_list = ["input_ids", "input_mask", "input_length", "mc_labels"]
|
||||
print("==================== [ACC] Testing ====================")
|
||||
num_data = 1
|
||||
all_choice_prob = []
|
||||
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, input_length, mc_labels = input_data
|
||||
print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
|
||||
# print("mc_labels: {}".format(mc_labels)) # [batch_size]
|
||||
|
||||
logits = model.predict(input_ids, input_mask)
|
||||
# choice_prob_list [batch_size]
|
||||
choice_prob_list = calculate_choice_prob_for_cbt(logits=logits,
|
||||
batch_size=gpt2_net_cfg.batch_size,
|
||||
input_length=input_length,
|
||||
input_ids=input_ids)
|
||||
all_choice_prob.append(choice_prob_list)
|
||||
if (num_data * gpt2_net_cfg.batch_size) % num_choice == 0:
|
||||
all_choice_prob_np = np.array(all_choice_prob)
|
||||
all_choice_prob_np = all_choice_prob_np.reshape((-1, num_choice))
|
||||
print("| all_choice_prob_np: ", all_choice_prob_np)
|
||||
print("| all_choice_prob_np shape: ", all_choice_prob_np.shape)
|
||||
mc_labels = np.array([mc_labels.asnumpy()[0]])
|
||||
callback.update(all_choice_prob_np, mc_labels)
|
||||
all_choice_prob = []
|
||||
num_data += 1
|
||||
|
||||
print("\n\n")
|
||||
print("**************************************************************")
|
||||
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
print("********************** Testing Finished **********************")
|
||||
else:
|
||||
raise ValueError("metric method not supported, support: [Accuracy]")
|
||||
|
||||
|
||||
def run_cbt_task():
|
||||
"""
|
||||
run Children's Book Test (CBT) task
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate CBT task")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=1,
|
||||
help="ID of target device. ")
|
||||
parser.add_argument("--num_choice", type=int, default=10,
|
||||
help="The number of choice in CBT task. ")
|
||||
parser.add_argument("--metric_method", type=str, default="Accuracy",
|
||||
help="The eval method including [Accuracy]. Default: Accuracy.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: true.")
|
||||
parser.add_argument("--eval_type", type=str, default="zero-shot",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--epoch_num", type=int, default=1,
|
||||
help="Epoch number. Default: 1.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the finetuned checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path for train.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path for evaluation.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device_target = args_opt.device_target
|
||||
if device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=device_target,
|
||||
device_id=args_opt.device_id,
|
||||
max_call_depth=3000)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
gpt2_loss = GPT2CBT(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
print("============== Start Loading Train Dataset ============")
|
||||
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
print("============== Start Loading Evaluation Dataset ============")
|
||||
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
|
||||
eval_dataset = create_cbt_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.eval_data_file_path)
|
||||
do_eval(eval_dataset, GPT2CBT, metric, load_finetune_ckpt_path, args_opt.eval_type, args_opt.num_choice)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_cbt_task()
|
||||
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,293 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for Reading Comprehension task.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CoQA
|
||||
from src.GPT2ForReadComprehension import GPT2CoQAModel
|
||||
from src.utils.metric_method import F1
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.dataset import create_language_model_dataset
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.utils.tokenization import Tokenizer
|
||||
from src.GPT2_generation import GenerateForReadComprehension
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrained model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetuned model checkpoint.
|
||||
epoch_num: the number of epoch.
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_rc_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
||||
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load the pretrained parameter successfully! \n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
|
||||
model = Model(netwithgrads)
|
||||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("=================== Starting Training For Translation Task ====================")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("=================== Translation Training Success ====================")
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="",
|
||||
generate_length=1, top_k=1, top_p=1.0, temperature=1.0):
|
||||
"""
|
||||
Do evaluation on Translation
|
||||
Args:
|
||||
dataset: the eval dataset.
|
||||
network: the network with loss.
|
||||
metric: the evaluation method.
|
||||
load_checkpoint_path: the file path which saved finetune model checkpoint.
|
||||
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
if metric.lower() == "f1":
|
||||
print("Prepare to calculate the BLEU score ...")
|
||||
|
||||
gpt2_rc = network(config=gpt2_net_cfg,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
gpt2_rc.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
if eval_type == "zero-shot":
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
load_param_into_net(gpt2_rc, final_param_dict)
|
||||
print("load pretrained parameter successfully!\n")
|
||||
elif eval_type == "finetuned":
|
||||
load_param_into_net(gpt2_rc, param_dict)
|
||||
print("load finetuned parameter successfully!\n")
|
||||
else:
|
||||
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
|
||||
|
||||
model = Model(gpt2_rc)
|
||||
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
|
||||
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
|
||||
callback = F1()
|
||||
rc_generator = GenerateForReadComprehension(decoder=model,
|
||||
config=gpt2_net_cfg,
|
||||
tokenizer=tokenizer,
|
||||
generate_length=generate_length,
|
||||
topk_num=top_k,
|
||||
topp_prob=float(top_p),
|
||||
temperature=float(temperature)
|
||||
)
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "label_ids"]
|
||||
print("==================== [F1] Testing ====================")
|
||||
num_data = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, _, label_ids = input_data
|
||||
|
||||
print("input_ids shape: {}".format(input_ids.shape))
|
||||
print("label_ids shape: {}".format(label_ids.shape))
|
||||
|
||||
passage, pred_answer, gold_answer = rc_generator.generate_for_read_comprehension(input_ids)
|
||||
|
||||
for batch_id in range(gpt2_net_cfg.batch_size):
|
||||
print("============== [F1] {} ================".format(num_data + 1))
|
||||
print(" | Passage:{}".format(passage[batch_id]))
|
||||
print(" | Gold_answer:{}".format(gold_answer[batch_id]))
|
||||
print(" | Pred_answer:{}".format(pred_answer[batch_id]))
|
||||
|
||||
pred = callback.get_normalize_answer_token(pred_answer[batch_id])
|
||||
gold = callback.get_normalize_answer_token(gold_answer[batch_id])
|
||||
|
||||
callback.update(pred, gold)
|
||||
num_data += 1
|
||||
|
||||
average_f1_score = callback.f1_score / num_data
|
||||
print("============== Evaluation =================")
|
||||
print("| Avg F1 Score:{:.8f}".format(average_f1_score))
|
||||
print("=============================================\n\n")
|
||||
|
||||
print("********************** Testing Finished **********************")
|
||||
else:
|
||||
raise ValueError("metric method not supported in Reading Comprehension task, support: [F1]")
|
||||
|
||||
|
||||
def run_Readcomprehension():
|
||||
'''
|
||||
run Readcomprehension task
|
||||
|
||||
'''
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate translation")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=0,
|
||||
help="ID of target device. ")
|
||||
parser.add_argument("--metric_method", type=str, default="F1",
|
||||
help="The eval method including [F1]. Default: F1.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: false.")
|
||||
parser.add_argument("--eval_type", type=str, default="zero-shot",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--epoch_num", type=int, default=1,
|
||||
help="Epoch number. Default: 1.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--tokenizer_file_path", type=str, default="",
|
||||
help="pretrained vocab and merge file path.")
|
||||
|
||||
parser.add_argument("--generate_length", type=int, default=55,
|
||||
help="The generation length of translation sentence.")
|
||||
parser.add_argument("--top_k", type=int, default=1,
|
||||
help="Parameter for Top-K sampling.")
|
||||
parser.add_argument("--top_p", type=str, default="1.0",
|
||||
help="parameter for Top-P sampling.")
|
||||
parser.add_argument("--temperature", type=str, default="1.0",
|
||||
help="Parameter for generation, greater if generation more diverse. ")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device_target = args_opt.device_target
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=device_target,
|
||||
device_id=args_opt.device_id,
|
||||
max_call_depth=3000)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
gpt2_loss = GPT2CoQA(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
print("============== Start Loading Translation Train Dataset ==============")
|
||||
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
print("============ Start Loading Translation Evaluation Dataset ============")
|
||||
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
|
||||
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.eval_data_file_path)
|
||||
do_eval(eval_dataset, GPT2CoQAModel, metric, load_finetune_ckpt_path, args_opt.eval_type,
|
||||
args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p,
|
||||
args_opt.temperature)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_Readcomprehension()
|
||||
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,328 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for LAMBADA task.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Lambada
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.utils.metric_method import LastWordAccuracy
|
||||
from src.dataset import create_language_model_dataset, create_lambada_control_dataset
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.utils.task_utils import get_final_word_label
|
||||
from src.utils.tokenization import Tokenizer
|
||||
from src.GPT2_generation import GenerateForLambada
|
||||
from src.utils.CrossEntropy import CrossEntropyCalculationWithMask
|
||||
from src.utils.get_config_setting import get_train_setting, get_model_setting
|
||||
from src.utils.task_utils import calculate_final_word_loss
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrain model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetune model checkpoint.
|
||||
epoch_num: the number of epoch
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_" + "lambada_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
||||
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load pretrained parameter successfully!\n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("==================== Starting Finetuning ====================")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("==================== Finetuning Success ====================")
|
||||
|
||||
|
||||
def eval_result_print(metric="accuracy", callback=None):
|
||||
"""
|
||||
Print eval result.
|
||||
"""
|
||||
if metric.lower() == "accuracy":
|
||||
print("acc_num {}, total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
else:
|
||||
raise ValueError("metric method not supported, support: [accuracy]")
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, stop_word_file="",
|
||||
generate_length_dynamic=True, tokenizer_file_path=""):
|
||||
"""
|
||||
Do eval
|
||||
Args:
|
||||
dataset: the eval dataset.
|
||||
network: the network with loss.
|
||||
metric: the evaluation method.
|
||||
load_checkpoint_path: the file path which saved finetune model checkpoint.
|
||||
eval_type: the eval type, i.e. zero-shot, finetuned.
|
||||
generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True.
|
||||
tokenizer_file_path: the tokenizer file path for vocab file and merge file.
|
||||
stop_word_file: stop word file for calculating Accuracy.
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
|
||||
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
|
||||
|
||||
gpt2_lambada = network(config=gpt2_net_cfg,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
gpt2_lambada.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
if eval_type == "zero-shot":
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
load_param_into_net(gpt2_lambada, final_param_dict)
|
||||
print("load pretrained parameter successfully!\n")
|
||||
elif eval_type == "finetuned":
|
||||
load_param_into_net(gpt2_lambada, param_dict)
|
||||
print("load finetuned parameter successfully!\n")
|
||||
|
||||
model = Model(gpt2_lambada)
|
||||
|
||||
if metric.lower() == "accuracy":
|
||||
print("Prepare to calculate the accuracy score ...")
|
||||
|
||||
callback = LastWordAccuracy()
|
||||
columns_list = ["input_ids", "input_mask", "input_length"]
|
||||
print("==================== [ACC] Testing ====================")
|
||||
lambada_generator = GenerateForLambada(decoder=model,
|
||||
config=gpt2_net_cfg,
|
||||
tokenizer=tokenizer,
|
||||
generate_length_dynamic=generate_length_dynamic,
|
||||
max_iterations=200,
|
||||
stop_word_file=stop_word_file)
|
||||
|
||||
num_data = 1
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, input_length = input_data
|
||||
print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
|
||||
|
||||
logits = model.predict(input_ids, input_mask)
|
||||
predict_str = lambada_generator.generate_for_lambada(input_ids=input_ids,
|
||||
logits=logits,
|
||||
input_length=input_length)
|
||||
label_str = get_final_word_label(input_ids=input_ids, input_length=input_length, tokenizer=tokenizer)
|
||||
callback.update(predict_str, label_str)
|
||||
eval_result_print(metric, callback)
|
||||
num_data += 1
|
||||
|
||||
print("\n\n")
|
||||
print("**********************************************************")
|
||||
eval_result_print(metric, callback)
|
||||
print("******************** Testing Finished ********************")
|
||||
|
||||
elif metric.lower() == "ppl":
|
||||
print("Prepare to calculate the ppl score ...")
|
||||
cross_entropy = CrossEntropyCalculationWithMask(is_training=True,
|
||||
num_labels=gpt2_net_cfg.vocab_size,
|
||||
config=gpt2_net_cfg)
|
||||
columns_list = ["input_ids", "input_mask", "input_length"]
|
||||
num_data = 1
|
||||
total_loss = 0.0
|
||||
print("==================== [PPL] Testing ====================")
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, input_length = input_data
|
||||
print("| [PPL] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
|
||||
|
||||
logits = model.predict(input_ids, input_mask) # (batch_size, seq_len, vocab_size)
|
||||
avg_batch_loss = calculate_final_word_loss(logits,
|
||||
gpt2_net_cfg.batch_size,
|
||||
input_ids,
|
||||
input_length,
|
||||
cross_entropy)
|
||||
|
||||
total_loss += avg_batch_loss
|
||||
avg_total_loss = total_loss / num_data
|
||||
print(" | Current AVG loss:", avg_total_loss)
|
||||
print(" | Current AVG ppl:", math.exp(avg_total_loss))
|
||||
num_data += 1
|
||||
|
||||
print("\n\n")
|
||||
print("**********************************************************")
|
||||
print("Average PPL: {:.6f}".format(math.exp(avg_total_loss)))
|
||||
print("******************** Testing Finished ********************")
|
||||
|
||||
else:
|
||||
|
||||
raise ValueError("metric method not supported, support: [accuracy, ppl]")
|
||||
|
||||
|
||||
def run_lambada():
|
||||
"""
|
||||
Run Lambada task.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate languagemodel")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=2,
|
||||
help="ID of target device.")
|
||||
parser.add_argument("--metric_method", type=str, default="PPL",
|
||||
help="The eval method including [Accuracy, PPL]. Default: Accuracy.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: false.")
|
||||
parser.add_argument("--eval_type", type=str, default="finetuned",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--epoch_num", type=int, default=3,
|
||||
help="Epoch number. Default: 1.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="false",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
|
||||
parser.add_argument("--generate_length_dynamically", type=str, default="true",
|
||||
help="Enable generate_length_Dynamically. Default: true.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path.")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path.")
|
||||
parser.add_argument("--tokenizer_file_path", type=str, default="",
|
||||
help="pretrained vocab and merge file path.")
|
||||
parser.add_argument("--stop_word_file_path", type=str, default="",
|
||||
help="The stop word file path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device = args_opt.device_target
|
||||
if device == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
gpt2_loss = GPT2Lambada(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
get_train_setting(cfg)
|
||||
get_model_setting(cfg, gpt2_net_cfg)
|
||||
print("============== Start Loading Train Dataset ============")
|
||||
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
get_model_setting(cfg, gpt2_net_cfg)
|
||||
print("============== Start Loading Evaluation Dataset ============")
|
||||
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
|
||||
eval_dataset = create_lambada_control_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.eval_data_file_path)
|
||||
do_eval(eval_dataset, GPT2Lambada, metric, load_finetune_ckpt_path, args_opt.eval_type,
|
||||
args_opt.stop_word_file_path, args_opt.generate_length_dynamically, args_opt.tokenizer_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_lambada()
|
||||
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for Language Modeling task.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2LM
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.dataset import create_language_model_dataset
|
||||
from src.utils.get_config_setting import get_train_setting, get_model_setting
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrained model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetuned model checkpoint.
|
||||
epoch_num: the number of epoch.
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_language_model_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
||||
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load pretrained parameter successfully!\n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("==================== Starting Finetuning ====================")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("==================== Finetuning Success ====================")
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None):
|
||||
"""
|
||||
Do eval
|
||||
Args:
|
||||
dataset: the eval dataset.
|
||||
network: the network with loss.
|
||||
metric: the evaluation method.
|
||||
load_checkpoint_path: the file path which saved finetuned model checkpoint.
|
||||
eval_type:
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
|
||||
if metric.lower() == "ppl":
|
||||
print("Prepare to calculate the ppl score ...")
|
||||
gpt2_loss = network(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
gpt2_loss.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
if eval_type == "zero-shot":
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
load_param_into_net(gpt2_loss, final_param_dict)
|
||||
print("load pretrained parameter successfully!\n")
|
||||
elif eval_type == "finetuned":
|
||||
load_param_into_net(gpt2_loss, param_dict)
|
||||
print("load finetuned parameter successfully!\n")
|
||||
else:
|
||||
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
|
||||
|
||||
model = Model(gpt2_loss)
|
||||
columns_list = ["input_ids", "input_mask", "label_ids"]
|
||||
print("==================== [PPL] Testing ====================")
|
||||
num_data = 1
|
||||
total_loss = 0.0
|
||||
avg_loss = 0.0
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, label_ids = input_data
|
||||
loss = model.predict(input_ids, input_mask, label_ids)
|
||||
loss = float(loss.asnumpy())
|
||||
total_loss += loss
|
||||
avg_loss = float(total_loss / num_data)
|
||||
print(" | Current Loss: {:.6f}".format(avg_loss))
|
||||
print(" | Current PPL: {}\n\n".format(math.exp(avg_loss)))
|
||||
num_data += 1
|
||||
|
||||
print("\n\n")
|
||||
print("**************************************************************")
|
||||
print("Average Loss: {:.6f}".format(avg_loss))
|
||||
print("Average PPL: {:.6f}".format(math.exp(avg_loss)))
|
||||
print("********************** Testing Finished **********************")
|
||||
else:
|
||||
raise ValueError("metric method not supported, support: [ppl]")
|
||||
|
||||
|
||||
def run_languagemodel():
|
||||
"""
|
||||
run Language Modeling task
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate language modelings task")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=1,
|
||||
help="ID of target device. ")
|
||||
parser.add_argument("--metric_method", type=str, default="PPL",
|
||||
help="The eval method including [PPL]. Default: PPL.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: true.")
|
||||
parser.add_argument("--eval_type", type=str, default="zero-shot",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--epoch_num", type=int, default=1,
|
||||
help="Epoch number. Default: 1.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the finetuned checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path for train.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path for evaluation.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device_target = args_opt.device_target
|
||||
if device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=device_target,
|
||||
device_id=args_opt.device_id,
|
||||
max_call_depth=3000)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
gpt2_loss = GPT2LM(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
get_train_setting(cfg)
|
||||
get_model_setting(cfg, gpt2_net_cfg)
|
||||
print("==================== Start Loading Train Dataset ==================")
|
||||
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
get_model_setting(cfg, gpt2_net_cfg)
|
||||
print("==================== Start Loading Evaluation Dataset ==================")
|
||||
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
|
||||
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.eval_data_file_path)
|
||||
do_eval(eval_dataset, GPT2LM, metric, load_finetune_ckpt_path, args_opt.eval_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_languagemodel()
|
||||
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,296 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for Summarization task.
|
||||
"""
|
||||
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.GPT2ForSummarization import GPT2SummarizationModel
|
||||
from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.utils.metric_method import Rouge
|
||||
from src.dataset import create_language_model_dataset
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.utils.tokenization import Tokenizer
|
||||
from src.utils.task_utils import clean_hypo, modify_paramdict
|
||||
from src.GPT2_generation import GenerateForSummarization
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrain model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetune model checkpoint.
|
||||
epoch_num: the number of epoch
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(
|
||||
filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
||||
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load pretrained parameter successfully!\n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
model = Model(netwithgrads)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("============== Starting Finetuning ==============")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("============== Finetuning Success ==============")
|
||||
|
||||
|
||||
def eval_result_print(metric="Rouge", callback=None):
|
||||
"""
|
||||
print eval result
|
||||
"""
|
||||
if metric == "Rouge":
|
||||
print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}".
|
||||
format(callback.Rouge1 / callback.total_num,
|
||||
callback.Rouge2 / callback.total_num,
|
||||
callback.RougeL / callback.total_num,
|
||||
(callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num)))
|
||||
else:
|
||||
raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric)))
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="",
|
||||
top_k=None, top_p=None, temperature=None, generate_length=None):
|
||||
"""
|
||||
Do evaluation on summarization
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
if metric.lower() == "rouge":
|
||||
print("Prepare to calculate the Rouge score ...")
|
||||
callback = Rouge()
|
||||
|
||||
gpt2_loss = network(config=gpt2_net_cfg,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
gpt2_loss.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.")
|
||||
load_param_into_net(gpt2_loss, reorganized_param_dict)
|
||||
|
||||
# load nn.Cell into Model and initiate tokenizer and Sample
|
||||
model = Model(gpt2_loss)
|
||||
tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json',
|
||||
merge_file=tokenizer_file + 'gpt2-merges.txt')
|
||||
|
||||
# load data and process text generation
|
||||
columns_list = ["input_ids", "input_mask", "label_ids"]
|
||||
|
||||
summarization_generator = GenerateForSummarization(model,
|
||||
config=gpt2_net_cfg,
|
||||
tokenizer=tokenizer,
|
||||
select_sentence=3,
|
||||
eval_type=eval_type,
|
||||
topk=top_k,
|
||||
topp=float(top_p),
|
||||
temperature=float(temperature),
|
||||
generate_length=generate_length)
|
||||
num_data = 1
|
||||
print("==================== [Summrization] Testing ====================")
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for value in columns_list:
|
||||
input_data.append(data[value])
|
||||
input_ids, _, label_ids = input_data
|
||||
print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
|
||||
print("input_ids shape: {}".format(input_ids.shape))
|
||||
print("label_ids shape: {}".format(label_ids.shape))
|
||||
|
||||
hypothesis, ref = summarization_generator.generate_for_summarization(input_ids)
|
||||
if ref[0] == '' or ref[0] is None:
|
||||
print("Sorry ref_list is None, skip it!")
|
||||
continue
|
||||
|
||||
print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n")
|
||||
for batch_idx in range(gpt2_net_cfg.batch_size):
|
||||
hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx])
|
||||
for batch_idx in range(gpt2_net_cfg.batch_size):
|
||||
hypothesis[batch_idx] = hypothesis[batch_idx].lower()
|
||||
ref[batch_idx] = ref[batch_idx].lower()
|
||||
|
||||
callback.update(hypothesis, ref)
|
||||
num_data += 1
|
||||
|
||||
print("\n\n")
|
||||
print("**********************************************************")
|
||||
eval_result_print(metric, callback)
|
||||
print("******************** Testing Finished ********************")
|
||||
|
||||
else:
|
||||
raise ValueError("metric method not supported in summarization, support: [Rouge]")
|
||||
|
||||
|
||||
def run_summarization():
|
||||
"""
|
||||
Run Summarization task.
|
||||
"""
|
||||
# set argument parser
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization")
|
||||
|
||||
# context and task settings
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=4,
|
||||
help="ID of target device.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: false.")
|
||||
parser.add_argument("--eval_type", type=str, default="finetuned",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--metric_method", type=str, default="Rouge",
|
||||
help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.")
|
||||
parser.add_argument("--epoch_num", type=int, default=2,
|
||||
help="Epoch number. Default: 2.")
|
||||
|
||||
# dataset and params_dict file settings
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
|
||||
# sampling settings
|
||||
parser.add_argument("--top_k", type=int, default=2,
|
||||
help="top k tokens chosen for sampling")
|
||||
parser.add_argument("--top_p", type=str, default="1.0",
|
||||
help="top p accumulated probability threshold for logit to be counted")
|
||||
parser.add_argument("--generate_length", type=int, default=100,
|
||||
help="the number of generated tokens.")
|
||||
parser.add_argument("--temperature", type=str, default="1.0",
|
||||
help="temperature on logits for sampling")
|
||||
parser.add_argument("--tokenizer_file_path", type=str, default="",
|
||||
help="vocab & merge file path")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
eval_type = args_opt.eval_type
|
||||
tokenizer_file = args_opt.tokenizer_file_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device = args_opt.device_target
|
||||
if device == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
train_data_file_path = args_opt.train_data_file_path
|
||||
gpt2_loss = GPT2Summarization(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
print("============== Start Loading Train Dataset ============")
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
eval_dataset_file_path = args_opt.eval_data_file_path
|
||||
print("============== Start Loading Evaluation Dataset ============")
|
||||
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=eval_dataset_file_path)
|
||||
do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file,
|
||||
args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_summarization()
|
||||
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 finetune and evaluation script for Translation task.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.GPT2ForTranslation import GPT2TranslationModel
|
||||
from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Translation
|
||||
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
||||
from src.dataset import create_language_model_dataset
|
||||
from src.utils.lr_schedule import GPT2LearningRate
|
||||
from src.utils.tokenization import Tokenizer
|
||||
from src.utils.metric_method import BLEU
|
||||
from src.GPT2_generation import GenerateForTranslation
|
||||
|
||||
|
||||
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
||||
"""
|
||||
Do train
|
||||
Args:
|
||||
dataset: the train dataset.
|
||||
network: the network with loss
|
||||
load_checkpoint_path: the file path which saved pretrained model checkpoint.
|
||||
save_checkpoint_path: the file path which will save finetuned model checkpoint.
|
||||
epoch_num: the number of epoch.
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
||||
|
||||
steps_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# optimizer
|
||||
if cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == 'Lamb':
|
||||
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
||||
decay_steps=steps_per_epoch * epoch_num,
|
||||
power=cfg.Lamb.power)
|
||||
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
||||
else:
|
||||
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
||||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
prefix_name = "gpt2_translation_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
||||
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
|
||||
load_param_into_net(network, final_param_dict)
|
||||
print("Load the pretrained parameter successfully! \n")
|
||||
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
||||
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
||||
netwithgrads.set_train(True)
|
||||
loss_cb = LossMonitor(per_print_times=1)
|
||||
|
||||
model = Model(netwithgrads)
|
||||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
||||
|
||||
print("=================== Starting Training For Translation Task ====================")
|
||||
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
print("=================== Translation Training Success ====================")
|
||||
|
||||
|
||||
def eval_result_print(metric="BLEU", callback=None):
|
||||
""" print eval result"""
|
||||
if metric == "BLEU":
|
||||
print(" | BLEU: {:.6f}".format(callback.bleu / float(callback.total_num)))
|
||||
else:
|
||||
raise ValueError("metric method '{}' not supported, support: [BLEU]. ".format(str(metric)))
|
||||
|
||||
|
||||
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="",
|
||||
generate_length=1, top_k=1, top_p=1.0, temperature=1.0):
|
||||
"""
|
||||
Do evaluation on Translation
|
||||
Args:
|
||||
dataset: the eval dataset.
|
||||
network: the network with loss.
|
||||
metric: the evaluation method.
|
||||
load_checkpoint_path: the file path which saved finetune model checkpoint.
|
||||
|
||||
"""
|
||||
if load_checkpoint_path == "":
|
||||
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
||||
if metric.lower() == "bleu":
|
||||
print("Prepare to calculate the BLEU score ...")
|
||||
|
||||
gpt2_translation = network(config=gpt2_net_cfg,
|
||||
is_training=False,
|
||||
use_one_hot_embeddings=False)
|
||||
gpt2_translation.set_train(False)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
|
||||
if eval_type == "zero-shot":
|
||||
final_param_dict = {}
|
||||
for name, _ in param_dict.items():
|
||||
final_param_dict['gpt2.' + name] = param_dict[name]
|
||||
final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
load_param_into_net(gpt2_translation, final_param_dict)
|
||||
print("load pretrained parameter successfully!\n")
|
||||
elif eval_type == "finetuned":
|
||||
load_param_into_net(gpt2_translation, param_dict)
|
||||
print("load finetuned parameter successfully!\n")
|
||||
else:
|
||||
raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
|
||||
|
||||
model = Model(gpt2_translation)
|
||||
tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
|
||||
merge_file=tokenizer_file_path + 'gpt2-merges.txt')
|
||||
callback = BLEU(tokenizer)
|
||||
translation_generator = GenerateForTranslation(decoder=model,
|
||||
config=gpt2_net_cfg,
|
||||
tokenizer=tokenizer,
|
||||
generate_length=1,
|
||||
use_hint=True,
|
||||
select_first_sentence=True,
|
||||
topk_num=top_k,
|
||||
topp_prob=float(top_p),
|
||||
temperature=float(temperature)
|
||||
)
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "label_ids"]
|
||||
print("==================== [BLEU] Testing ====================")
|
||||
num_data = 1
|
||||
for data in dataset.create_dict_iterator():
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, label_ids = input_data
|
||||
|
||||
print("| Data count: {}".format(num_data * gpt2_net_cfg.batch_size))
|
||||
print("input_ids shape: {}".format(input_ids.shape))
|
||||
print("input_mask shape: {}".format(input_mask.shape))
|
||||
print("label_ids shape: {}".format(label_ids.shape))
|
||||
|
||||
ts_predict_list, ref_list = translation_generator.generate_for_translation(input_ids)
|
||||
print("| Batch Reference translation:\n{}\n".format(ref_list))
|
||||
if ref_list == '' or ref_list is None:
|
||||
print("Sorry ref_list is None, skip it!")
|
||||
continue
|
||||
else:
|
||||
print(" | Batch Predict translation:\n{}\n".format(ts_predict_list))
|
||||
callback.update(ref_list, ts_predict_list)
|
||||
num_data += 1
|
||||
print("\n\n")
|
||||
|
||||
print("**************************************************************")
|
||||
eval_result_print(metric, callback)
|
||||
print("********************** Testing Finished **********************")
|
||||
else:
|
||||
raise ValueError("metric method not supported in translation, support: [BLEU]")
|
||||
|
||||
|
||||
def run_translation():
|
||||
"""
|
||||
run translation task
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Finetune and Evaluate translation")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="Device type. Default: Ascend.")
|
||||
parser.add_argument("--device_id", type=int, default=0,
|
||||
help="ID of target device. ")
|
||||
parser.add_argument("--metric_method", type=str, default="BLEU",
|
||||
help="The eval method including [BLEU]. Default: BLEU.")
|
||||
parser.add_argument("--do_train", type=str, default="false",
|
||||
help="Enable train. Default: false.")
|
||||
parser.add_argument("--do_eval", type=str, default="true",
|
||||
help="Enable evaluation. Default: false.")
|
||||
parser.add_argument("--eval_type", type=str, default="zero-shot",
|
||||
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
||||
parser.add_argument("--epoch_num", type=int, default=1,
|
||||
help="Epoch number. Default: 1.")
|
||||
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
||||
help="Enable train data shuffle. Default: true.")
|
||||
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
||||
help="Enable eval data shuffle. Default: false.")
|
||||
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
||||
help="Save the checkpoint path.")
|
||||
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
||||
help="Load the checkpoint file path.")
|
||||
parser.add_argument("--train_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--eval_data_file_path", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--tokenizer_file_path", type=str, default="",
|
||||
help="pretrained vocab and merge file path.")
|
||||
|
||||
parser.add_argument("--generate_length", type=int, default=150,
|
||||
help="The generation length of translation sentence.")
|
||||
parser.add_argument("--top_k", type=int, default=1,
|
||||
help="Parameter for Top-K sampling.")
|
||||
parser.add_argument("--top_p", type=str, default="1.0",
|
||||
help="parameter for Top-P sampling.")
|
||||
parser.add_argument("--temperature", type=str, default="1.0",
|
||||
help="Parameter for generation, greater if generation more diverse. ")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
epoch_num = args_opt.epoch_num
|
||||
metric = args_opt.metric_method
|
||||
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
||||
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
||||
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
||||
|
||||
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
||||
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
||||
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
||||
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
||||
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
||||
|
||||
device_target = args_opt.device_target
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=device_target,
|
||||
device_id=args_opt.device_id,
|
||||
max_call_depth=3000)
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
|
||||
else:
|
||||
raise Exception("Device target error, Ascend is supported.")
|
||||
|
||||
gpt2_loss = GPT2Translation(config=gpt2_net_cfg,
|
||||
is_training=True,
|
||||
use_one_hot_embeddings=False)
|
||||
|
||||
if args_opt.do_train.lower() == "true":
|
||||
print("============== Start Loading Translation Train Dataset ==============")
|
||||
print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
|
||||
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.train_data_file_path)
|
||||
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
||||
|
||||
if args_opt.do_eval.lower() == "true":
|
||||
print("============ Start Loading Translation Evaluation Dataset ============")
|
||||
print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
|
||||
print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
|
||||
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
|
||||
dataset_path=args_opt.eval_data_file_path)
|
||||
do_eval(eval_dataset, GPT2TranslationModel, metric, load_finetune_ckpt_path, args_opt.eval_type,
|
||||
args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p,
|
||||
args_opt.temperature)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
run_translation()
|
||||
print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@ -0,0 +1,60 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_cbt.sh"
|
||||
echo "for example: bash scripts/run_cbt.sh"
|
||||
echo "metric method: Accuracy"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_cbt.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_CBT_task.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--num_choice=10 \
|
||||
--metric_method="Accuracy" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &
|
|
@ -0,0 +1,68 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_lambada.sh"
|
||||
echo "for example: bash scripts/run_lambada.sh"
|
||||
echo "method metric include: [Accuracy, PPL]"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_lambada.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
# stopword path
|
||||
stop_word_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_lambada.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=1 \
|
||||
--metric_method="PPL" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--generate_length_dynamically="true" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--stop_word_file_path=$stop_word_file_path >> $output_log 2>&1 &
|
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_language_model.sh"
|
||||
echo "for example: bash scripts/run_language_model.sh"
|
||||
echo "metric method: PPL"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_language_model.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_language_model.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--metric_method="PPL" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &
|
|
@ -0,0 +1,67 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_read_comprehension.sh"
|
||||
echo "for example: bash scripts/run_read_comprehension.sh"
|
||||
echo "metric method: F1"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_read_comprehension.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_ReadComprehension.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=7 \
|
||||
--metric_method="F1" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--generate_length=55 \
|
||||
--top_k=1 \
|
||||
--top_p="1.0" \
|
||||
--temperature="1.0" >> $output_log 2>&1 &
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_summarization.sh"
|
||||
echo "for example: bash scripts/run_summarization.sh"
|
||||
echo "eval_load_param_mode include: [zero-shot, finetuned]. Default: finetuned"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_summarization.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_summarization.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=0 \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--metric_method="Rouge" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--top_k=2 \
|
||||
--top_p="1.0" \
|
||||
--generate_length=100 \
|
||||
--temperature="1.0" \
|
||||
--eval_type="finetuned" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path >> $output_log 2>&1 &
|
|
@ -0,0 +1,67 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_translation.sh"
|
||||
echo "for example: bash scripts/run_translation.sh"
|
||||
echo "metric method: BLEU"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_translation.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_translation.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--metric_method="BLEU" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--generate_length=100 \
|
||||
--top_k=1 \
|
||||
--top_p="1.0" \
|
||||
--temperature="1.0" >> $output_log 2>&1 &
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (CBT) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2CBTModel(nn.Cell):
|
||||
"""
|
||||
GPT2CBTModel is responsible for Children's Book Test (CBT) task, i.e. CBT-CN, CBT-NE datasets.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
config: the configuration of GPT-2 model
|
||||
is_training (bool): `True` for train (finetune), `False` for evaluation.
|
||||
use_one_hot_embeddings (bool): default False.
|
||||
"""
|
||||
super(GPT2CBTModel, self).__init__()
|
||||
if not is_training:
|
||||
config.summary_first_dropout = 0.0
|
||||
|
||||
self.is_training = is_training
|
||||
self.d_model = config.d_model
|
||||
self.batch_size = config.batch_size
|
||||
self.seq_length = config.seq_length
|
||||
self.vocab_size = config.vocab_size
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
|
||||
self.dtype = config.dtype
|
||||
self.lm_head = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=False).to_float(config.compute_type)
|
||||
|
||||
self.first_dropout = nn.Dropout(1 - config.summary_first_dropout)
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): shape with [batch_size, seq_len]
|
||||
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor): language model distribution with log_softmax,
|
||||
shape with [batch_size, seq_len, vocab_size]
|
||||
|
||||
"""
|
||||
output, _ = self.gpt2(input_ids, input_mask) # output shape is [batch_size, seq_len, d_model]
|
||||
output = self.cast(output, self.dtype)
|
||||
output = self.reshape(output, (-1, self.d_model))
|
||||
output = self.first_dropout(output)
|
||||
lm_logits = self.lm_head(output) # [batch_size * seq_len, vocab_size]
|
||||
lm_logits = self.reshape(lm_logits, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
lm_logits = self.cast(lm_logits, self.dtype)
|
||||
lm_logits = self.log_softmax(lm_logits)
|
||||
|
||||
return lm_logits
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.lm_head.weight
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (LAMBADA) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2LambadaModel(nn.Cell):
|
||||
"""
|
||||
GPT2LambadaModel is responsible for Lambada task, i.e. Lambada-train, Lambada-test datasets.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
config: the configuration of GPT-2 model
|
||||
is_training (bool): `True` for train (finetune), `False` for evaluation.
|
||||
use_one_hot_embeddings (bool): default False.
|
||||
"""
|
||||
super(GPT2LambadaModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout = 0.0
|
||||
self.vocab_size = config.vocab_size
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.dtype = config.dtype
|
||||
self.dense1 = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(mstype.float16)
|
||||
self.dropout = nn.Dropout(1 - config.hidden_dropout)
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Args:
|
||||
input_ids (Tensor): shape with [batch_size, seq_len]
|
||||
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor): language model distribution with log_softmax,
|
||||
shape with [batch_size, seq_len, vocab_size]
|
||||
"""
|
||||
output, _ = self.gpt2(input_ids, input_mask)
|
||||
output = self.cast(output, self.dtype)
|
||||
output = self.dropout(output)
|
||||
batch_size, seq_length, d_model = self.shape(output)
|
||||
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
|
||||
logits = self.dense1(output_reshape)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = self.log_softmax(logits)
|
||||
lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
|
||||
return lm_logits
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (Language Modeling) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2LanguageModel(nn.Cell):
|
||||
"""
|
||||
GPT2LanguageModel is responsible for Language Modeling task, i.e. WikiText2, WikiText103, PTB, 1BW datasets.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
config: the configuration of GPT-2 model
|
||||
is_training (bool): `True` for train (finetune), `False` for evaluation.
|
||||
use_one_hot_embeddings (bool): default False.
|
||||
"""
|
||||
super(GPT2LanguageModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout = 0.0
|
||||
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = config.dtype
|
||||
self.dense1 = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=False).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - config.hidden_dropout)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
|
||||
where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model].
|
||||
"""
|
||||
output, _ = self.gpt2(input_ids, input_mask)
|
||||
output = self.cast(output, self.dtype)
|
||||
batch_size, seq_length, d_model = self.shape(output)
|
||||
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
|
||||
logits = self.dense1(output_reshape)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = self.log_softmax(logits)
|
||||
lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) # [batch_size, seq_len, vocab]
|
||||
|
||||
return lm_logits
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (Reading Comprehension) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2CoQAModel(nn.Cell):
|
||||
"""
|
||||
This class is responsible for CoQA
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(GPT2CoQAModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout = 0.0
|
||||
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense1 = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=self.weight_init,
|
||||
has_bias=False).to_float(config.compute_type)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.dtype = config.dtype
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
|
||||
where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model].
|
||||
"""
|
||||
decoder_output, _ = self.gpt2(input_ids, input_mask)
|
||||
decoder_output = P.Cast()(decoder_output, self.dtype)
|
||||
batch_size, seq_length, d_model = P.Shape()(decoder_output)
|
||||
reshaped_ouput = P.Reshape()(decoder_output, (-1, d_model)) # [batch_size * seq_length, d_model]
|
||||
logits = self.dense1(reshaped_ouput)
|
||||
logits = P.Cast()(logits, self.dtype)
|
||||
logits = self.log_softmax(logits)
|
||||
logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
|
||||
return logits
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (Summarization) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2SummarizationModel(nn.Cell):
|
||||
"""
|
||||
GPT2SummarizationModel is responsible for summary task, i.e. cnn_dailymail datasets.
|
||||
|
||||
Args:
|
||||
config: the configuration of GPT-2 model
|
||||
is_training (bool): `True` for train (finetune), `False` for evaluation.
|
||||
use_one_hot_embeddings (bool): default False.
|
||||
"""
|
||||
def __init__(self, config, is_training=True, use_one_hot_embeddings=False):
|
||||
super(GPT2SummarizationModel, self).__init__()
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.lm_head = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=False).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.dtype = config.dtype
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
|
||||
where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor): language model distribution without log_softmax,
|
||||
shape with [batch_size, seq_len, d_model].
|
||||
"""
|
||||
output, _ = self.gpt2(input_ids, input_mask)
|
||||
output = self.cast(output, self.dtype)
|
||||
batch_size, seq_length, d_model = self.shape(output)
|
||||
|
||||
hidden_state = self.reshape(output, (-1, d_model))
|
||||
hidden_state = self.cast(hidden_state, self.dtype)
|
||||
lm_logits = self.lm_head(hidden_state)
|
||||
lm_logits = self.cast(lm_logits, self.dtype)
|
||||
lm_logits = self.reshape(lm_logits, (batch_size, seq_length, -1))
|
||||
|
||||
return lm_logits
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 downstream task (Translation) model script.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
from .GPT2_model import GPT2Model
|
||||
|
||||
|
||||
class GPT2TranslationModel(nn.Cell):
|
||||
"""
|
||||
GPT2TranslationModel is responsible for translation task, i.e. WMT-14 En-Fr, WMT-14 Fr-En datasets.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
config: the configuration of GPT-2 model
|
||||
is_training (bool): `True` for train (finetune), `False` for evaluation.
|
||||
use_one_hot_embeddings (bool): default False.
|
||||
"""
|
||||
super(GPT2TranslationModel, self).__init__()
|
||||
if not is_training:
|
||||
config.hidden_dropout = 0.0
|
||||
|
||||
self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = config.dtype
|
||||
self.dense1 = nn.Dense(config.d_model,
|
||||
config.vocab_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=True).to_float(config.compute_type)
|
||||
self.dropout = nn.Dropout(1 - config.hidden_dropout)
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences shape with [batch_size, seq_len]
|
||||
input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask
|
||||
|
||||
Returns:
|
||||
translation_logits (Tensor): language model distribution without log_softmax,
|
||||
shape with [batch_size, seq_len, vocab_size]
|
||||
|
||||
"""
|
||||
output, _ = self.gpt2(input_ids, input_mask)
|
||||
output = self.cast(output, self.dtype)
|
||||
output = self.dropout(output)
|
||||
batch_size, seq_length, d_model = self.shape(output)
|
||||
output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model]
|
||||
logits = self.dense1(output_reshape)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
translation_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size))
|
||||
|
||||
return translation_logits
|
|
@ -0,0 +1,366 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
generation class for downstream task (Summarization, Reading Comprehension, Translation)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from .utils.task_utils import extract_logits
|
||||
from .utils.generation_utils import Sample
|
||||
from .utils.tensor_manipulations import extract_string_from_tensor
|
||||
|
||||
INF = 1. * 1e9
|
||||
|
||||
|
||||
class GenerateForSummarization():
|
||||
"""
|
||||
generate for summarization task
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decoder,
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
select_sentence=3,
|
||||
eval_type="finetuned",
|
||||
temperature=1.0,
|
||||
generate_length=100,
|
||||
topk=2,
|
||||
topp=1.0):
|
||||
|
||||
self.decoder = decoder
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
self.select_sentence = select_sentence
|
||||
self.eval_type = eval_type
|
||||
|
||||
self.generator = Sample(decoder,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
topk_num=topk,
|
||||
topp_prob=topp,
|
||||
min_tokens_to_keep=1,
|
||||
demo_mode=False,
|
||||
temperature=temperature)
|
||||
self.generate_length = generate_length
|
||||
|
||||
def generate_for_summarization(self, input_ids):
|
||||
"""generation function for summarization task"""
|
||||
|
||||
# prepare input_str
|
||||
article_str, summary_str = extract_string_from_tensor(input_ids=input_ids,
|
||||
mode="pair",
|
||||
config=self.config,
|
||||
tokenizer=self.tokenizer)
|
||||
generated_summary_list = [""] * self.config.batch_size
|
||||
|
||||
# clip overflow
|
||||
for batch_idx in range(self.config.batch_size):
|
||||
last_dot_pos = max(article_str[batch_idx].rfind(' .'), article_str[batch_idx].rfind('. ')) + 2
|
||||
article_str[batch_idx] = article_str[batch_idx][:last_dot_pos]
|
||||
|
||||
# pad a <TL,DR;> token(<EOS>) after the string of Article.
|
||||
tldr_str = "TL;DR:"
|
||||
if self.eval_type == "finetuned":
|
||||
for batch_idx in range(self.config.batch_size):
|
||||
article_str[batch_idx] += (" " + tldr_str)
|
||||
|
||||
# add prefix
|
||||
for batch_idx in range(self.config.batch_size):
|
||||
article_str[batch_idx] = article_str[batch_idx]
|
||||
generate_str_list, _ = self.generator.generate(input_str=article_str, generate_length=self.generate_length)
|
||||
for batch_idx in range(self.config.batch_size):
|
||||
generate_str = generate_str_list[batch_idx]
|
||||
generated_summary = ""
|
||||
|
||||
if self.select_sentence > 0:
|
||||
# check if there are number of select_sentence of sentences in generated text,
|
||||
# if not enough, it will return full generated string
|
||||
len_generate_str = len(generate_str)
|
||||
search_index = -1
|
||||
for _ in range(self.select_sentence):
|
||||
search_index = generate_str.find('.', search_index + 1)
|
||||
if search_index == -1 or search_index >= len_generate_str:
|
||||
search_index = len_generate_str
|
||||
break
|
||||
|
||||
# increase search_index to add period token('.') if search_index does not overflow.
|
||||
search_index = search_index + 1 if search_index < len_generate_str else len_generate_str
|
||||
generated_summary = generate_str[:search_index]
|
||||
if generated_summary.find(self.tokenizer.eos_token) != -1:
|
||||
cut_pos = generated_summary.find(self.tokenizer.eos_token, 0)
|
||||
generated_summary = generated_summary[:cut_pos]
|
||||
else:
|
||||
generated_summary = generate_str
|
||||
|
||||
# if all of str hs been clipped, restore it to beginning state.
|
||||
if generated_summary == '':
|
||||
generated_summary = generate_str
|
||||
# empty str check
|
||||
if generated_summary == '':
|
||||
generated_summary = '<empty>'
|
||||
generated_summary_list[batch_idx] = generated_summary
|
||||
|
||||
return generated_summary_list, summary_str # Hypo and Ref
|
||||
|
||||
|
||||
class GenerateForLambada():
|
||||
"""
|
||||
generate class for lambada task, which is to predict the final word of sentence.
|
||||
"""
|
||||
def __init__(self,
|
||||
decoder,
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
generate_length_dynamic=True,
|
||||
generate_length=1,
|
||||
max_iterations=200,
|
||||
stop_word_file=""):
|
||||
"""
|
||||
Args:
|
||||
decoder: decoder (Model): GPT2 model to do generation.
|
||||
config (object): configuration of given GPT2 model.
|
||||
tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
|
||||
generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True.
|
||||
max_iterations (int): choose the top k token according to selected probability, there k = `max_iterations`.
|
||||
generate_length (int): the final word max generated token length.
|
||||
stop_word_file (str): stop word file is used to be a stop-word filter.
|
||||
"""
|
||||
self.decoder = decoder
|
||||
self.config = config
|
||||
self.batch_size = config.batch_size
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_length_dynamic = generate_length_dynamic
|
||||
self.generate_length = generate_length
|
||||
self.max_iterations = max_iterations
|
||||
self.stop_word_set = self.build_stop_word(stop_word_file)
|
||||
|
||||
self.generator = Sample(decoder=decoder,
|
||||
config=config,
|
||||
batch_size=1,
|
||||
tokenizer=tokenizer,
|
||||
topk_num=1,
|
||||
topp_prob=1,
|
||||
return_ids=True
|
||||
)
|
||||
self.stop_eos = ['.', ',', '!', '?', '"', " '", " and", " says", " said"]
|
||||
|
||||
def build_stop_word(self, stop_word_file):
|
||||
stop_words_set = set()
|
||||
with open(stop_word_file, 'r', encoding="utf8") as file:
|
||||
for line in file.readlines():
|
||||
line = line.strip('\n')
|
||||
stop_words_set.add(line)
|
||||
return stop_words_set
|
||||
|
||||
def is_stop_word(self, word):
|
||||
flag = False
|
||||
if word in self.stop_word_set:
|
||||
flag = True
|
||||
return flag
|
||||
return flag
|
||||
|
||||
def generate_for_lambada(self, input_ids, logits, input_length):
|
||||
"""
|
||||
generation function for lambada task
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
logits (Tensor): the language model distribution.
|
||||
input_length (Tensor): store the context length which not including final word , and whole sentence length
|
||||
|
||||
return:
|
||||
batch_predict_words (list): the list of predict_words
|
||||
|
||||
"""
|
||||
batch_predict_words = ["" for _ in range(self.batch_size)]
|
||||
input_len_np = input_length.asnumpy()
|
||||
input_ids_list = input_ids.asnumpy().tolist()
|
||||
|
||||
extracted_logits = extract_logits(logits=logits, position=input_len_np) # [batch_size, vocab_size]
|
||||
extracted_logits = extracted_logits.asnumpy()
|
||||
sorted_ids = np.argsort(-extracted_logits, axis=-1)[::, :self.max_iterations] # [batch_size, max_iterations]
|
||||
|
||||
for batch_idx in range(self.batch_size):
|
||||
final_word_spos = input_len_np[batch_idx, 0]
|
||||
context_ids = input_ids_list[batch_idx][1:final_word_spos] # 1 for dropping <bos> token
|
||||
last_word_token_num = input_len_np[batch_idx, 1] - input_len_np[batch_idx, 0]
|
||||
|
||||
if self.generate_length_dynamic:
|
||||
generate_length = last_word_token_num
|
||||
else:
|
||||
generate_length = self.generate_length
|
||||
|
||||
for num in range(self.max_iterations):
|
||||
id_ = sorted_ids[batch_idx][num]
|
||||
source_ids = context_ids + [id_]
|
||||
source_string = self.tokenizer.decode(source_ids)
|
||||
generated_ids_list = self.generator.generate(input_str=source_string,
|
||||
generate_length=generate_length,
|
||||
do_sample=False)
|
||||
predict_tokens_ids = [id_] + generated_ids_list[0]
|
||||
predict_word = self.tokenizer.decode(predict_tokens_ids)
|
||||
|
||||
eos_pos = min(predict_word.find(word) if predict_word.find(word) >= 0
|
||||
else INF for word in self.stop_eos)
|
||||
if eos_pos == INF:
|
||||
continue
|
||||
else:
|
||||
predict_word = predict_word[:eos_pos]
|
||||
predict_word = predict_word.strip()
|
||||
if predict_word.find(" ") == -1:
|
||||
if self.is_stop_word(word=predict_word.lower()):
|
||||
continue
|
||||
batch_predict_words[batch_idx] = predict_word
|
||||
print("predict word: {}".format(predict_word))
|
||||
break
|
||||
return batch_predict_words
|
||||
|
||||
|
||||
class GenerateForTranslation():
|
||||
"""
|
||||
generate class for translation task
|
||||
"""
|
||||
def __init__(self,
|
||||
decoder,
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
generate_length=1,
|
||||
use_hint=True,
|
||||
select_first_sentence=True,
|
||||
topk_num=None,
|
||||
topp_prob=None,
|
||||
temperature=None
|
||||
):
|
||||
|
||||
self.decoder = decoder
|
||||
self.config = config
|
||||
self.batch_size = config.batch_size
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_length = generate_length
|
||||
self.use_hint = use_hint
|
||||
self.select_first_sentence = select_first_sentence
|
||||
|
||||
self.generator = Sample(decoder=decoder,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
topk_num=topk_num,
|
||||
topp_prob=topp_prob,
|
||||
temperature=temperature,
|
||||
min_tokens_to_keep=1,
|
||||
early_stop=False)
|
||||
|
||||
def generate_for_translation(self, input_ids):
|
||||
"""generation function for translation task"""
|
||||
source_str_list, ref_str_list = extract_string_from_tensor(input_ids=input_ids,
|
||||
mode="pair",
|
||||
config=self.config,
|
||||
tokenizer=self.tokenizer)
|
||||
final_predict_translation_list = [""] * self.batch_size
|
||||
|
||||
if self.use_hint:
|
||||
for index in range(self.batch_size):
|
||||
source_str_list[index] += " =" # now source_str is "english sentence ="
|
||||
|
||||
translation_str_list, _ = self.generator.generate(input_str=source_str_list,
|
||||
generate_length=self.generate_length,
|
||||
do_sample=False)
|
||||
|
||||
for index in range(self.batch_size):
|
||||
generate_str = translation_str_list[index].replace('<|endoftext|>', '')
|
||||
predict_translation = ""
|
||||
|
||||
# According to the GPT2 paper, the select_first_sentence will be set "True"
|
||||
if self.select_first_sentence:
|
||||
# check if there are number of select_sentence of sentences in generated text,
|
||||
# if not enough, it will return full generated string
|
||||
search_index = generate_str.find('.', 0, len(generate_str))
|
||||
if search_index == -1:
|
||||
search_index = len(generate_str)
|
||||
else:
|
||||
search_index = search_index + 1
|
||||
predict_translation = generate_str[:search_index]
|
||||
else:
|
||||
predict_translation = generate_str
|
||||
|
||||
if predict_translation == '':
|
||||
predict_translation = '<empty>'
|
||||
|
||||
final_predict_translation_list[index] = predict_translation
|
||||
|
||||
return final_predict_translation_list, ref_str_list
|
||||
|
||||
|
||||
class GenerateForReadComprehension():
|
||||
"""
|
||||
generate class for Reading Comprehension task.
|
||||
|
||||
Args:
|
||||
decoder: decoder (Model): GPT2 model to do generation.
|
||||
config (object): configuration of given GPT2 model.
|
||||
tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
|
||||
generate_length (int):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decoder,
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
generate_length=1,
|
||||
topk_num=None,
|
||||
topp_prob=None,
|
||||
temperature=None
|
||||
):
|
||||
|
||||
self.decoder = decoder
|
||||
self.config = config
|
||||
self.batch_size = config.batch_size
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_length = generate_length
|
||||
|
||||
self.generator = Sample(decoder=decoder,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
topk_num=topk_num,
|
||||
topp_prob=topp_prob,
|
||||
temperature=temperature,
|
||||
min_tokens_to_keep=1,
|
||||
)
|
||||
|
||||
def generate_for_read_comprehension(self, input_ids):
|
||||
"""generation function for reading comprehension task"""
|
||||
passage_str_list, answer_str_list = extract_string_from_tensor(input_ids=input_ids,
|
||||
mode="pair",
|
||||
config=self.config,
|
||||
tokenizer=self.tokenizer)
|
||||
passage = passage_str_list[:]
|
||||
|
||||
generate_str_list, _ = self.generator.generate(input_str=passage_str_list,
|
||||
generate_length=self.generate_length,
|
||||
do_sample=False)
|
||||
|
||||
pred_answer = []
|
||||
for batch_id in range(self.batch_size):
|
||||
new_str = generate_str_list[batch_id].replace('<|endoftext|>', '')
|
||||
index_a = new_str.find('.')
|
||||
index_b = new_str.find('Q:')
|
||||
if index_a != -1 or index_b != -1:
|
||||
index = max(index_a, index_b)
|
||||
pred_answer += [new_str[1:index]] # 1 represents skip the space in the beginning of the sentence
|
||||
else:
|
||||
pred_answer += [new_str]
|
||||
|
||||
return passage, pred_answer, answer_str_list
|
|
@ -0,0 +1,896 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
GPT-2 base model
|
||||
"""
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
from .weight_init import normal_weight, zero_weight
|
||||
|
||||
|
||||
class GPT2Config:
|
||||
"""
|
||||
Configuration for `GPT2Model`.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset. Default: 512.
|
||||
seq_length (int): Length of input sequence. Default: 1024.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 50257.
|
||||
d_model (int): Size of the bert encoder layers. Default: 768.
|
||||
num_hidden_layers (int): Number of hidden layers in the GPT2Transformer decoder block. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in the GPT2Transformer decoder block. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in the GPT2Transformer decoder block. Default: 3072.
|
||||
hidden_act (str): Activation function used in the GPT2Transformer decoder block. Default: "gelu".
|
||||
hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1.
|
||||
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from dataset.
|
||||
Default: True.
|
||||
summary_first_dropout (float): The dropout probability for GPT2CBTModel. Default: 0.1.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in GPT2Transformer. Default: mstype.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size=512,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
d_model=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=1024,
|
||||
initializer_range=0.02,
|
||||
input_mask_from_dataset=True,
|
||||
summary_first_dropout=0.1,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.input_mask_from_dataset = input_mask_from_dataset
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_dim (int): The size of each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
use_one_hot_embeddings=False,
|
||||
compute_type=mstype.float16):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.compute_type = compute_type
|
||||
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_dim], embedding_dim),
|
||||
name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""
|
||||
get embedding according to input_ids.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Returns:
|
||||
output (Tensor): the embedding matrix according to the input_ids.
|
||||
self.embedding_table (Parameter): the whole embedding table of GPT-2 model.
|
||||
"""
|
||||
input_shape = self.shape(input_ids) # [batch_size, seq_length]
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat) # [batch_size * seq_length]
|
||||
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
# precision transition fp32 -> fp16
|
||||
one_hot_ids = self.cast(one_hot_ids, self.compute_type)
|
||||
self.embedding_table = self.cast(self.embedding_table, self.compute_type)
|
||||
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
||||
output_for_reshape = self.cast(output_for_reshape, mstype.float32)
|
||||
|
||||
else:
|
||||
# [batch_size * seq_length * embedding_dim]
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
out_shape = input_shape + (self.embedding_dim,)
|
||||
output = self.reshape(output_for_reshape, out_shape) # [batch_size, seq_length, embedidng_dim]
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingPostprocessor(nn.Cell):
|
||||
"""
|
||||
Postprocessors apply positional embeddings to word embeddings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The size of each embedding vector.
|
||||
seq_length (int): the length of input sequence.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_dim=None,
|
||||
seq_length=None,
|
||||
max_position_embeddings=1024,
|
||||
dropout_prob=0.1):
|
||||
super(EmbeddingPostprocessor, self).__init__()
|
||||
|
||||
self.position_embedding_table = Parameter(
|
||||
normal_weight([max_position_embeddings, embedding_dim], embedding_dim), name='position_embeddings')
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.add = P.TensorAdd()
|
||||
self.gather = P.GatherV2()
|
||||
self.input_indices = Tensor(np.array([x for x in range(seq_length)]), mindspore.int32)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32)
|
||||
self.use_dropout = dropout_prob > 0
|
||||
|
||||
def construct(self, word_embeddings):
|
||||
"""
|
||||
Add the position embedding table to token embedding table
|
||||
Args:
|
||||
word_embeddings (Tensor): the token embedding matrix
|
||||
|
||||
Returns:
|
||||
output (Tensor): the final embedding matrix by adding the position embedding table
|
||||
to token embedding table.
|
||||
|
||||
"""
|
||||
position_embeddings = self.gather(self.position_embedding_table, self.input_indices, 0)
|
||||
position_embeddings = self.expand_dims(position_embeddings, 0)
|
||||
output = self.add(word_embeddings, position_embeddings)
|
||||
|
||||
if self.use_dropout:
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CastWrapper(nn.Cell):
|
||||
"""
|
||||
Cast wrapper
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dst_type=mstype.float32):
|
||||
super(CastWrapper, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
type cast
|
||||
Args:
|
||||
x (Tensor): the input which need to be cast.
|
||||
|
||||
Returns:
|
||||
Tensor, the cast output.
|
||||
"""
|
||||
return self.cast(x, self.dst_type)
|
||||
|
||||
|
||||
class LayerNorm(nn.Cell):
|
||||
"""
|
||||
Do layer norm
|
||||
|
||||
Args:
|
||||
in_channels (int): In channels number of layer norm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=None):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.layer_norm = nn.LayerNorm((in_channels,))
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
layer norm
|
||||
Args:
|
||||
input_tensor (Tensor): the input of layernorm.
|
||||
|
||||
Returns:
|
||||
Tensor, the output after layernorm.
|
||||
"""
|
||||
output = self.cast(input_tensor, mstype.float32)
|
||||
output = self.layer_norm(output)
|
||||
output = self.cast(output, self.get_dtype(input_tensor))
|
||||
return output
|
||||
|
||||
|
||||
class ResidualConnection(nn.Cell):
|
||||
"""
|
||||
Add residual to output.
|
||||
|
||||
Args:
|
||||
dropout_prob (float): Dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(self, dropout_prob=0.0):
|
||||
super(ResidualConnection, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.use_dropout = dropout_prob > 0
|
||||
|
||||
def construct(self, hidden_tensor, input_tensor):
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_tensor (Tensor): the output of sublayer.
|
||||
input_tensor (Tensor): the input tensor.
|
||||
|
||||
Returns:
|
||||
output (Tensor): with the same shape of hidden_tensor.
|
||||
|
||||
"""
|
||||
output = hidden_tensor
|
||||
if self.use_dropout:
|
||||
output = self.dropout(output)
|
||||
output = self.add(output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class Conv1D(nn.Cell):
|
||||
"""
|
||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
||||
|
||||
Basically works like a linear layer but the weights are transposed.
|
||||
|
||||
Args:
|
||||
nx (int): The number of input features.
|
||||
nf (int): The number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
nx,
|
||||
nf):
|
||||
super(Conv1D, self).__init__()
|
||||
self.nx = nx
|
||||
self.nf = nf
|
||||
self.weight = Parameter(normal_weight([nx, nf], nf), name='projection_weight')
|
||||
self.bias = Parameter(zero_weight(nf), name='projection_bias')
|
||||
|
||||
self.matmul = P.MatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): the input tensor of Conv1D with shape [batch_size * seq_length, nx]
|
||||
|
||||
Returns:
|
||||
output_tensor (Tensor): the output tensor with shape [batch_size * seq_length, self.nf]
|
||||
|
||||
"""
|
||||
# precision transition fp32 -> fp16
|
||||
input_tensor = self.cast(input_tensor, mstype.float16)
|
||||
fp16_weight = self.cast(self.weight, mstype.float16)
|
||||
output_tensor = self.matmul(input_tensor, fp16_weight) # [batch_size * seq_length, self.nf]
|
||||
output_tensor = self.cast(output_tensor, mstype.float32)
|
||||
output_tensor = self.bias_add(output_tensor, self.bias)
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
class MaskedSelfAttention(nn.Cell):
|
||||
"""
|
||||
Apply masked multi-head attention.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input datasets. Default: 512.
|
||||
d_model (int): Size of last dim of input tensor. Default: 768.
|
||||
seq_length (int): Length of input tensor sequence. Default: 1024.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
dim_per_head (int): Size of each attention head. Default: 64.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
|
||||
attention_dropout (float): The dropout probability for MultiheadAttention. Default: 0.0.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
Tensor, with the shape [batch_size, seq_length, d_model]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size=512,
|
||||
d_model=768,
|
||||
seq_length=1024,
|
||||
num_attention_heads=12,
|
||||
dim_per_head=64,
|
||||
has_attention_mask=True,
|
||||
do_return_2d_tensor=True,
|
||||
attention_dropout=0.0,
|
||||
compute_type=mstype.float16):
|
||||
super(MaskedSelfAttention, self).__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.d_model = d_model
|
||||
self.seq_length = seq_length
|
||||
self.num_heads = num_attention_heads
|
||||
self.dim_per_head = dim_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.compute_type = compute_type
|
||||
assert has_attention_mask
|
||||
|
||||
self.scale = Tensor([1.0 / math.sqrt(float(self.dim_per_head))], dtype=compute_type) # attention scale
|
||||
self.mask_data = Tensor([-10000.0,], dtype=compute_type)
|
||||
self.split_head_shape = (-1, self.seq_length, self.num_heads, self.dim_per_head)
|
||||
|
||||
self.c_attn = Conv1D(d_model, d_model * 3)
|
||||
self.c_proj = Conv1D(d_model, d_model)
|
||||
|
||||
self.split_for_qkv = P.Split(1, 3)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.matmul = P.BatchMatMul()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
|
||||
if do_return_2d_tensor:
|
||||
self.shape_return = (-1, d_model)
|
||||
else:
|
||||
self.shape_return = (-1, seq_length, d_model)
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.softmax_cast = P.Cast()
|
||||
self.dropout = nn.Dropout(1 - attention_dropout)
|
||||
self.use_attention_dropout = attention_dropout > 0
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""
|
||||
do masked self-attention
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): the embedding of input sequence tokens,
|
||||
shape with [batch_size * seq_length, d_mdoel]
|
||||
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
|
||||
shape with [batch_size, seq_len, seq_len].
|
||||
|
||||
Returns:
|
||||
outputs (Tensor): the output of masked self-attention, shape with [batch_size * seq_len, d_model].
|
||||
"""
|
||||
input_tensor = self.c_attn(input_tensor) # [batch_size * seq_length, d_model*3]---> eg.[1 * 3, 2304]
|
||||
input_tensor = self.split_for_qkv(input_tensor)
|
||||
query = input_tensor[0] # [batch_size * seq_length, d_model] ---> eg. [1 * 3, 768]
|
||||
key = input_tensor[1]
|
||||
value = input_tensor[2]
|
||||
|
||||
# split head
|
||||
query = self.reshape(query, self.split_head_shape)
|
||||
# query shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
|
||||
query = self.transpose(query, self.trans_shape)
|
||||
|
||||
key = self.reshape(key, self.split_head_shape)
|
||||
# key shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
|
||||
key = self.transpose(key, self.trans_shape)
|
||||
|
||||
value = self.reshape(value, self.split_head_shape)
|
||||
# value shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64]
|
||||
value = self.transpose(value, self.trans_shape)
|
||||
|
||||
# attention and mask
|
||||
# precision transition fp32 -> fp16
|
||||
query = self.cast(query, self.compute_type)
|
||||
key = self.cast(key, self.compute_type)
|
||||
attention_scores = self.matmul_trans_b(query, key) # [batch_size, num_heads, seq_len, seq_len]
|
||||
attention_scores = self.cast(attention_scores, self.compute_type)
|
||||
attention_scores = self.multiply(attention_scores, self.scale)
|
||||
|
||||
if self.has_attention_mask:
|
||||
attention_mask = self.expand_dims(attention_mask, 1) # [batch_size, 1, seq_length, seq_length]
|
||||
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||
self.cast(attention_mask, self.get_dtype(attention_scores))) # fp16
|
||||
adder = self.multiply(multiply_out, self.mask_data)
|
||||
adder = self.cast(adder, mstype.float32)
|
||||
attention_scores = self.cast(attention_scores, mstype.float32)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
|
||||
attention_scores = self.softmax_cast(attention_scores, mstype.float32)
|
||||
attention_probs = self.softmax(attention_scores) # [batch_size, num_heads, seq_len, seq_len]
|
||||
attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key))
|
||||
|
||||
if self.use_attention_dropout:
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
value = self.cast(value, mstype.float16)
|
||||
attention_probs = self.cast(attention_probs, self.compute_type)
|
||||
outputs = self.matmul(attention_probs, value) # [batch_size, num_heads, seq_len, dim_per_head]
|
||||
outputs = self.cast(outputs, mstype.float32)
|
||||
|
||||
# merge heads
|
||||
outputs = self.transpose(outputs, self.trans_shape) # [batch_size, seq_len, num_heads, dim_per_head]
|
||||
outputs = self.reshape(outputs,
|
||||
self.shape_return) # default True, the outputs shape [batch_size * seq_len, d_model]
|
||||
|
||||
# project
|
||||
outputs = self.c_proj(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class FeedForward(nn.Cell):
|
||||
"""
|
||||
Apply two-layer feed forward
|
||||
|
||||
Args:
|
||||
in_channels (int): Size of the input layer. Default: 768.
|
||||
out_channels (int): Size of the output layers. Default: 768.
|
||||
hidden_size (int): Size of the hidden layer. Default: 3072.
|
||||
hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=786,
|
||||
out_channels=768,
|
||||
hidden_size=3072,
|
||||
hidden_dropout=0.1):
|
||||
super(FeedForward, self).__init__()
|
||||
|
||||
self.c_fc = Conv1D(in_channels, hidden_size)
|
||||
self.c_proj = Conv1D(hidden_size, out_channels)
|
||||
# self.gelu = Gelu()
|
||||
|
||||
self.layernorm = LayerNorm(in_channels=in_channels)
|
||||
self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout)
|
||||
self.gelu_act = P.Gelu()
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout)
|
||||
self.use_dropout = hidden_dropout > 0
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
FeedForward construct function with layernorm and residual connection.
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): the input of FeedForward layer, shape with [batch_szie * seq_len, d_model].
|
||||
|
||||
Returns:
|
||||
output (Tensor): the output of FeedForward layer, shape with [batch_szie * seq_len, d_model]
|
||||
"""
|
||||
# LayerNorm
|
||||
output = self.layernorm(input_tensor)
|
||||
# Feed Forward
|
||||
output = self.c_fc(output) # [batch_szie * seq_len, d_model * 4]
|
||||
output = self.gelu_act(output)
|
||||
# output = self.gelu(output)
|
||||
output = self.c_proj(output) # [batch_szie * seq_len, d_model]
|
||||
if self.use_dropout:
|
||||
output = self.dropout(output)
|
||||
# Add
|
||||
output = self.residual_connect(output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class MaskedMultiHeadAttention(nn.Cell):
|
||||
"""
|
||||
Masked multi-head attention block.
|
||||
"""
|
||||
def __init__(self,
|
||||
batch_size=512,
|
||||
seq_length=2014,
|
||||
d_model=768,
|
||||
num_attention_heads=12,
|
||||
attention_dropout=0.02,
|
||||
hidden_dropout=0.1,
|
||||
has_attention_mask=True,
|
||||
compute_type=mstype.float16
|
||||
):
|
||||
super(MaskedMultiHeadAttention, self).__init__()
|
||||
if d_model % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (d_model, num_attention_heads))
|
||||
|
||||
self.dim_per_head = int(d_model / num_attention_heads) # 64
|
||||
|
||||
self.masked_self_attention = MaskedSelfAttention(
|
||||
batch_size=batch_size,
|
||||
d_model=d_model,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
dim_per_head=self.dim_per_head,
|
||||
has_attention_mask=has_attention_mask,
|
||||
do_return_2d_tensor=True,
|
||||
attention_dropout=attention_dropout,
|
||||
compute_type=compute_type
|
||||
)
|
||||
|
||||
self.layer_norm = LayerNorm(in_channels=d_model)
|
||||
self.residual_connection = ResidualConnection()
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.new_shape = (-1, d_model)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""
|
||||
do masked multi head self-attention with layernorm and residual_connection.
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): the embedding matrix of input sequence tokens,
|
||||
shape with [batch_size * seq_length, d_mdoel]
|
||||
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
|
||||
shape with [batch_size, seq_len, seq_len].
|
||||
|
||||
Returns:
|
||||
outputs (Tensor): the output of MaskedMultiHeadAttention, shape with [batch_size * seq_len, d_model].
|
||||
"""
|
||||
# LayerNorm
|
||||
output_tensor = self.layer_norm(input_tensor)
|
||||
# masked multi-head attention
|
||||
# attention_output shape [batch_size * seq_length, d_model]
|
||||
attention_output = self.masked_self_attention(output_tensor, attention_mask)
|
||||
# residual connection
|
||||
output = self.residual_connection(attention_output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class DecoderBlock(nn.Cell):
|
||||
"""
|
||||
decoder block used in GPT2.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset. Default: 512.
|
||||
seq_length (int): Length of input sequence. Default: 1024.
|
||||
d_model (int): Size of the GPT2 decoder layers. Default: 768.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
||||
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.02.
|
||||
hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size=512,
|
||||
seq_length=1024,
|
||||
d_model=768,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_dropout=0.02,
|
||||
hidden_dropout=0.1,
|
||||
has_attention_mask=True,
|
||||
compute_type=mstype.float16
|
||||
):
|
||||
super(DecoderBlock, self).__init__()
|
||||
if d_model % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (d_model, num_attention_heads))
|
||||
|
||||
self.dim_per_head = int(d_model / num_attention_heads) # 64
|
||||
|
||||
self.masked_multi_head_attention = MaskedMultiHeadAttention(
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length,
|
||||
d_model=d_model,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
hidden_dropout=hidden_dropout,
|
||||
has_attention_mask=has_attention_mask,
|
||||
compute_type=compute_type
|
||||
)
|
||||
self.feedforward = FeedForward(
|
||||
in_channels=d_model,
|
||||
out_channels=d_model,
|
||||
hidden_size=intermediate_size,
|
||||
hidden_dropout=hidden_dropout
|
||||
)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.new_shape = (-1, d_model)
|
||||
|
||||
def construct(self, input_tensor, attention_mask): # input tensor shape[batch_size, seq_length, d_model]
|
||||
"""
|
||||
DecoderBlock with masked_multi_head_attention and feedforward.
|
||||
Args:
|
||||
input_tensor (Tensor): the embedding matrix of input sequence tokens,
|
||||
shape with [batch_size * seq_length, d_mdoel]
|
||||
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
|
||||
shape with [batch_size, seq_len, seq_len].
|
||||
|
||||
Returns:
|
||||
outputs (Tensor): the output of DecoderBlock, shape with [batch_size * seq_len, d_model].
|
||||
"""
|
||||
input_tensor = self.reshape(input_tensor, self.new_shape)
|
||||
|
||||
# masked multi head attention with ln, res
|
||||
attention_output = self.masked_multi_head_attention(input_tensor, attention_mask)
|
||||
# feed forward with ln, res
|
||||
output = self.feedforward(attention_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GPT2Transformer(nn.Cell):
|
||||
"""
|
||||
Multi-layer GPT2 transformer.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset. Default: 512.
|
||||
d_model (int): Size of the decoder layers. Default: 768.
|
||||
seq_length (int): Length of input sequence. Default: 1024.
|
||||
num_hidden_layers (int): Number of hidden layers in decoder cells. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in decoder cells. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in decoder cells. Default: 3072.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: True.
|
||||
attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1.
|
||||
hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size=512,
|
||||
d_model=768,
|
||||
seq_length=1024,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
has_attention_mask=True,
|
||||
attention_dropout=0.1,
|
||||
hidden_dropout=0.1,
|
||||
compute_type=mstype.float16):
|
||||
super(GPT2Transformer, self).__init__()
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = DecoderBlock(batch_size=batch_size,
|
||||
seq_length=seq_length,
|
||||
d_model=d_model,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_dropout=attention_dropout,
|
||||
hidden_dropout=hidden_dropout,
|
||||
has_attention_mask=has_attention_mask,
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.new_shape = (-1, d_model)
|
||||
# self.out_shape = (batch_size, seq_length, d_model)
|
||||
self.out_shape = (-1, seq_length, d_model)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""
|
||||
Do Multi DecoderBlock.
|
||||
|
||||
Args:
|
||||
input_tensor (Tensor): the embedding matrix of input sequence tokens,
|
||||
shape with [batch_size * seq_length, d_mdoel]
|
||||
attention_mask (Tensor): mask to avoid performing attention on padding token indices,
|
||||
shape with [batch_size, seq_len, seq_len].
|
||||
|
||||
Returns:
|
||||
outputs (Tensor): the output of GPT2Transformer, shape with [batch_size * seq_len, d_model].
|
||||
"""
|
||||
prev_output = self.reshape(input_tensor, self.new_shape)
|
||||
for layer_module in self.layers:
|
||||
layer_output = layer_module(prev_output, attention_mask)
|
||||
prev_output = layer_output
|
||||
|
||||
output = self.reshape(prev_output, self.out_shape)
|
||||
return output
|
||||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for GPT2Model.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||
self.input_mask = None
|
||||
self.compute_type = config.compute_type
|
||||
|
||||
assert self.input_mask_from_dataset
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.BatchMatMul()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
# mask future positions
|
||||
ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))
|
||||
self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
||||
|
||||
def construct(self, input_mask, mask_future=True):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_mask (Tensor): Tensor mask vectors with shape [batch_size, seq_len].
|
||||
mask_future (bool): Whether mask future (for decoder training). Default: True.
|
||||
|
||||
Returns:
|
||||
attention_mask (Tensor): shape [batch_size, seq_len, seq_len].
|
||||
"""
|
||||
input_shape = self.shape(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1]) # [batch_size, 1, seq_len]
|
||||
shape_left = input_shape + (1,) # [batch_size, seq_len, 1]
|
||||
|
||||
input_mask = self.cast(input_mask, mstype.float32)
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
|
||||
# precision transition fp32 -> fp16
|
||||
mask_left = self.cast(mask_left, self.compute_type)
|
||||
mask_right = self.cast(mask_right, self.compute_type)
|
||||
attention_mask = self.matmul(mask_left, mask_right) # [batch_szie, seq_len, seq_len]
|
||||
attention_mask = self.cast(attention_mask, mstype.float32)
|
||||
if mask_future:
|
||||
attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class GPT2Model(nn.Cell):
|
||||
"""
|
||||
Decoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for GPT2Model.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False
|
||||
):
|
||||
super(GPT2Model, self).__init__()
|
||||
self.config = copy.deepcopy(config)
|
||||
self.is_training = is_training
|
||||
if not is_training:
|
||||
self.config.hidden_dropout = 0.0
|
||||
self.config.attention_dropout = 0.0
|
||||
|
||||
self.input_mask_from_dataset = self.config.input_mask_from_dataset
|
||||
self.batch_size = self.config.batch_size
|
||||
self.seq_length = self.config.seq_length
|
||||
self.d_model = self.config.d_model
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
self.embedding_dim = self.config.d_model
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
|
||||
self.gpt2_embedding_lookup = EmbeddingLookup(
|
||||
vocab_size=self.config.vocab_size,
|
||||
embedding_dim=self.embedding_dim,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
compute_type=self.config.compute_type
|
||||
)
|
||||
self.gpt2_embedding_postprocess = EmbeddingPostprocessor(
|
||||
embedding_dim=self.embedding_dim,
|
||||
seq_length=self.seq_length,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
dropout_prob=self.config.hidden_dropout
|
||||
)
|
||||
self.gpt2_decoder = GPT2Transformer(
|
||||
batch_size=self.batch_size,
|
||||
d_model=self.d_model,
|
||||
seq_length=self.seq_length,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
intermediate_size=self.config.intermediate_size,
|
||||
has_attention_mask=True,
|
||||
attention_dropout=self.config.attention_dropout,
|
||||
hidden_dropout=self.config.hidden_dropout,
|
||||
compute_type=self.config.compute_type
|
||||
)
|
||||
|
||||
self.cast_compute_type = CastWrapper(dst_type=self.config.compute_type)
|
||||
self.layer_norm = LayerNorm(in_channels=self.d_model)
|
||||
self.dropout = nn.Dropout(1 - self.config.hidden_dropout)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.config)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.new_shape = (-1, self.d_model)
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
Construct network.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len],
|
||||
where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
decoder_output (Tensor): shape[batch_size, seq_len, d_model].
|
||||
embedding_tables (Tensor): word embeddings with shape [vocab_size, d_model]
|
||||
"""
|
||||
# Embedding
|
||||
word_embeddings, embedding_tables = self.gpt2_embedding_lookup(input_ids)
|
||||
embedding_output = self.gpt2_embedding_postprocess(word_embeddings)
|
||||
embedding_output = self.dropout(embedding_output)
|
||||
|
||||
# Attention mask with shape [batch_size, seq_len, seq_len]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask, True)
|
||||
|
||||
# GPT2 decoder
|
||||
decoder_output = self.gpt2_decoder(
|
||||
self.cast_compute_type(embedding_output),
|
||||
self.cast_compute_type(attention_mask)
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
decoder_output = self.reshape(decoder_output, self.new_shape)
|
||||
decoder_output = self.layer_norm(decoder_output)
|
||||
decoder_output = self.reshape(decoder_output, (-1, self.seq_length, self.d_model))
|
||||
|
||||
return decoder_output, embedding_tables
|
||||
|
||||
def get_token_embeddings(self):
|
||||
return self.gpt2_embedding_lookup.embedding_table.asnumpy()
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""clip gradient"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
# pylint: disable=consider-using-in
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
from .finetune_eval_config import gpt2_net_cfg
|
||||
|
||||
|
||||
def create_language_model_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""):
|
||||
"""create dataset like language model task"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["input_ids", "input_mask", "label_ids"],
|
||||
shuffle=do_shuffle,
|
||||
num_shards=device_num,
|
||||
shard_id=rank_id)
|
||||
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
|
||||
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label_ids")
|
||||
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
print("dataset size: {}".format(ds.get_dataset_size()))
|
||||
print("repeat count: {}".format(ds.get_repeat_count()))
|
||||
print("output shape: {}".format(ds.output_shapes()))
|
||||
print("output type: {}".format(ds.output_types()))
|
||||
print("============== create dataset successful ===============")
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def create_cbt_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=False, dataset_path=""):
|
||||
"""create dataset for cbt task"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["input_ids", "input_mask", "input_length", "mc_labels"],
|
||||
shuffle=do_shuffle,
|
||||
num_shards=device_num,
|
||||
shard_id=rank_id)
|
||||
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
|
||||
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_length")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="mc_labels")
|
||||
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
print("dataset size: {}".format(ds.get_dataset_size()))
|
||||
print("repeat count: {}".format(ds.get_repeat_count()))
|
||||
print("output shape: {}".format(ds.output_shapes()))
|
||||
print("output type: {}".format(ds.output_types()))
|
||||
print("============== create CBT LM dataset successful ===============")
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def create_lambada_control_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""):
|
||||
"""create dataset for lambada task"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["input_ids", "input_mask", "input_length"],
|
||||
shuffle=do_shuffle,
|
||||
num_shards=device_num,
|
||||
shard_id=rank_id)
|
||||
print("batch_size: {}".format(gpt2_net_cfg.batch_size))
|
||||
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_mask")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="input_length")
|
||||
ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
print("dataset size: {}".format(ds.get_dataset_size()))
|
||||
print("repeat count: {}".format(ds.get_repeat_count()))
|
||||
print("output shape: {}".format(ds.output_shapes()))
|
||||
print("output type: {}".format(ds.output_types()))
|
||||
print("============== create dataset successful ===============")
|
||||
return ds
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GPT-2 finetune config and GPT-2 model config"""
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .GPT2_model import GPT2Config
|
||||
|
||||
cfg = edict({
|
||||
'gpt2_network': 'large',
|
||||
'optimizer': 'Lamb',
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 1e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'learning_rate': 1e-5,
|
||||
'end_learning_rate': 1e-7,
|
||||
'power': 1.0,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
})
|
||||
|
||||
"""
|
||||
three kinds of GPT2 model version
|
||||
"""
|
||||
if cfg.gpt2_network == 'small':
|
||||
gpt2_net_cfg = GPT2Config(
|
||||
batch_size=8,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
d_model=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=1024,
|
||||
initializer_range=0.02,
|
||||
input_mask_from_dataset=True,
|
||||
summary_first_dropout=0.1,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
if cfg.gpt2_network == 'medium':
|
||||
gpt2_net_cfg = GPT2Config(
|
||||
batch_size=8,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
d_model=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=1024,
|
||||
initializer_range=0.02,
|
||||
input_mask_from_dataset=True,
|
||||
summary_first_dropout=0.1,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
if cfg.gpt2_network == 'large':
|
||||
gpt2_net_cfg = GPT2Config(
|
||||
batch_size=6,
|
||||
seq_length=1024,
|
||||
vocab_size=50257,
|
||||
d_model=1280,
|
||||
num_hidden_layers=36,
|
||||
num_attention_heads=20,
|
||||
intermediate_size=5120,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=1024,
|
||||
initializer_range=0.02,
|
||||
input_mask_from_dataset=True,
|
||||
summary_first_dropout=0.1,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -0,0 +1,464 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GPT-2 finetune for downstream task"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
|
||||
from .utils.CrossEntropy import CrossEntropyCalculationWithMask
|
||||
from .clip_grad_utils import clip_grad
|
||||
from .GPT2ForLanguageModel import GPT2LanguageModel
|
||||
from .GPT2ForLambada import GPT2LambadaModel
|
||||
from .GPT2ForCBT import GPT2CBTModel
|
||||
from .GPT2ForTranslation import GPT2TranslationModel
|
||||
from .GPT2ForReadComprehension import GPT2CoQAModel
|
||||
from .GPT2ForSummarization import GPT2SummarizationModel
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class GPT2FinetuneCell(nn.Cell):
|
||||
"""
|
||||
Specifically defined for finetuning where only three inputs tensor are needed.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(GPT2FinetuneCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.gpu_target = False
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
label_ids,
|
||||
sens=None):
|
||||
"""
|
||||
GPT2 Finetune.
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
|
||||
"""
|
||||
|
||||
weights = self.weights
|
||||
init = False
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
label_ids)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
if not self.gpu_target:
|
||||
init = self.alloc_status()
|
||||
clear_before_grad = self.clear_before_grad(init)
|
||||
F.control_depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
flag = self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
F.control_depend(grads, flag)
|
||||
F.control_depend(flag, flag_sum)
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
if self.is_distributed:
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class GPT2LM(nn.Cell):
|
||||
"""
|
||||
Train interface for Language Modeling finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
|
||||
super(GPT2LM, self).__init__()
|
||||
self.gpt2 = GPT2LanguageModel(config, is_training, use_one_hot_embeddings)
|
||||
self.num_labels = config.vocab_size
|
||||
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
|
||||
num_labels=self.num_labels,
|
||||
config=config)
|
||||
self.is_training = is_training
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_mask, label_ids):
|
||||
"""
|
||||
construct function for Language Modeling
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
|
||||
otherwise, return the computed loss.
|
||||
"""
|
||||
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
|
||||
|
||||
if self.is_training:
|
||||
shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
|
||||
label_ids = label_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss(shift_logits, label_ids, input_mask)
|
||||
return loss
|
||||
|
||||
return lm_logits
|
||||
|
||||
|
||||
class GPT2Lambada(nn.Cell):
|
||||
"""
|
||||
Train interface for Lambada finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(GPT2Lambada, self).__init__()
|
||||
self.gpt2 = GPT2LambadaModel(config, is_training, use_one_hot_embeddings)
|
||||
self.num_labels = config.vocab_size
|
||||
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
|
||||
num_labels=self.num_labels,
|
||||
config=config)
|
||||
self.is_training = is_training
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_mask, label_ids=None):
|
||||
"""
|
||||
construct function for Lambada task
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
|
||||
otherwise, return the computed loss.
|
||||
"""
|
||||
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
|
||||
|
||||
if self.is_training:
|
||||
shift_logits = lm_logits[:, :-1, :] # [batch_size, seq_length - 1, vocab_size]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
|
||||
label_ids = label_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss(shift_logits, label_ids, input_mask)
|
||||
|
||||
return loss
|
||||
|
||||
return lm_logits
|
||||
|
||||
|
||||
class GPT2CBT(nn.Cell):
|
||||
"""
|
||||
Train interface for Children's Book Test finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
|
||||
super(GPT2CBT, self).__init__()
|
||||
self.gpt2 = GPT2CBTModel(config, is_training, use_one_hot_embeddings)
|
||||
self.num_labels = config.vocab_size
|
||||
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
|
||||
num_labels=self.num_labels,
|
||||
config=config)
|
||||
self.is_training = is_training
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_mask):
|
||||
"""
|
||||
construct function for CBT task
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
|
||||
otherwise, return the computed loss.
|
||||
"""
|
||||
lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
|
||||
|
||||
if self.is_training:
|
||||
shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
|
||||
label_ids = input_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss(shift_logits, label_ids, input_mask)
|
||||
return loss
|
||||
|
||||
return lm_logits
|
||||
|
||||
|
||||
class GPT2Translation(nn.Cell):
|
||||
"""
|
||||
Train interface for Translation finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(GPT2Translation, self).__init__()
|
||||
self.gpt2 = GPT2TranslationModel(config, is_training, use_one_hot_embeddings)
|
||||
self.num_labels = config.vocab_size
|
||||
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
|
||||
num_labels=self.num_labels,
|
||||
config=config)
|
||||
self.is_training = is_training
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_ids, input_mask, label_ids):
|
||||
"""
|
||||
construct function for Translation task
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
|
||||
|
||||
Returns:
|
||||
translation_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
|
||||
otherwise, return the computed loss.
|
||||
"""
|
||||
translation_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size]
|
||||
translation_logits = self.log_softmax(translation_logits)
|
||||
|
||||
if self.is_training:
|
||||
shift_logits = translation_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size]
|
||||
label_ids = label_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss(shift_logits, label_ids, input_mask)
|
||||
return loss
|
||||
|
||||
return translation_logits
|
||||
|
||||
|
||||
class GPT2Summarization(nn.Cell):
|
||||
"""
|
||||
Train interface for Summary finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False):
|
||||
super(GPT2Summarization, self).__init__()
|
||||
self.gpt2 = GPT2SummarizationModel(config, is_training, use_one_hot_embeddings)
|
||||
self.is_training = is_training
|
||||
self.last_idx = (-1,)
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.batch_size = config.batch_size
|
||||
self.seq_length = config.seq_length
|
||||
self.vocab_size = config.vocab_size
|
||||
self.cast = P.Cast()
|
||||
self.loss_function = CrossEntropyCalculationWithMask(num_labels=self.vocab_size,
|
||||
is_training=self.is_training,
|
||||
config=config)
|
||||
|
||||
def construct(self, input_ids, input_mask, label_ids):
|
||||
"""
|
||||
construct function for Language Modeling
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
|
||||
|
||||
Returns:
|
||||
loss (mstype.float32): if is_training is True, return the computed loss.
|
||||
"""
|
||||
output = self.gpt2(input_ids, input_mask)
|
||||
|
||||
shift_logits = output[::, :-1, ::]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.vocab_size))
|
||||
shift_logits = self.log_softmax(shift_logits)
|
||||
label_ids = label_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss_function(shift_logits, label_ids, input_mask)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class GPT2CoQA(nn.Cell):
|
||||
"""
|
||||
Train interface for Reading Comprehension finetuning task.
|
||||
|
||||
Args:
|
||||
config (class): the configuration of GPT-2 model.
|
||||
is_training (bool): whether to train.
|
||||
use_one_hot_embeddings (bool): whether to use onehot embeddings.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(GPT2CoQA, self).__init__()
|
||||
self.gpt2 = GPT2CoQAModel(config, is_training, use_one_hot_embeddings)
|
||||
self.num_labels = config.vocab_size
|
||||
self.loss = CrossEntropyCalculationWithMask(is_training=is_training,
|
||||
num_labels=self.num_labels,
|
||||
config=config)
|
||||
self.is_training = is_training
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||
|
||||
def construct(self, input_ids, input_mask, label_ids=None):
|
||||
"""
|
||||
construct function for reading comprehension task
|
||||
|
||||
Args:
|
||||
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sequence padding mask, where 0 indicates padding position.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary
|
||||
|
||||
Returns:
|
||||
lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
|
||||
otherwise, return the computed loss.
|
||||
"""
|
||||
lm_logits = self.gpt2(input_ids, input_mask)
|
||||
lm_logits = self.log_softmax(lm_logits)
|
||||
|
||||
if self.is_training:
|
||||
shift_logits = lm_logits[::, :-1, ::]
|
||||
shift_logits = self.reshape(shift_logits, (-1, self.num_labels))
|
||||
label_ids = label_ids[::, 1:]
|
||||
input_mask = input_mask[::, 1:]
|
||||
|
||||
loss = self.loss(shift_logits, label_ids, input_mask)
|
||||
return loss
|
||||
|
||||
return lm_logits
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Calculate Cross Entropy With Mask"""
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropyCalculationWithMask(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, is_training=None, num_labels=None, config=None):
|
||||
super(CrossEntropyCalculationWithMask, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
self.num_labels = num_labels
|
||||
if config is not None:
|
||||
# for PPL calculation in evaluation
|
||||
self.input_mask_length = Tensor(config.batch_size * (config.seq_length - 1), mstype.float32)
|
||||
|
||||
def construct(self, logits, label_ids, input_mask=None):
|
||||
"""
|
||||
Calculate loss
|
||||
|
||||
Args:
|
||||
logits (Tensor): the probability distribution over vocabulary.
|
||||
label_ids (Tensor): the indices of input sequence tokens in the vocabulary.
|
||||
input_mask (Tensor): input sentences padding mask, where 0 indicates padding position.
|
||||
|
||||
Returns:
|
||||
return_value (Tensor, mstype.float32): if is_training is False, directly return the logits, otherwise,
|
||||
return the computed loss.
|
||||
"""
|
||||
|
||||
# logits [batch * (seq_length-1), vocab_size] label_ids [batch, seq_length-1]
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx) # label_ids [batch * (seq_length-1)]
|
||||
one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value,
|
||||
self.off_value) # [batch * (seq_length-1), vocab_size]
|
||||
per_example_loss = self.neg(
|
||||
self.reduce_sum(one_hot_labels * logits, self.last_idx)) # [batch * (seq_length-1)]
|
||||
|
||||
# for PPL calculation in evaluation
|
||||
if input_mask is not None:
|
||||
input_mask = self.cast(self.reshape(input_mask, self.last_idx),
|
||||
mstype.float32) # [batch * (seq_length-1)]
|
||||
|
||||
valid_loss_sum = self.reduce_sum(input_mask * per_example_loss, ())
|
||||
valid_element_sum = self.reduce_sum(input_mask, ()) + self.cast(F.tuple_to_array((1e-5,)),
|
||||
mstype.float32)
|
||||
loss = valid_loss_sum / valid_element_sum
|
||||
else:
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx) # a number
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0 # [batch * (seq_length-1), vocab_size]
|
||||
|
||||
return return_value
|
|
@ -0,0 +1,488 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""data preprocess for downstream task"""
|
||||
import re
|
||||
import json
|
||||
import random
|
||||
|
||||
|
||||
def lambada_detokenizer(string):
|
||||
string = re.sub(r"``", "-DQ-", string)
|
||||
string = re.sub(r"`", "-SQ-", string)
|
||||
string = re.sub(r"''", "-DQ-", string)
|
||||
string = re.sub(r" '", "-SQ-", string)
|
||||
string = re.sub("-DQ-", '"', string)
|
||||
string = re.sub("-SQ-", "'", string)
|
||||
string = re.sub(r"([,?!.]['\"])(\w)", "\g<1> \g<2>", string)
|
||||
# contractions
|
||||
string = string.replace("s '", "s'")
|
||||
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
||||
# number separators
|
||||
string = string.replace(" @-@ ", "-")
|
||||
string = string.replace(" @,@ ", ",")
|
||||
string = string.replace(" @.@ ", ".")
|
||||
# miscellaneous
|
||||
string = string.replace("= = = =", "====")
|
||||
string = string.replace("= = =", "===")
|
||||
string = string.replace("= =", "==")
|
||||
string = string.replace(" " + chr(176) + " ", chr(176))
|
||||
string = string.replace(" \n", "\n")
|
||||
string = string.replace("\n ", "\n")
|
||||
string = string.replace(" N ", " 1 ")
|
||||
string = string.replace(" 's", "'s")
|
||||
string = string.replace(" 'd", "'d")
|
||||
string = string.replace(" '", "'")
|
||||
string = string.replace(" n't", "n't")
|
||||
|
||||
string = string.replace(" .", ".")
|
||||
string = string.replace(" ,", ",")
|
||||
string = string.replace(" !", "!")
|
||||
string = string.replace(" ?", "?")
|
||||
string = string.replace(" :", ":")
|
||||
string = string.replace(" ;", ";")
|
||||
|
||||
string = string.replace(" : ", ": ")
|
||||
string = string.replace(" ; ", "; ")
|
||||
string = string.replace(" ,'", ",'")
|
||||
string = string.replace(" .'", ".'")
|
||||
string = string.replace(" !'", "!'")
|
||||
string = string.replace(" ?'", "?'")
|
||||
string = string.replace("~", "")
|
||||
string = string.replace("---", "")
|
||||
string = string.replace("<", "")
|
||||
string = string.replace(">", "")
|
||||
string = string.replace("#", "")
|
||||
|
||||
string = string.replace(', "', ',"')
|
||||
string = string.replace('. "', '."')
|
||||
string = string.replace('! "', '!"')
|
||||
string = string.replace('? "', '?"')
|
||||
string = string.replace('"" ', '" "')
|
||||
string = string.replace('• • •', '')
|
||||
|
||||
# sensitive word process
|
||||
string = string.replace("f ** k", "fuck")
|
||||
string = string.replace("f ** king", "fucking")
|
||||
string = string.replace("f ** ked", "fucked")
|
||||
string = string.replace("c ** k", "cock")
|
||||
string = string.replace("br ** sts", "breasts")
|
||||
string = string.replace("n ** ples", "nipples")
|
||||
string = string.replace("ni ** les", "nipples")
|
||||
string = string.replace("a ** hole", "asshole")
|
||||
string = string.replace("ass ** le", "asshole")
|
||||
string = string.replace("p ** sy", "pussy")
|
||||
string = string.replace("pu ** y", "pussy")
|
||||
string = string.replace("na ** d", "naked")
|
||||
string = string.replace("nak * d", "naked")
|
||||
string = string.replace("cli ** x", "climax")
|
||||
string = string.replace("h * ps", "hips")
|
||||
string = string.replace("c * ck", "cock")
|
||||
string = string.replace("coc ** ne", "cocaine")
|
||||
string = string.replace("*", "")
|
||||
|
||||
string = re.sub(" "," ",string)
|
||||
string = re.sub(" "," ",string)
|
||||
string = re.sub(" "," ",string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def lambada_dataset_preprocess(input_file, output_file):
|
||||
sentences = []
|
||||
count = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
line = lambada_detokenizer(line)
|
||||
split_sentence_list = line.split()
|
||||
final_word = split_sentence_list[-1]
|
||||
context = split_sentence_list[:-1]
|
||||
new_sentence = ' '.join(context) + '\t' + ' ' + final_word
|
||||
sentences.append(new_sentence)
|
||||
count += 1
|
||||
print('read {} file finished!\n total count = {}'.format(input_file, count))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
f.write(sentence)
|
||||
f.write('\n')
|
||||
count -= 1
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, count))
|
||||
|
||||
|
||||
def get_gold_answer_id(gold_answer, candidate_answer_list):
|
||||
id_ = 0
|
||||
for candidate in candidate_answer_list:
|
||||
if gold_answer == candidate:
|
||||
return id_
|
||||
id_ += 1
|
||||
|
||||
|
||||
def get_passage_string(passage_string, candidate_answer, final_sentence, gold_answer_id):
|
||||
"""
|
||||
concat each candidate answer to the rest_sentence
|
||||
Args:
|
||||
candidate_answer (list): store each candidate answers
|
||||
final_sentence (str): the 21st sentence string with "XXXXX"
|
||||
gold_answer_id (int): the id of correct answer.
|
||||
|
||||
return:
|
||||
candidate_passage (list): the length of candidate_sentence equals to length of candidate_answer.
|
||||
"""
|
||||
candidate_passage = []
|
||||
for answer in candidate_answer:
|
||||
passage = passage_string + " " + final_sentence
|
||||
passage = passage.replace(" XXXXX", "\t XXXXX")
|
||||
final_passage = passage.replace("XXXXX", answer)
|
||||
whole_passage = final_passage + "\t" + str(gold_answer_id)
|
||||
candidate_passage.append(whole_passage)
|
||||
|
||||
return candidate_passage
|
||||
|
||||
|
||||
def cbt_dataset_preprocess(input_file, output_file):
|
||||
passages = []
|
||||
candidate_passage_list = []
|
||||
passage_string = ""
|
||||
count = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
single_sentence = line.split(' ', 1)
|
||||
line_id = int(single_sentence[0])
|
||||
string = single_sentence[1]
|
||||
|
||||
if line_id == 21:
|
||||
string = string.replace("\t\t", "\t")
|
||||
mini_string = string.split("\t")
|
||||
candidate_answer = mini_string[-1]
|
||||
candidate_answer_list = candidate_answer.split("|")
|
||||
gold_answer_id = get_gold_answer_id(mini_string[-2], candidate_answer_list)
|
||||
candidate_passage = get_passage_string(passage_string,
|
||||
candidate_answer_list,
|
||||
mini_string[0],
|
||||
gold_answer_id)
|
||||
assert len(candidate_passage) == 10
|
||||
count += 10
|
||||
|
||||
else:
|
||||
passage_string = passage_string + " " + string
|
||||
else:
|
||||
passages.append(candidate_passage)
|
||||
candidate_passage_list = []
|
||||
passage_string = ""
|
||||
|
||||
print('read {} file finished!\n total count = {}'.format(input_file, count))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for passage in passages:
|
||||
for candidate_passage in passage:
|
||||
candidate_passage = candidate_passage.replace(" \t ", "\t ")
|
||||
candidate_passage = candidate_passage.strip()
|
||||
f.write(candidate_passage)
|
||||
f.write("\n")
|
||||
count -= 1
|
||||
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, count))
|
||||
|
||||
|
||||
def wikitext_detokenizer(string):
|
||||
# contractions
|
||||
string = string.replace("s '", "s'")
|
||||
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
||||
# number separators
|
||||
string = string.replace(" @-@ ", "-")
|
||||
string = string.replace(" @,@ ", ",")
|
||||
string = string.replace(" @.@ ", ".")
|
||||
# punctuation
|
||||
string = string.replace(" : ", ": ")
|
||||
string = string.replace(" ; ", "; ")
|
||||
string = string.replace(" . ", ". ")
|
||||
string = string.replace(" .", ".")
|
||||
string = string.replace(" ! ", "! ")
|
||||
string = string.replace(" ? ", "? ")
|
||||
string = string.replace(" , ", ", ")
|
||||
# double brackets
|
||||
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
||||
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
||||
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
||||
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
||||
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
||||
# miscellaneous
|
||||
string = string.replace("= = = =", "====")
|
||||
string = string.replace("= = =", "===")
|
||||
string = string.replace("= =", "==")
|
||||
string = string.replace(" " + chr(176) + " ", chr(176))
|
||||
string = string.replace(" \n", "\n")
|
||||
string = string.replace("\n ", "\n")
|
||||
string = string.replace(" N ", " 1 ")
|
||||
string = string.replace(" 's", "'s")
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def wikitext_dataset_preprocess(input_file, output_file):
|
||||
dataset_test = []
|
||||
passage = []
|
||||
count = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
if line.startswith('=') and line.endswith('=') and len(passage) != 0:
|
||||
dataset_test.append(passage)
|
||||
count += 1
|
||||
passage = []
|
||||
elif line.startswith('=') and line.endswith('='):
|
||||
continue
|
||||
else:
|
||||
passage.append(line)
|
||||
print('read {} file finished!\n total count = {}'.format(input_file, count))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for line in dataset_test:
|
||||
text = ""
|
||||
for sentence in line:
|
||||
sentence = wikitext_detokenizer(sentence)
|
||||
text = text + " " + sentence
|
||||
text = text.strip()
|
||||
f.write(text)
|
||||
f.write("\n")
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, count))
|
||||
|
||||
|
||||
def ptb_detokenizer(string):
|
||||
string = string.replace(" '", "'")
|
||||
string = string.replace(" \n", "\n")
|
||||
string = string.replace("\n ", "\n")
|
||||
string = string.replace(" n't", "n't")
|
||||
string = string.replace(" N ", "1 ")
|
||||
string = string.replace("$ 1", "$1")
|
||||
string = string.replace("# 1", "#1")
|
||||
string = string.replace("\/abc", "")
|
||||
string = string.replace("\/ua", "")
|
||||
|
||||
string = string.replace("s '", "s'")
|
||||
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
||||
# punctuation
|
||||
string = string.replace(" : ", ": ")
|
||||
string = string.replace(" ; ", "; ")
|
||||
string = string.replace(" . ", ". ")
|
||||
string = string.replace(" ! ", "! ")
|
||||
string = string.replace(" ? ", "? ")
|
||||
string = string.replace(" , ", ", ")
|
||||
|
||||
string = string.replace(" 's", "'s")
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def ptb_dataset_preprocess(input_file, output_file):
|
||||
sentences = []
|
||||
count = 0
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
line = ptb_detokenizer(line)
|
||||
sentences.append(line)
|
||||
count += 1
|
||||
print('read {} file finished!\n total count = {}'.format(input_file, count))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
f.write(sentence)
|
||||
f.write("\n")
|
||||
count -= 1
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, count))
|
||||
|
||||
|
||||
def onebw_detokenizer(string):
|
||||
# contractions
|
||||
string = string.replace("s '", "s'")
|
||||
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
||||
# number separators
|
||||
string = string.replace(" @-@ ", "-")
|
||||
string = string.replace(" @,@ ", ",")
|
||||
string = string.replace(" @.@ ", ".")
|
||||
# punctuation
|
||||
string = string.replace(" : ", ": ")
|
||||
string = string.replace(" ; ", "; ")
|
||||
string = string.replace(" . ", ". ")
|
||||
string = string.replace(" ! ", "! ")
|
||||
string = string.replace(" ? ", "? ")
|
||||
string = string.replace(" , ", ", ")
|
||||
# double brackets
|
||||
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
||||
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
||||
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
||||
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
||||
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
||||
# miscellaneous
|
||||
string = string.replace("= = = =", "====")
|
||||
string = string.replace("= = =", "===")
|
||||
string = string.replace("= =", "==")
|
||||
string = string.replace(" --", "")
|
||||
string = string.replace("--", "")
|
||||
string = string.replace("? ? ?", " ?")
|
||||
string = string.replace(" " + chr(176) + " ", chr(176))
|
||||
string = string.replace(" \n", "\n")
|
||||
string = string.replace("\n ", "\n")
|
||||
string = string.replace(" 't", "'t")
|
||||
string = string.replace(" N ", " 1 ")
|
||||
string = string.replace(" 's", "'s")
|
||||
|
||||
string = string.replace(" '", "'")
|
||||
string = string.replace(" n't", "n't")
|
||||
string = string.replace("$ 1", "$1")
|
||||
string = string.replace("# 1", "#1")
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def test_length(string):
|
||||
string_list = string.split()
|
||||
return len(string_list)
|
||||
|
||||
|
||||
def onebw_dataset_preprocess(condition, input_file, output_file):
|
||||
sentences = []
|
||||
count = 0
|
||||
if condition.lower() == "test":
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
sentences.append(line)
|
||||
count += 1
|
||||
print('read {} file finished!\n total count = {}'.format(input_file, count))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
sentence = onebw_detokenizer(sentence)
|
||||
f.write(sentence)
|
||||
f.write("\n")
|
||||
count -= 1
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, count))
|
||||
|
||||
elif condition.lower() == "train":
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
line = onebw_detokenizer(line)
|
||||
length = test_length(line)
|
||||
if length > 10 and length < 60:
|
||||
sentences.append(line)
|
||||
count += 1
|
||||
print('read finished! count = ', count)
|
||||
|
||||
sample_result_list = random.sample(range(0, count), 30000)
|
||||
sample_result_list.sort()
|
||||
count_sample = 0
|
||||
choiced_sentence = ""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(len(sample_result_list)):
|
||||
choiced_sentence = sentences[sample_result_list[i]]
|
||||
f.write(choiced_sentence)
|
||||
f.write("\n")
|
||||
count_sample += 1
|
||||
print('write finished! ', count_sample)
|
||||
|
||||
else:
|
||||
raise ValueError("condition error support: [train, test]")
|
||||
|
||||
|
||||
def coqa_dataset_preprocess(input_file, output_file):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
source_data = json.load(f)
|
||||
|
||||
stories = []
|
||||
instances = []
|
||||
end_sep = [',', '.', ';']
|
||||
question_before_sep = " "
|
||||
question_after_sep = " A: "
|
||||
answer_sep = " A:\t"
|
||||
|
||||
for i, dialog in enumerate(source_data["data"]):
|
||||
story = dialog["story"].replace("\n", "")
|
||||
stories.append(story)
|
||||
|
||||
concat_ = ""
|
||||
concat_ += story
|
||||
for question, answer in zip(dialog["questions"], dialog["answers"]):
|
||||
question = question["input_text"]
|
||||
answer = answer["input_text"]
|
||||
concat_ += question_before_sep
|
||||
concat_ += question
|
||||
tmp = concat_ + question_after_sep
|
||||
concat_ += answer_sep
|
||||
concat_ += answer
|
||||
instances.append(concat_)
|
||||
concat_ = tmp + answer
|
||||
if concat_[-1] not in end_sep:
|
||||
concat_ += "."
|
||||
instances.append("")
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i in range(len(instances)):
|
||||
if instances[i]:
|
||||
f.write(instances[i])
|
||||
f.write("\n")
|
||||
|
||||
print('write {} file finished!\n total count = {}'.format(output_file, len(instances)))
|
||||
|
||||
|
||||
def wmt14_en_fr_preprocess(input_file, output_file):
|
||||
input_file = input_file + "/newstest2014-fren-ref"
|
||||
output_file = output_file + "/wmt14"
|
||||
language = ['.en.sgm', '.fr.sgm']
|
||||
count = 0
|
||||
# en-fr
|
||||
with open(input_file + language[0], "r", encoding='utf-8') as english, \
|
||||
open(input_file + language[1], "r", encoding='utf-8') as french, \
|
||||
open(output_file + '.en_fr.txt', "a", encoding='utf-8') as enfr_f, \
|
||||
open(output_file + '.fr_en.txt', "a", encoding='utf-8') as fren_f:
|
||||
line_id = 0
|
||||
for en, fr in zip(english, french):
|
||||
line_id += 1
|
||||
if (en[:7] == '<seg id'):
|
||||
print("=" * 20, "\n", line_id, "\n", "=" * 20)
|
||||
en_start = en.find('>', 0)
|
||||
en_end = en.find('</seg>', 0)
|
||||
print(en[en_start + 1:en_end])
|
||||
en_ = en[en_start + 1:en_end]
|
||||
|
||||
fr_start = fr.find('>', 0)
|
||||
fr_end = fr.find('</seg>', 0)
|
||||
print(fr[fr_start + 1:fr_end])
|
||||
fr_ = fr[fr_start + 1:fr_end]
|
||||
|
||||
en_fr_str = en_ + "\t" + fr_ + "\n"
|
||||
enfr_f.write(en_fr_str)
|
||||
fr_en_str = fr_ + "\t" + en_ + "\n"
|
||||
fren_f.write(fr_en_str)
|
||||
count += 1
|
||||
|
||||
print('write {} file finished!\n total count = {}'.format(output_file + '.en_fr.txt', count))
|
||||
print('write {} file finished!\n total count = {}'.format(output_file + '.fr_en.txt', count))
|
|
@ -0,0 +1,542 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
generation utils
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
from .tensor_manipulations import extract_single_token_logits, add_last_token
|
||||
|
||||
INF = 1. * 1e9
|
||||
|
||||
|
||||
class TopKTopP_Filter():
|
||||
"""
|
||||
Top K sampling along with Top P sampling(Nucleus Sampling)
|
||||
|
||||
Choose top-K probability of ids and those with top-P probability ids into candidate sample sets.
|
||||
Use np.random.multinomial to sample
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size of input dataset.
|
||||
vocab_size (int): the shape of each embedding vector.
|
||||
k (int): parameter for Top-K sampling, k should be in range of [0, vocab_size].
|
||||
0 for no filter for TopK sampling(do nothing). Default: 0.
|
||||
p (float) [Optional]: parameter for Top-P sampling a.k.a. Necleus Sampling, p is in between 0.0 and 1.0.
|
||||
Default: 1.0.
|
||||
temperature (float) [Optional]: parameter for generation, greater if generation more diverse. Default: 1.0.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size=None,
|
||||
vocab_size=None,
|
||||
k=0,
|
||||
p=1.0,
|
||||
temperature=1.0,
|
||||
min_tokens_to_keep=1,
|
||||
):
|
||||
|
||||
self.k = k
|
||||
self.p = p
|
||||
self.temp = temperature
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
assert self.temp > 0.0, 'temperature must be positive'
|
||||
assert self.k >= 0, 'the top_k number must be no negative.'
|
||||
if self.k > 0:
|
||||
assert self.min_tokens_to_keep <= self.k, 'k must be larger than or equal to min_token_to_keep ' \
|
||||
'for Top-p sampling'
|
||||
|
||||
if self.k == 0:
|
||||
self.k = self.vocab_size
|
||||
|
||||
self.safety_mask = np.concatenate((np.ones((self.batch_size, self.min_tokens_to_keep)),
|
||||
np.zeros((self.batch_size, self.k - self.min_tokens_to_keep))),
|
||||
axis=1).astype(np.bool)
|
||||
|
||||
def calculate(self, distribution):
|
||||
"""
|
||||
calculate sampling procedure with setting initialized before, return a list of sampled ids.
|
||||
|
||||
Args:
|
||||
distribution (numpy.ndarray): with shape (batch_size,vocab_size)
|
||||
|
||||
Returns:
|
||||
sampled ids: a list, with length of batch_size
|
||||
"""
|
||||
|
||||
if self.temp != 1.0:
|
||||
distribution = distribution / float(self.temp)
|
||||
|
||||
distribution_sorted = -np.sort(-distribution, axis=1)
|
||||
index_sorted = np.argsort(-distribution, axis=1)
|
||||
|
||||
topk_distribution = distribution_sorted[::, :self.k if self.k > 0 else self.vocab_size]
|
||||
topk_indices = index_sorted[::, :self.k if self.k > 0 else self.vocab_size]
|
||||
|
||||
# safety check of probability
|
||||
self.p = max(0.0, min(1.0, self.p))
|
||||
cum_sum = np.cumsum(softmax(topk_distribution, axis=1), axis=1)
|
||||
bool_map = np.logical_or((cum_sum <= self.p), self.safety_mask).astype(np.float32)
|
||||
|
||||
topk_distribution = topk_distribution * bool_map + np.float32(-1e5) * (1.0 - bool_map)
|
||||
topk_distribution = softmax(topk_distribution, axis=1)
|
||||
|
||||
# normalize for np.float64
|
||||
# choose np.float64 to avoid overflow in softmax operation
|
||||
topk_distribution = topk_distribution.astype(np.float64)
|
||||
for batch_idx in range(self.batch_size):
|
||||
topk_distribution[batch_idx] = topk_distribution[batch_idx] / np.sum(topk_distribution[batch_idx])
|
||||
|
||||
ret_ids = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
select_index = np.argmax(np.random.multinomial(1, topk_distribution[batch_idx]))
|
||||
ret_ids.append(topk_indices[batch_idx][select_index])
|
||||
|
||||
return ret_ids
|
||||
|
||||
|
||||
class Sample():
|
||||
"""
|
||||
Initiate a Sample object for sampling next token(s) from previous text.
|
||||
|
||||
Args:
|
||||
decoder (Model): GPT2 model to do generation.
|
||||
config (GPT2Config): configuration of given GPT2 model.
|
||||
tokenizer (GPT2Tokenizer): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory.
|
||||
generate_length (int): number of tokens which should be generated. Default: 1.
|
||||
topk_num (int): number of k in Top-k Sampling, 0 for no condition constrained,
|
||||
equivalent to k = self.vocab_size. Default:0
|
||||
topp_prob (float): probability parameter of Top-p sampling.
|
||||
if p = 1.0, it equals to do nothing. (nucleus sampling). Default: 1.0
|
||||
temperature (float): temperature for Top-k sampling. Default: 1.0
|
||||
min_tokens_to_keep (int): guarantee for there is at least min_tokens_to_keep token(s) generated. Default:1
|
||||
early_stop (bool): whether stop when the model generates <EOS> token.
|
||||
It is functioned when batch_size is 1. Default: False
|
||||
demo_mode(bool): True if input_str is a str not a List of str.
|
||||
self.batch_size should be 1 if it is True. Default: False
|
||||
return_ids (bool): whether return ids generated from Sample. Default: False
|
||||
return_last_token_logits (bool): whether return logits of last token for each time step during generation.
|
||||
Default: False
|
||||
append_eos (bool): whether append <EOS> token id to input_ids pass directly to GPT2Model class. Default: False
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decoder,
|
||||
config=None,
|
||||
batch_size=None,
|
||||
tokenizer=None,
|
||||
generate_length=1,
|
||||
topk_num=0,
|
||||
topp_prob=1.0,
|
||||
temperature=1.0,
|
||||
min_tokens_to_keep=1,
|
||||
early_stop=False,
|
||||
demo_mode=False,
|
||||
return_ids=False,
|
||||
return_last_token_logits=False,
|
||||
append_eos=False,
|
||||
):
|
||||
|
||||
assert config is not None, 'Config is a must for sampling.'
|
||||
|
||||
self.decoder = decoder
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_length = generate_length
|
||||
self.topk_num = topk_num
|
||||
self.topp_prob = topp_prob
|
||||
self.temperature = temperature
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
self.early_stop = early_stop
|
||||
self.demo_mode = demo_mode
|
||||
self.return_ids = return_ids
|
||||
self.return_last_token_logits = return_last_token_logits
|
||||
self.append_eos = append_eos
|
||||
|
||||
self.seq_length = config.seq_length
|
||||
self.batch_size = config.batch_size if batch_size is None else batch_size
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reshape = P.Reshape()
|
||||
self.cumsum = P.CumSum()
|
||||
self.onehot = P.OneHot()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat()
|
||||
self.sample_function = P.RandomCategorical(mstype.int32)
|
||||
self.filter_distribution = TopKTopP_Filter(batch_size=self.batch_size,
|
||||
vocab_size=self.vocab_size,
|
||||
k=self.topk_num,
|
||||
p=self.topp_prob,
|
||||
temperature=self.temperature,
|
||||
min_tokens_to_keep=self.min_tokens_to_keep)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.eos_id = self.tokenizer.eos_token_id
|
||||
else:
|
||||
self.eos_id = config.vocab_size - 1
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.eos_text = self.tokenizer.eos_token
|
||||
else:
|
||||
self.eos_text = "<|endoftext|>"
|
||||
|
||||
if self.demo_mode is True:
|
||||
assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format(
|
||||
self.batch_size)
|
||||
|
||||
def _extract_string_from_tensor(self, input_ids, mode="pair"):
|
||||
"""
|
||||
Args:
|
||||
input_ids(Tensor): input sentences with shape [self.batch_size, self.seq_len]
|
||||
mode (str): ["pair", "single"]
|
||||
"pair" for tasks with paired inputs `<bos> A <eos> B <eos>`,
|
||||
such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`,
|
||||
reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`.
|
||||
|
||||
"single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task.
|
||||
Returns:
|
||||
source_list (list): the list of source_text or first part of text.
|
||||
target_list (list): the list of target_text or second part of text.
|
||||
If self.batch_size is 1, it will return the first sentence of list, that is to say, the string.
|
||||
|
||||
Example:
|
||||
for pair mode, if self.demo_mode is True, it will return source_list[0], target_list[0]
|
||||
"""
|
||||
assert self.tokenizer is not None, 'There is no tokenizer'
|
||||
source_list = [""] * self.batch_size
|
||||
target_list = [""] * self.batch_size
|
||||
eos_text = self.tokenizer.eos_token
|
||||
len_eos_text = len(eos_text)
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
input_ids_np = input_ids_np.reshape((self.batch_size, self.seq_length))
|
||||
# input_ids = self.reshape(input_ids, (self.batch_size, self.seq_length))
|
||||
|
||||
if mode == "pair":
|
||||
for batch_idx in range(self.batch_size):
|
||||
sentence_tensor = input_ids_np[batch_idx]
|
||||
sentence_list = sentence_tensor.tolist()[1:]
|
||||
|
||||
sentence = self.tokenizer.decode(sentence_list)
|
||||
source_start = 0
|
||||
source_end = sentence.find(eos_text, 0)
|
||||
target_start = source_end + len_eos_text
|
||||
target_end = sentence[target_start:].find(eos_text, 0) + target_start
|
||||
source_list[batch_idx] = sentence[source_start:source_end]
|
||||
target_list[batch_idx] = sentence[target_start:target_end]
|
||||
|
||||
if self.batch_size == 1 and self.demo_mode is True:
|
||||
return source_list[0], target_list[0]
|
||||
return source_list, target_list
|
||||
|
||||
if mode == "single":
|
||||
for batch_idx in range(self.batch_size):
|
||||
sentence_tensor = input_ids_np[batch_idx]
|
||||
sentence_list = sentence_tensor.tolist()[1:]
|
||||
|
||||
sentence = self.tokenizer.decode(sentence_list)
|
||||
source_start = 0
|
||||
source_end = sentence.find(eos_text, 0)
|
||||
source_list[batch_idx] = sentence[source_start:source_end]
|
||||
if self.batch_size == 1 and self.demo_mode is True:
|
||||
return source_list[0]
|
||||
else:
|
||||
raise ValueError('mode:{} not supported, only support [pair, single].'.format(mode))
|
||||
|
||||
return source_list
|
||||
|
||||
def _tensorize_ids_with_masks(self, src_str):
|
||||
"""
|
||||
Transform from string to tensor
|
||||
|
||||
Args:
|
||||
src_str: string or list of strings
|
||||
Return:
|
||||
input_ids (Tensor): shape with [self.batch_size, self.seq_length]
|
||||
input_mask (Tensor): shape with [self.batch_size, self.seq_length]
|
||||
src_len_list (list): the length of tokens of src_string after decoded by self.tokenzier
|
||||
"""
|
||||
|
||||
if isinstance(src_str, str):
|
||||
src_str = [src_str]
|
||||
|
||||
single_sentence_shape = (1, self.seq_length)
|
||||
src_len_list = list()
|
||||
input_ids = None
|
||||
input_mask = None
|
||||
|
||||
for batch_idx in range(self.batch_size):
|
||||
src_ids_list = self.tokenizer.encode(src_str[batch_idx])
|
||||
src_ids_len = len(src_ids_list)
|
||||
if src_ids_len > self.seq_length:
|
||||
src_ids_list = src_ids_list[:self.seq_length]
|
||||
src_ids_len = self.seq_length
|
||||
|
||||
src_len_list.append(src_ids_len)
|
||||
return_dict = self.tokenizer.prepare_for_model(src_ids_list,
|
||||
max_length=self.config.seq_length,
|
||||
add_special_tokens=False)
|
||||
|
||||
input_ids_list = return_dict['input_ids']
|
||||
input_mask_list = return_dict['attention_mask']
|
||||
|
||||
input_ids_np = np.array(input_ids_list, dtype=int)
|
||||
input_mask_np = np.array(input_mask_list, dtype=int)
|
||||
input_ids_np = input_ids_np.reshape(single_sentence_shape)
|
||||
input_mask_np = input_mask_np.reshape(single_sentence_shape)
|
||||
|
||||
# input_ids_tensor = self.reshape(Tensor(np.array(input_ids_list, dtype=int), dtype=mstype.int32),
|
||||
# single_sentence_shape)
|
||||
# input_mask_tensor = self.reshape(Tensor(np.array(input_mask_list, dtype=int), dtype=mstype.int32),
|
||||
# single_sentence_shape)
|
||||
if batch_idx == 0:
|
||||
# input_ids = input_ids_tensor
|
||||
# input_mask = input_mask_tensor
|
||||
input_ids_np_ = input_ids_np
|
||||
input_mask_np_ = input_mask_np
|
||||
else:
|
||||
# input_ids = self.concat((input_ids, input_ids_tensor))
|
||||
# input_mask = self.concat((input_mask, input_mask_tensor))
|
||||
input_ids_np_ = np.concatenate((input_ids_np_, input_ids_np), axis=0)
|
||||
input_mask_np_ = np.concatenate((input_mask_np_, input_mask_np), axis=0)
|
||||
|
||||
input_ids = Tensor(input_ids_np_, dtype=mstype.int32)
|
||||
input_mask = Tensor(input_mask_np_, dtype=mstype.int32)
|
||||
|
||||
return input_ids, input_mask, src_len_list
|
||||
|
||||
class LastTokenPos():
|
||||
"""
|
||||
class for record input_strs and the position of their last tokens
|
||||
|
||||
Args:
|
||||
input_ (Union[list, Tensor]): list if input is a list containing strings,
|
||||
Tensor with shape (batch_size, seq_length) representing input_mask.
|
||||
"""
|
||||
|
||||
def __init__(self, input_, seq_length=1024):
|
||||
if isinstance(input_, list):
|
||||
self.input_strs = input_
|
||||
self.input_mask = None
|
||||
else:
|
||||
self.input_strs = None
|
||||
self.input_mask = input_
|
||||
|
||||
self.seq_length = seq_length
|
||||
if self.input_strs is not None:
|
||||
self.pos_list = [len(input_str) - 1 for input_str in self.input_strs]
|
||||
else:
|
||||
input_mask_ = P.Cast()(self.input_mask, mstype.float32)
|
||||
temp_pos_list = P.ReduceSum(keep_dims=False)(input_mask_, axis=1).asnumpy().astype(np.int32).tolist()
|
||||
# minimum value is always 0 for safety
|
||||
self.pos_list = [max(0, pos - 1) for pos in temp_pos_list]
|
||||
|
||||
def get_pos(self, shift: int = 0):
|
||||
# return last token if overflow
|
||||
shift_list = [min(self.seq_length - 1, pos + shift) for pos in self.pos_list]
|
||||
return shift_list
|
||||
|
||||
def _sample_from_distribution(self, distribution):
|
||||
"""
|
||||
sample one token per batch from self.sample_function().
|
||||
|
||||
Arg:
|
||||
distribution (Tensor): the distribution or logits of the last token of different batches.
|
||||
shape with [batch_size, vocab_size]
|
||||
|
||||
Return:
|
||||
word_index (Tensor): shape with [batch_size, ]
|
||||
"""
|
||||
|
||||
distribution = self.reshape(distribution, (self.vocab_size, self.batch_size))
|
||||
topk_distribution = distribution[:self.topk_num, ::]
|
||||
topk_distribution = self.reshape(topk_distribution, (self.batch_size, -1))
|
||||
|
||||
word_index = self.sample_function(P.Softmax()(topk_distribution), 1, 1)
|
||||
word_index = self.reshape(word_index, (-1,))
|
||||
|
||||
return word_index
|
||||
|
||||
def _demo_mode_check(self, input_str):
|
||||
"""
|
||||
type check for demo_mode: 1 batch, input_str is not None and initiate full_str as input_str
|
||||
"""
|
||||
if self.batch_size == 1 and self.demo_mode is True:
|
||||
assert input_str is not None, "demo mode should have input str"
|
||||
# type check
|
||||
if isinstance(input_str, list):
|
||||
assert isinstance(input_str[0], str), "type of input_str is {}, " \
|
||||
"which should be str instead.".format(type(input_str[0]))
|
||||
if len(input_str) != 1:
|
||||
print("[WARNING] Sample.generate: length of input_str is larger than 1, "
|
||||
"choose input_str[0] as input_str.")
|
||||
input_str = input_str[0]
|
||||
assert isinstance(input_str, str), "type of input_str is {}, " \
|
||||
"which should be str instead.".format(input_str)
|
||||
input_str = [input_str]
|
||||
return input_str
|
||||
|
||||
def _input_check_and_normalize(self, input_str=None, input_ids=None, input_mask=None, generate_length=None):
|
||||
"""
|
||||
input check function
|
||||
"""
|
||||
if input_str is not None:
|
||||
assert self.tokenizer is not None, 'if choose to give input_str, a tokenizer is necessary.'
|
||||
input_str = self._demo_mode_check(input_str)
|
||||
|
||||
if input_ids is not None:
|
||||
assert input_mask is not None, 'if input_ids is given, input_mask is required either.'
|
||||
|
||||
if input_str is not None and input_ids is not None and input_mask is not None:
|
||||
print('[WARNING] Sample.generate got input_str, input_ids and input_mask, '
|
||||
'choose input_str as default for input')
|
||||
|
||||
if input_ids is None and input_mask is None:
|
||||
input_ids, input_mask, _ = self._tensorize_ids_with_masks(input_str)
|
||||
else:
|
||||
if input_str is None:
|
||||
if input_ids is not None:
|
||||
input_str = self._extract_string_from_tensor(input_ids, mode="full")
|
||||
|
||||
if generate_length is not None:
|
||||
# reload generate_length
|
||||
generate_length = int(generate_length)
|
||||
assert generate_length >= 0, 'generate_length can not be negative.'
|
||||
else:
|
||||
generate_length = self.generate_length
|
||||
|
||||
return input_str, input_ids, input_mask, generate_length
|
||||
|
||||
def generate(self, input_str=None, input_ids=None, input_mask=None, generate_length=None, do_sample=True):
|
||||
"""
|
||||
base function for text generation given a batch_size list of str or str itself (when demo mode is on)
|
||||
|
||||
Args
|
||||
input_str (list(str) or str): prompt string.
|
||||
generate_length: number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
generate_str: string generated by the GPT-2 model.
|
||||
full_str: input_str appended with generate_str.
|
||||
"""
|
||||
|
||||
input_str, input_ids, input_mask, generate_length = self._input_check_and_normalize(input_str,
|
||||
input_ids,
|
||||
input_mask,
|
||||
generate_length)
|
||||
return_ids_list = [[]] * self.batch_size
|
||||
|
||||
last_token = self.LastTokenPos(input_mask, seq_length=self.seq_length)
|
||||
|
||||
for i in range(generate_length):
|
||||
last_token_pos_list = last_token.get_pos(shift=i)
|
||||
early_stop_mask = [0] * self.batch_size
|
||||
|
||||
# unsorted logits (distribution) of next word
|
||||
logits = self.decoder.predict(input_ids, input_mask)
|
||||
|
||||
if self.return_last_token_logits is True:
|
||||
if i == 0:
|
||||
# [batch_size, 1, vocab_size]
|
||||
return_last_logits = extract_single_token_logits(logits, last_token_pos_list)
|
||||
else:
|
||||
# [batch_size, 1, vocab_size] + [batch_size, i, vocab_size] --> [batch_size, i+1, vocab_size]
|
||||
return_last_logits = P.Concat(axis=1)((return_last_logits,
|
||||
extract_single_token_logits(logits, last_token_pos_list)))
|
||||
|
||||
nextword_distribution = self.reshape(logits[0, last_token_pos_list[0]:last_token_pos_list[0]+1:1, ::],
|
||||
(1, -1))
|
||||
|
||||
# stack up nextword_distribution if batch_size is larger than 1
|
||||
if self.batch_size > 1:
|
||||
for batch_idx in range(1, self.batch_size):
|
||||
nextword_distribution_rest = self.reshape(
|
||||
logits[batch_idx, last_token_pos_list[batch_idx]:last_token_pos_list[batch_idx] + 1:1, ::],
|
||||
(1, -1))
|
||||
nextword_distribution = self.concat((nextword_distribution, nextword_distribution_rest))
|
||||
|
||||
if do_sample:
|
||||
# get sampled ids
|
||||
nextword_distribution = nextword_distribution.asnumpy().astype(np.float32)
|
||||
real_next_word_index_list = self.filter_distribution.calculate(nextword_distribution)
|
||||
else:
|
||||
np_nextword_distribution = nextword_distribution.asnumpy()
|
||||
next_word_index = np.argmax(np_nextword_distribution, axis=-1)
|
||||
real_next_word_index_list = next_word_index.tolist()
|
||||
|
||||
append_ids = []
|
||||
|
||||
# tokenizer.decode and early_stop (if all batched generates a EOS, then it is time to say goodbye)
|
||||
for batch_idx in range(self.batch_size):
|
||||
next_word_index = real_next_word_index_list[batch_idx]
|
||||
# earlystop if the model generates a EOS token.
|
||||
if self.early_stop is True:
|
||||
if next_word_index == self.eos_id:
|
||||
if self.batch_size == 1:
|
||||
break
|
||||
else:
|
||||
early_stop_mask[batch_idx] = 1
|
||||
continue
|
||||
|
||||
return_ids_list[batch_idx].append(next_word_index)
|
||||
append_ids.append(next_word_index)
|
||||
|
||||
# check early_stop mask at the end of each loop
|
||||
if 0 not in early_stop_mask:
|
||||
break
|
||||
input_ids, input_mask = add_last_token(input_ids,
|
||||
input_mask,
|
||||
overflow_strategy="shift",
|
||||
append_ids=append_ids,
|
||||
next_token_pos=last_token.get_pos(shift=i + 1))
|
||||
|
||||
# add str to full str
|
||||
generate_str = [""] * self.batch_size
|
||||
full_str = [""] * self.batch_size
|
||||
text_cnt = 0
|
||||
|
||||
for text_ids in return_ids_list:
|
||||
text = self.tokenizer.decode(text_ids)
|
||||
generate_str[text_cnt] = text
|
||||
text_cnt += 1
|
||||
|
||||
for batch_idx in range(self.batch_size):
|
||||
full_str[batch_idx] = input_str[batch_idx] + generate_str[batch_idx]
|
||||
|
||||
# return by several conditions
|
||||
if self.batch_size == 1 and self.demo_mode is True:
|
||||
if self.return_ids:
|
||||
return generate_str[0], input_str[0], return_ids_list[0]
|
||||
return generate_str[0], input_str[0]
|
||||
|
||||
if self.return_ids:
|
||||
if self.return_last_token_logits:
|
||||
return return_ids_list, return_last_logits
|
||||
return return_ids_list
|
||||
|
||||
return generate_str, full_str
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""get config setting"""
|
||||
|
||||
|
||||
def get_train_setting(finetune_config):
|
||||
"""get train config setting"""
|
||||
cfg = finetune_config
|
||||
|
||||
print("Loading GPT2 Finetune Config setting......")
|
||||
print(" | optimizer: {}".format(cfg.optimizer))
|
||||
opt = cfg['optimizer']
|
||||
print(" | learning rate: {}".format(cfg[opt]['learning_rate']))
|
||||
print(" | end learning rate: {}".format(
|
||||
cfg[opt]['end_learning_rate'] if 'end_learning_rate' in cfg[opt] else 'None'))
|
||||
print(" | weight decay: {}\n".format(cfg[opt]['weight_decay'] if 'weight_decay' in cfg[opt] else 'None'))
|
||||
|
||||
|
||||
def get_model_setting(finetune_config, model_config):
|
||||
"""get GPT-2 model config setting"""
|
||||
cfg = finetune_config
|
||||
gpt2_net_cfg = model_config
|
||||
|
||||
print("Loading GPT2 Model Config setting......")
|
||||
print(" | model size: {}".format(cfg.gpt2_network))
|
||||
print(" | batch_size: {}".format(gpt2_net_cfg.batch_size))
|
||||
print(" | seq_length: {}".format(gpt2_net_cfg.seq_length))
|
||||
print(" | vocab_size: {}".format(gpt2_net_cfg.vocab_size))
|
||||
print(" | d_model: {}".format(gpt2_net_cfg.d_model))
|
||||
print(" | num_hidden_layers: {}".format(gpt2_net_cfg.num_hidden_layers))
|
||||
print(" | num_attention_heads: {}".format(gpt2_net_cfg.num_attention_heads))
|
||||
print(" | hidden_dropout: {}".format(gpt2_net_cfg.hidden_dropout))
|
||||
print(" | attention_dropout: {}".format(gpt2_net_cfg.attention_dropout))
|
||||
print(" | summary_first_dropout: {}\n".format(gpt2_net_cfg.summary_first_dropout))
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning schedule"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
|
||||
|
||||
class GPT2LearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Implements of warmup-polydecay learning rate scheduler.
|
||||
|
||||
Args:
|
||||
learning_rate (float): The initial value of learning rate.
|
||||
end_learning_rate (float): The end value of learning rate.
|
||||
warmup_steps (int): The warm up steps of learning rate.
|
||||
decay_steps (int): A value used to calculate decayed learning rate.
|
||||
power (float): A value used to calculate decayed learning rate.
|
||||
|
||||
Returns:
|
||||
lr (Tensor): The learning rate value for the current step.
|
||||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(GPT2LearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
|
@ -0,0 +1,185 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metric method for downstream task"""
|
||||
|
||||
import string
|
||||
import re
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
from .rouge_score import get_rouge_score
|
||||
from .bleu import compute_bleu
|
||||
|
||||
|
||||
class LastWordAccuracy():
|
||||
"""
|
||||
LastWordAccuracy class is for lambada task (predict the final word of sentence)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
|
||||
def normalize(self, word):
|
||||
"""normalization"""
|
||||
word = word.lstrip()
|
||||
word = word.rstrip()
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return remove_punc(lower(word))
|
||||
|
||||
def update(self, predict_label, gold_label):
|
||||
if isinstance(predict_label, str) and isinstance(gold_label, str):
|
||||
predict_label = [predict_label]
|
||||
gold_label = [gold_label]
|
||||
for predict_word, gold_word in zip(predict_label, gold_label):
|
||||
self.total_num += 1
|
||||
if self.normalize(predict_word) == self.normalize(gold_word):
|
||||
self.acc_num += 1
|
||||
|
||||
|
||||
class Accuracy():
|
||||
"""
|
||||
calculate accuracy
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.acc_num = 0
|
||||
self.total_num = 0
|
||||
|
||||
def update(self, logits, labels):
|
||||
"""accuracy update"""
|
||||
labels = np.reshape(labels, -1)
|
||||
logits_id = np.argmax(logits, axis=-1)
|
||||
print(" | Preict Label: {} Gold Label: {}".format(logits_id, labels))
|
||||
self.acc_num += np.sum(labels == logits_id)
|
||||
self.total_num += len(labels)
|
||||
print("\n| Accuracy = {} \n".format(self.acc_num / self.total_num))
|
||||
|
||||
|
||||
class F1():
|
||||
"""calculate F1 score"""
|
||||
|
||||
def __init__(self):
|
||||
self.f1_score = 0.0
|
||||
|
||||
def get_normalize_answer_token(self, string_):
|
||||
"""Lower text and remove punctuation, article and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(char for char in text if char not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(string_)))).split()
|
||||
|
||||
def update(self, pred_answer, gold_answer):
|
||||
"""F1 update"""
|
||||
common = Counter(pred_answer) & Counter(gold_answer)
|
||||
num_same = sum(common.values())
|
||||
# the number of same tokens between pred_answer and gold_answer
|
||||
precision = 1.0 * num_same / len(pred_answer) if pred_answer else 0
|
||||
recall = 1.0 * num_same / len(gold_answer) if gold_answer else 0
|
||||
if ' '.join(pred_answer).strip() == "" and ' '.join(gold_answer).strip() == "":
|
||||
self.f1_score += 1
|
||||
else:
|
||||
self.f1_score += 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0
|
||||
|
||||
print('| precision: {}, recall: {}\n'.format(precision, recall))
|
||||
|
||||
|
||||
class BLEU():
|
||||
"""calculate BLEU score"""
|
||||
|
||||
def __init__(self, tokenizer=None, max_order=4, smooth=True):
|
||||
self.bleu = 0.0
|
||||
self.total_num = 0
|
||||
self.tokenizer = tokenizer
|
||||
self.max_order = max_order
|
||||
self.smooth = smooth
|
||||
|
||||
def sum_bleu(self, references, translations, max_order, smooth):
|
||||
"""calculate the sum of bleu score"""
|
||||
all_result = []
|
||||
bleu_avg = 0.0
|
||||
for refer, trans in zip(references, translations):
|
||||
result = compute_bleu([[refer]], [trans], max_order, smooth)
|
||||
all_result.append(result)
|
||||
bleu_avg += result[0]
|
||||
bleu_avg /= len(references)
|
||||
return bleu_avg, all_result
|
||||
|
||||
def update(self, hypotheses, references):
|
||||
"""BLEU update"""
|
||||
hypo_l = []
|
||||
ref_l = []
|
||||
if self.tokenizer is not None:
|
||||
for hypo, ref in zip(hypotheses, references):
|
||||
if ref.strip() == '':
|
||||
print("Reference is None, skip it !")
|
||||
continue
|
||||
if hypo.strip() == '':
|
||||
print("translation is None, skip it !")
|
||||
continue
|
||||
hypo_l.append(self.tokenizer.encode(hypo))
|
||||
ref_l.append(self.tokenizer.encode(ref))
|
||||
|
||||
if hypo_l and ref_l:
|
||||
hypotheses = hypo_l
|
||||
references = ref_l
|
||||
|
||||
bleu_avg, _ = self.sum_bleu(references, hypotheses, self.max_order, self.smooth)
|
||||
self.bleu += bleu_avg * 100
|
||||
|
||||
self.total_num += 1
|
||||
|
||||
print("============== BLEU: {} ==============".format(float(self.bleu / self.total_num)))
|
||||
|
||||
|
||||
class Rouge():
|
||||
'''
|
||||
Get Rouge Score
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.Rouge1 = 0.0
|
||||
self.Rouge2 = 0.0
|
||||
self.RougeL = 0.0
|
||||
self.total_num = 0
|
||||
|
||||
def update(self, hypothesis, targets):
|
||||
scores = get_rouge_score(hypothesis, targets)
|
||||
self.Rouge1 += scores['rouge-1']['f'] * 100
|
||||
self.Rouge2 += scores['rouge-2']['f'] * 100
|
||||
self.RougeL += scores['rouge-l']['f'] * 100
|
||||
self.total_num += 1
|
||||
|
||||
print("=============== ROUGE: {} ===============".format(
|
||||
(self.Rouge1 + self.Rouge2 + self.RougeL) / float(3.0 * self.total_num)))
|
|
@ -0,0 +1,466 @@
|
|||
,
|
||||
.
|
||||
?
|
||||
!
|
||||
#
|
||||
~
|
||||
=
|
||||
-
|
||||
"
|
||||
'
|
||||
:
|
||||
-
|
||||
…
|
||||
--
|
||||
|
|
||||
|
||||
a
|
||||
about
|
||||
above
|
||||
across
|
||||
after
|
||||
again
|
||||
against
|
||||
all
|
||||
almost
|
||||
alone
|
||||
along
|
||||
already
|
||||
also
|
||||
although
|
||||
always
|
||||
among
|
||||
an
|
||||
and
|
||||
another
|
||||
any
|
||||
anybody
|
||||
anyone
|
||||
anything
|
||||
anywhere
|
||||
are
|
||||
area
|
||||
areas
|
||||
around
|
||||
as
|
||||
ask
|
||||
asked
|
||||
asking
|
||||
asks
|
||||
at
|
||||
away
|
||||
b
|
||||
back
|
||||
backed
|
||||
backing
|
||||
backs
|
||||
be
|
||||
became
|
||||
because
|
||||
become
|
||||
becomes
|
||||
been
|
||||
before
|
||||
began
|
||||
behind
|
||||
being
|
||||
beings
|
||||
best
|
||||
better
|
||||
between
|
||||
big
|
||||
both
|
||||
bro
|
||||
but
|
||||
by
|
||||
c
|
||||
came
|
||||
can
|
||||
cannot
|
||||
case
|
||||
cases
|
||||
certain
|
||||
certainly
|
||||
clear
|
||||
clearly
|
||||
come
|
||||
could
|
||||
d
|
||||
did
|
||||
differ
|
||||
different
|
||||
differently
|
||||
do
|
||||
does
|
||||
done
|
||||
down
|
||||
down
|
||||
downed
|
||||
downing
|
||||
downs
|
||||
during
|
||||
dr
|
||||
e
|
||||
each
|
||||
early
|
||||
eh
|
||||
either
|
||||
end
|
||||
ended
|
||||
ending
|
||||
ends
|
||||
enough
|
||||
even
|
||||
evenly
|
||||
ever
|
||||
every
|
||||
everybody
|
||||
everyone
|
||||
everything
|
||||
everywhere
|
||||
f
|
||||
fact
|
||||
facts
|
||||
far
|
||||
felt
|
||||
few
|
||||
find
|
||||
finds
|
||||
first
|
||||
for
|
||||
four
|
||||
from
|
||||
full
|
||||
fully
|
||||
further
|
||||
furthered
|
||||
furthering
|
||||
furthers
|
||||
g
|
||||
gave
|
||||
general
|
||||
generally
|
||||
get
|
||||
gets
|
||||
give
|
||||
given
|
||||
gives
|
||||
going
|
||||
good
|
||||
goods
|
||||
got
|
||||
great
|
||||
greater
|
||||
greatest
|
||||
group
|
||||
grouped
|
||||
grouping
|
||||
groups
|
||||
h
|
||||
had
|
||||
has
|
||||
have
|
||||
having
|
||||
he
|
||||
her
|
||||
here
|
||||
herself
|
||||
hey
|
||||
high
|
||||
high
|
||||
high
|
||||
higher
|
||||
highest
|
||||
him
|
||||
himself
|
||||
his
|
||||
house
|
||||
how
|
||||
however
|
||||
i
|
||||
if
|
||||
important
|
||||
in
|
||||
interest
|
||||
interested
|
||||
interesting
|
||||
interests
|
||||
into
|
||||
is
|
||||
it
|
||||
its
|
||||
itself
|
||||
j
|
||||
just
|
||||
k
|
||||
kae
|
||||
keep
|
||||
keeps
|
||||
kind
|
||||
knew
|
||||
know
|
||||
known
|
||||
knows
|
||||
kya
|
||||
l
|
||||
lads
|
||||
large
|
||||
largely
|
||||
last
|
||||
later
|
||||
latest
|
||||
least
|
||||
less
|
||||
let
|
||||
lets
|
||||
like
|
||||
likely
|
||||
long
|
||||
longer
|
||||
longest
|
||||
m
|
||||
made
|
||||
make
|
||||
making
|
||||
man
|
||||
many
|
||||
may
|
||||
me
|
||||
member
|
||||
members
|
||||
men
|
||||
might
|
||||
mister
|
||||
more
|
||||
most
|
||||
mostly
|
||||
mr
|
||||
Mr
|
||||
mrs
|
||||
much
|
||||
must
|
||||
my
|
||||
myself
|
||||
n
|
||||
na
|
||||
necessary
|
||||
need
|
||||
needed
|
||||
needing
|
||||
needs
|
||||
never
|
||||
new
|
||||
new
|
||||
newer
|
||||
newest
|
||||
next
|
||||
no
|
||||
nobody
|
||||
non
|
||||
noone
|
||||
not
|
||||
nothing
|
||||
now
|
||||
nowhere
|
||||
number
|
||||
numbers
|
||||
nt
|
||||
nn
|
||||
nope
|
||||
ny
|
||||
o
|
||||
oi
|
||||
of
|
||||
off
|
||||
often
|
||||
old
|
||||
older
|
||||
oldest
|
||||
on
|
||||
once
|
||||
one
|
||||
only
|
||||
open
|
||||
opened
|
||||
opening
|
||||
opens
|
||||
or
|
||||
order
|
||||
ordered
|
||||
ordering
|
||||
orders
|
||||
other
|
||||
others
|
||||
our
|
||||
out
|
||||
over
|
||||
oh
|
||||
p
|
||||
part
|
||||
parted
|
||||
parting
|
||||
parts
|
||||
per
|
||||
perhaps
|
||||
place
|
||||
places
|
||||
please
|
||||
point
|
||||
pointed
|
||||
pointing
|
||||
points
|
||||
possible
|
||||
present
|
||||
presented
|
||||
presenting
|
||||
presents
|
||||
problem
|
||||
problems
|
||||
put
|
||||
puts
|
||||
q
|
||||
quite
|
||||
r
|
||||
rather
|
||||
really
|
||||
right
|
||||
right
|
||||
room
|
||||
rooms
|
||||
s
|
||||
|
||||
said
|
||||
same
|
||||
saw
|
||||
say
|
||||
says
|
||||
second
|
||||
seconds
|
||||
see
|
||||
seem
|
||||
seemed
|
||||
seeming
|
||||
seems
|
||||
sees
|
||||
several
|
||||
shall
|
||||
she
|
||||
should
|
||||
show
|
||||
showed
|
||||
showing
|
||||
shows
|
||||
side
|
||||
sides
|
||||
since
|
||||
small
|
||||
smaller
|
||||
smallest
|
||||
so
|
||||
some
|
||||
somebody
|
||||
someone
|
||||
something
|
||||
somewhere
|
||||
state
|
||||
states
|
||||
still
|
||||
still
|
||||
such
|
||||
sure
|
||||
t
|
||||
take
|
||||
taken
|
||||
than
|
||||
that
|
||||
the
|
||||
their
|
||||
them
|
||||
then
|
||||
there
|
||||
therefore
|
||||
these
|
||||
they
|
||||
thing
|
||||
things
|
||||
think
|
||||
thinks
|
||||
this
|
||||
those
|
||||
though
|
||||
thought
|
||||
thoughts
|
||||
three
|
||||
through
|
||||
thus
|
||||
to
|
||||
today
|
||||
together
|
||||
too
|
||||
took
|
||||
toward
|
||||
turn
|
||||
turned
|
||||
turning
|
||||
turns
|
||||
two
|
||||
u
|
||||
uh
|
||||
um
|
||||
under
|
||||
until
|
||||
up
|
||||
upon
|
||||
us
|
||||
use
|
||||
used
|
||||
uses
|
||||
v
|
||||
very
|
||||
w
|
||||
want
|
||||
wanted
|
||||
wanting
|
||||
wants
|
||||
was
|
||||
way
|
||||
ways
|
||||
we
|
||||
well
|
||||
wells
|
||||
went
|
||||
were
|
||||
what
|
||||
when
|
||||
where
|
||||
whether
|
||||
which
|
||||
while
|
||||
who
|
||||
whole
|
||||
whose
|
||||
why
|
||||
will
|
||||
with
|
||||
within
|
||||
without
|
||||
work
|
||||
worked
|
||||
working
|
||||
works
|
||||
would
|
||||
x
|
||||
y
|
||||
ya
|
||||
ye
|
||||
year
|
||||
years
|
||||
yet
|
||||
you
|
||||
young
|
||||
younger
|
||||
youngest
|
||||
your
|
||||
yours
|
||||
z
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Calculate ROUGE score."""
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from typing import List
|
||||
from rouge import Rouge
|
||||
|
||||
|
||||
def get_rouge_score(hypothesis: List[str], target: List[str]):
|
||||
"""
|
||||
Calculate ROUGE score.
|
||||
|
||||
Args:
|
||||
hypothesis (List[str]): Inference result.
|
||||
target (List[str]): Reference.
|
||||
"""
|
||||
if not hypothesis or not target:
|
||||
raise ValueError(f"`hypothesis` and `target` can not be None.")
|
||||
_rouge = Rouge()
|
||||
print("hypothesis:", hypothesis)
|
||||
print("target:", target)
|
||||
scores = _rouge.get_scores(hypothesis, target, avg=True)
|
||||
print(" | ROUGE Score:")
|
||||
print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}")
|
||||
print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}")
|
||||
print(f" | RG-L(F): {scores['rouge-l']['f'] * 100:8.2f}")
|
||||
return scores
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
task utils
|
||||
"""
|
||||
import regex as re
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
# for lambada task
|
||||
def extract_logits(logits=None, position=None):
|
||||
"""
|
||||
Args
|
||||
logits (Tensor): Tensor(batch_size,seq_length,vocab_size) e.g.(8,1024,50257)
|
||||
position (numpy.array): the array stored the fianl word position, shape with [batch_size, 2]
|
||||
|
||||
Return:
|
||||
output_logits (Tensor): extract the Specified logit according to the position,
|
||||
shape with [batch_size, vocab_size]
|
||||
"""
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
word_logits_pos = int(position[batch_idx, 0] - 1)
|
||||
|
||||
logit = logits[batch_idx:batch_idx+1:1, word_logits_pos, ::] # [1, vocab_size]
|
||||
if batch_idx == 0:
|
||||
output_logits = logit
|
||||
else:
|
||||
output_logits = P.Concat()((output_logits, logit)) # [batch_size, vocab_size]
|
||||
|
||||
return output_logits
|
||||
|
||||
|
||||
def get_final_word_label(input_ids, input_length, tokenizer=None):
|
||||
"""
|
||||
get whole word label_str from input_ids
|
||||
Args:
|
||||
input_ids: Tensor(batch_size,seq_length), indices of input text
|
||||
config: GPT2Config, config of GPT2 model, if not initiated,
|
||||
this function will create a MockConfig by params of input_ids, optional
|
||||
tokenizer: GPT2Tokenizer, if not initiated, it will be created using the default setting in utils. tokenization,
|
||||
optional
|
||||
Returns:
|
||||
batch_word_label: [str], lastword str given lambada as label
|
||||
"""
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
input_length_np = input_length.asnumpy()
|
||||
batch_word_label = []
|
||||
|
||||
for batch_idx in range(len(input_ids_np)):
|
||||
word_spos = input_length_np[batch_idx, 0]
|
||||
word_epos = input_length_np[batch_idx, 1]
|
||||
final_word_ids = input_ids_np[batch_idx, word_spos:word_epos]
|
||||
final_word_str = tokenizer.decode(final_word_ids.tolist())
|
||||
batch_word_label.append(final_word_str)
|
||||
|
||||
return batch_word_label
|
||||
|
||||
|
||||
def calculate_final_word_loss(logits, batch_size, input_ids, input_length, loss):
|
||||
"""
|
||||
Calculate the last word loss.
|
||||
"""
|
||||
logits = logits.asnumpy()
|
||||
input_len_np = input_length.asnumpy()
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
|
||||
sum_batch_loss = 0.0
|
||||
|
||||
for batch in range(batch_size):
|
||||
lastword_spos = input_len_np[batch, 0]
|
||||
lastword_epos = input_len_np[batch, 1]
|
||||
|
||||
last_word_logits = logits[batch, lastword_spos - 1:lastword_epos - 1:1, ::]
|
||||
last_word_logits_tensor = Tensor(last_word_logits, mstype.float32)
|
||||
|
||||
last_word_label = input_ids_np[batch, lastword_spos:lastword_epos:1]
|
||||
print("last word label: ", last_word_label)
|
||||
last_word_label_tensor = Tensor(last_word_label, mstype.int32)
|
||||
|
||||
last_word_loss = loss(last_word_logits_tensor, last_word_label_tensor)
|
||||
last_word_loss = float(last_word_loss.asnumpy())
|
||||
sum_batch_loss += last_word_loss
|
||||
print(" | loss: ", last_word_loss)
|
||||
|
||||
avg_batch_loss = float(sum_batch_loss / batch_size)
|
||||
return avg_batch_loss
|
||||
|
||||
|
||||
# for cbt task
|
||||
def calculate_choice_prob_for_cbt(logits, batch_size, input_length, input_ids):
|
||||
"""
|
||||
calculate choice prob for cbt
|
||||
Args:
|
||||
logits:
|
||||
batch_size: Any
|
||||
input_length: {asnumpy}
|
||||
input_ids: {asnumpy}
|
||||
|
||||
Returns:
|
||||
choice_prob: List[float]
|
||||
|
||||
"""
|
||||
choice_prob = [] # [batch_size]
|
||||
logits = logits.asnumpy()
|
||||
input_len_np = input_length.asnumpy()
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
|
||||
for batch in range(batch_size):
|
||||
sum_ = 0.0
|
||||
rest_spos = input_len_np[batch, 0]
|
||||
rest_epos = input_len_np[batch, 1] + 1
|
||||
for rest_pos in range(rest_spos - 1, rest_epos - 1):
|
||||
rest_token_id = input_ids_np[batch, rest_pos + 1]
|
||||
log_prob = logits[batch, rest_pos, rest_token_id]
|
||||
sum_ = sum_ + log_prob
|
||||
choice_prob.append(sum_)
|
||||
print("rest sentence prob: ", sum_)
|
||||
|
||||
return choice_prob
|
||||
|
||||
|
||||
# for summarization task
|
||||
def modify_paramdict(param_dict, mode="zero-shot", model_prefix="gpt2."):
|
||||
"""
|
||||
modify keys of param_dict to fit model.
|
||||
|
||||
Args:
|
||||
param_dic: dict, dictionary of parameters imported from a ckpt file
|
||||
mode: str, "zero-shot" for an pretrained GPT2 model;
|
||||
"finetune" for an finetuned model for certain task.
|
||||
Return:
|
||||
reorganized_param_dict: dict, new param_dict to fit in model for different tasks.
|
||||
"""
|
||||
final_param_dict = dict()
|
||||
if mode == "zero-shot":
|
||||
for name in param_dict:
|
||||
final_param_dict[model_prefix + name] = param_dict[name]
|
||||
final_param_dict['lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
||||
elif mode == "finetuned":
|
||||
embedding_name = "gpt2_embedding_lookup.embedding_table"
|
||||
embedding_name_old = ""
|
||||
for name in param_dict:
|
||||
name_remove_prefix = name[len(model_prefix):]
|
||||
name_prefix = name[:len(model_prefix)]
|
||||
final_param_dict[name_remove_prefix] = param_dict[name]
|
||||
if embedding_name in name and name_prefix == model_prefix:
|
||||
embedding_name_old = name
|
||||
final_param_dict[embedding_name] = param_dict[embedding_name_old]
|
||||
else:
|
||||
raise ValueError("mode should be [zero-shot, finetuned]")
|
||||
return final_param_dict
|
||||
|
||||
|
||||
def clean_hypo(text):
|
||||
"""
|
||||
to prevent generation of empty string, and lower text
|
||||
|
||||
Arg:
|
||||
text: str, input str
|
||||
Return:
|
||||
text: str, cleaned input str
|
||||
"""
|
||||
text = text.lower()
|
||||
eng_re = re.compile(r'[a-z]+', re.I)
|
||||
length_con = len(eng_re.findall(text))
|
||||
if length_con == 0:
|
||||
return '<EMPTY>'
|
||||
return text
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
tensor manipulations
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def extract_string_from_tensor(input_ids, mode="single", config=None, tokenizer=None):
|
||||
"""
|
||||
Args:
|
||||
input_ids (Tensor): input sentences with shape [batch_size, seq_len].
|
||||
mode (str): ["pair", "single"]
|
||||
"pair" for tasks with paired inputs `<bos> A <eos> B <eos>`,
|
||||
such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`,
|
||||
reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`.
|
||||
|
||||
"single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task.
|
||||
config: the configuration of GPT-2 model.
|
||||
tokenizer: the tokenizer of GPT-2 model.
|
||||
|
||||
Return:
|
||||
prompt_list (list): list of prompt_text
|
||||
reference_list (list): list of reference_text, or second part of text
|
||||
rest_list (list): list of rest_text, or rest part of text
|
||||
|
||||
"""
|
||||
|
||||
batch_size = config.batch_size
|
||||
seq_length = config.seq_length
|
||||
prompt_list = [""] * batch_size
|
||||
reference_list = [""] * batch_size
|
||||
eos_text = tokenizer.eos_token
|
||||
len_eos_text = len(eos_text)
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
input_ids_np = input_ids_np.reshape((batch_size, seq_length))
|
||||
# input_ids = P.Reshape()(input_ids, (batch_size, seq_length))
|
||||
|
||||
if mode == "pair":
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
sentence_tensor = input_ids_np[batch_idx]
|
||||
sentence_list = sentence_tensor.asnumpy().tolist()[1:]
|
||||
|
||||
sentence = tokenizer.decode(sentence_list)
|
||||
prompt_start = 0
|
||||
prompt_end = sentence.find(eos_text, 0)
|
||||
reference_start = prompt_end + len_eos_text
|
||||
reference_end = sentence[reference_start:].find(
|
||||
eos_text, 0) + reference_start
|
||||
prompt_list[batch_idx] = sentence[prompt_start:prompt_end]
|
||||
reference_list[batch_idx] = sentence[reference_start:reference_end]
|
||||
|
||||
return prompt_list, reference_list
|
||||
|
||||
# For single output datasets such as WikiText, etc.
|
||||
if mode == "single":
|
||||
for batch_idx in range(batch_size):
|
||||
sentence_tensor = input_ids_np[batch_idx]
|
||||
sentence_list = sentence_tensor.asnumpy().tolist()[1:]
|
||||
|
||||
sentence = tokenizer.decode(sentence_list)
|
||||
prompt_start = 0
|
||||
prompt_end = sentence.find(eos_text, 0)
|
||||
prompt_list[batch_idx] = sentence[prompt_start:prompt_end]
|
||||
else:
|
||||
raise NotImplementedError('mode:{} not supported.'.format(mode))
|
||||
|
||||
return prompt_list
|
||||
|
||||
|
||||
def extract_single_token_logits(logits=None, seq_pos=None):
|
||||
"""
|
||||
Args
|
||||
logits: (batch_size,seq_length,vocab_size) e.g. when batchsize is 8,
|
||||
sequence length is 1024 and vocab_size is 50257,
|
||||
then logits is a Tensor with shape (8,1024,50257)
|
||||
seq_pos:(batch_size) list
|
||||
Return:
|
||||
output_logits: (batch_size,1,vocab_size) extract the logit to predict the last token.
|
||||
"""
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
logits_np = logits.asnumpy()
|
||||
logits_type = P.DType()(logits)
|
||||
for i in range(batch_size):
|
||||
# logit = logits[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::]
|
||||
logit_np = logits_np[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::]
|
||||
if i == 0:
|
||||
# output_logits = logit
|
||||
output_logits = logit_np
|
||||
else:
|
||||
# output_logits = P.Concat()((output_logits, logit))
|
||||
output_logits = np.concatenate((output_logits, logit_np), axis=0)
|
||||
|
||||
output_logits = Tensor(output_logits, dtype=logits_type)
|
||||
|
||||
return output_logits
|
||||
|
||||
|
||||
def get_last_one_pos(input_mask: Tensor):
|
||||
"""
|
||||
Arg:
|
||||
input_mask (Tensor): (batch_size,seq_length)
|
||||
Return:
|
||||
pos (Tensor): (batch_size,)
|
||||
"""
|
||||
input_mask_ = P.Cast()(input_mask, mstype.float32)
|
||||
pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,)
|
||||
pos = P.Cast()(pos, mstype.int32)
|
||||
pos = pos - 1
|
||||
return pos
|
||||
|
||||
|
||||
def get_next_one_pos(input_mask: Tensor):
|
||||
"""
|
||||
Arg:
|
||||
input_mask (Tensor): (batch_size,seq_length)
|
||||
"""
|
||||
input_mask_ = P.Cast()(input_mask, mstype.float32)
|
||||
pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,)
|
||||
pos = P.Cast()(pos, mstype.int32)
|
||||
return pos
|
||||
|
||||
|
||||
def add_last_token_mask(input_mask: Tensor, overflow_strategy: str = "shift"):
|
||||
"""
|
||||
add last token mask
|
||||
Args:
|
||||
input_mask: Tensor
|
||||
overflow_strategy: str
|
||||
|
||||
Returns:
|
||||
Tensor
|
||||
|
||||
"""
|
||||
pos = get_next_one_pos(input_mask).asnumpy()
|
||||
input_mask_np = input_mask.asnumpy()
|
||||
maximum_length = input_mask.shape[1]
|
||||
batch_size = input_mask.shape[0]
|
||||
for idx in range(batch_size):
|
||||
# not overflow
|
||||
if pos[idx] < maximum_length:
|
||||
input_mask_np[idx][pos[idx]] = 1
|
||||
|
||||
# overflow
|
||||
else:
|
||||
if overflow_strategy == "shift":
|
||||
continue
|
||||
if overflow_strategy == "truncate":
|
||||
continue
|
||||
else:
|
||||
raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy))
|
||||
return Tensor(input_mask_np, dtype=mstype.int32)
|
||||
|
||||
|
||||
def add_last_token(input_ids: Tensor, input_mask: Tensor, overflow_strategy: str = "shift", append_ids=None,
|
||||
next_token_pos=None):
|
||||
"""
|
||||
add last token
|
||||
Args:
|
||||
input_ids: Tensor
|
||||
input_mask: Tensor
|
||||
overflow_strategy: str
|
||||
append_ids: Any
|
||||
next_token_pos: Any
|
||||
|
||||
Returns:
|
||||
Tensor
|
||||
|
||||
"""
|
||||
# get positional list/numpy array
|
||||
if next_token_pos is None:
|
||||
pos = get_next_one_pos(input_mask).asnumpy()
|
||||
else:
|
||||
pos = next_token_pos
|
||||
# get numpy of inputs
|
||||
input_mask_np = input_mask.asnumpy()
|
||||
input_ids_np = input_ids.asnumpy()
|
||||
maximum_length = int(input_mask.shape[1])
|
||||
batch_size = int(input_mask.shape[0])
|
||||
|
||||
for idx in range(batch_size):
|
||||
# not overflow
|
||||
if pos[idx] < maximum_length:
|
||||
input_mask_np[idx][int(pos[idx])] = 1
|
||||
input_ids_np[idx][int(pos[idx])] = append_ids[idx]
|
||||
|
||||
# overflow
|
||||
else:
|
||||
if overflow_strategy == "shift":
|
||||
# shift one token left
|
||||
input_ids_np[idx][0:maximum_length - 1] = input_ids_np[idx][1:maximum_length]
|
||||
input_ids_np[idx][maximum_length - 1] = append_ids[idx]
|
||||
continue
|
||||
if overflow_strategy == "truncate":
|
||||
# do nothing
|
||||
continue
|
||||
else:
|
||||
raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy))
|
||||
return Tensor(input_ids_np, dtype=mstype.int32), Tensor(input_mask_np, dtype=mstype.int32)
|
|
@ -0,0 +1,517 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
tokenization
|
||||
"""
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
import regex as re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
bytes to unicode
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [chr(i) for i in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class GPT2Tokenizer():
|
||||
"""
|
||||
GPT2Tokenizer
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merge_file,
|
||||
add_prefix_space=False,
|
||||
):
|
||||
with open(vocab_file, 'r', encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.vocab_size = len(self.decoder)
|
||||
with open(merge_file, 'r', encoding="utf-8") as merge_handle:
|
||||
bpe_merges = merge_handle.read().split('\n')[1:-1]
|
||||
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
||||
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.cache = {}
|
||||
|
||||
self.unk_token = "<|endoftext|>"
|
||||
self.unk_token_id = 50256
|
||||
self.bos_token = "<|endoftext|>"
|
||||
self.bos_token_id = 50256
|
||||
self.eos_token = "<|endoftext|>"
|
||||
self.eos_token_id = 50256
|
||||
self.pad_token = "<|endoftext|>"
|
||||
self.pad_token_id = 50256
|
||||
|
||||
def bpe(self, token):
|
||||
"""
|
||||
bpe encode
|
||||
"""
|
||||
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(token)
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i + 1 < len(word) and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string using bpe encode. """
|
||||
text = self.prepare_for_tokenization(text, is_pretokenized=False)
|
||||
# print(text)
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(
|
||||
self.byte_encoder[b] for b in token.encode("utf-8")
|
||||
)
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" the index of the token in the vocabulary. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, _id):
|
||||
""" return the origin bpe token according to id"""
|
||||
return self.decoder.get(_id)
|
||||
|
||||
def _convert_tokens_to_string(self, tokens):
|
||||
""" return a string according to the list of tokens"""
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors='ignore')
|
||||
return text
|
||||
|
||||
def encode(self, text):
|
||||
""" get the index list of text"""
|
||||
text_id = []
|
||||
bpe_tokens = self._tokenize(text)
|
||||
for token in bpe_tokens:
|
||||
text_id.append(self._convert_token_to_id(token))
|
||||
return text_id
|
||||
|
||||
def decode(self, ids):
|
||||
""" return a string according to the index list of tokens"""
|
||||
tokens = []
|
||||
for id_ in ids:
|
||||
tokens.append(self._convert_id_to_token(id_))
|
||||
return self._convert_tokens_to_string(tokens)
|
||||
|
||||
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
|
||||
""" whether to add a whitespace in the front of text """
|
||||
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
||||
if is_pretokenized or add_prefix_space:
|
||||
text = " " + text
|
||||
return text
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
"""
|
||||
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
|
||||
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
|
||||
current vocabulary).
|
||||
Args:
|
||||
special_tokens_dict (dictionary `str` to `str`):
|
||||
Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``,
|
||||
``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
|
||||
``additional_special_tokens``].
|
||||
|
||||
Returns:
|
||||
added_tokens (int): Number of tokens added to the vocabulary
|
||||
"""
|
||||
# special_tokens_dict = {'cls_token': '<CLS>'}
|
||||
if not special_tokens_dict:
|
||||
return 0
|
||||
|
||||
added_tokens = 0
|
||||
for key, value in special_tokens_dict.items():
|
||||
setattr(self, key, value)
|
||||
assert isinstance(value, str), f"Token {value} for key {key} should be a str instance"
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
return added_tokens
|
||||
|
||||
def add_tokens(self, new_tokens, special_tokens=False):
|
||||
if not new_tokens:
|
||||
return 0
|
||||
if not isinstance(new_tokens, (list, tuple)):
|
||||
new_tokens = [new_tokens]
|
||||
return self._add_tokens(new_tokens, special_tokens=special_tokens)
|
||||
|
||||
def _add_tokens(self, new_tokens, special_tokens=False):
|
||||
"""
|
||||
_add_tokens
|
||||
Args:
|
||||
new_tokens (list[str]): Token(s) to add in vocabulary.
|
||||
special_tokens (bool): Whether or not the tokens should be added as special tokens.
|
||||
|
||||
Returns:
|
||||
the number of the new added tokens.
|
||||
"""
|
||||
new_tokens = [str(token) for token in new_tokens]
|
||||
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str)
|
||||
tokens_to_add.append(token)
|
||||
logger.info("Adding %s to the vocabulary ! ", token)
|
||||
|
||||
added_tok_encoder = dict((tok, self.vocab_size + i) for i, tok in enumerate(tokens_to_add))
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
self.encoder.update(added_tok_encoder)
|
||||
self.decoder.update(added_tok_decoder)
|
||||
return len(tokens_to_add)
|
||||
|
||||
def num_special_tokens_to_add(self, pair: bool = False):
|
||||
token_ids_0 = []
|
||||
token_ids_1 = []
|
||||
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence by concatenating and adding special tokens.
|
||||
|
||||
A GPT2 sequence has the following format:
|
||||
- single sequence: ``<bos> X <eos>``
|
||||
- pair of sequences: ``<bos> A <eos> B <eos>``
|
||||
|
||||
Args:
|
||||
token_ids_0 (List[int]): List of IDs to which the special tokens will be added
|
||||
token_ids_1 (List[int], `optional`, defaults to `None`): Optional second list of IDs for sequence pairs.
|
||||
"""
|
||||
bos = [self.bos_token_id]
|
||||
eos = [self.eos_token_id]
|
||||
if token_ids_1 is None:
|
||||
return bos + token_ids_0 + eos
|
||||
return bos + token_ids_0 + eos + token_ids_1 + eos
|
||||
|
||||
def truncate_sequences(self, ids, num_tokens_to_remove, truncation_strategy="ONLY_FIRST", direction="RIGHT"):
|
||||
"""
|
||||
truncate sequences
|
||||
Args:
|
||||
ids: Any
|
||||
num_tokens_to_remove:
|
||||
truncation_strategy: str
|
||||
direction: str
|
||||
|
||||
Returns:
|
||||
(ids, overflowing_tokens): (Any, list)
|
||||
|
||||
"""
|
||||
if num_tokens_to_remove <= 0:
|
||||
return ids, []
|
||||
|
||||
overflowing_tokens = []
|
||||
if truncation_strategy == "ONLY_FIRST":
|
||||
if len(ids) > num_tokens_to_remove:
|
||||
if direction == "RIGHT":
|
||||
overflowing_tokens = ids[-num_tokens_to_remove:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
if direction == "LEFT":
|
||||
overflowing_tokens = ids[:num_tokens_to_remove]
|
||||
ids = ids[num_tokens_to_remove:]
|
||||
else:
|
||||
logger.error("The first sequence length is smaller than removed tokens. ")
|
||||
else:
|
||||
logger.error("Please select correct truncation strategy, for instance 'ONLY_FIRST'")
|
||||
return (ids, overflowing_tokens)
|
||||
|
||||
def _pad(self, encoded_inputs, max_length=None, padding_strategy=None,
|
||||
return_attention_mask: Optional[bool] = None):
|
||||
"""
|
||||
_pad
|
||||
Args:
|
||||
encoded_inputs:
|
||||
max_length: Any
|
||||
padding_strategy: Any
|
||||
return_attention_mask: Optional[bool]
|
||||
|
||||
Returns:
|
||||
encoded_inputs:
|
||||
|
||||
"""
|
||||
needs_to_be_padded = (len(encoded_inputs["input_ids"]) != max_length)
|
||||
if needs_to_be_padded:
|
||||
if padding_strategy == "MAX_LENGTH":
|
||||
difference = max_length - len(encoded_inputs["input_ids"])
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy")
|
||||
else:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def pad(self, encoded_inputs, max_length: Optional[int] = None, padding_strategy="MAX_LENGTH",
|
||||
return_attention_mask=True):
|
||||
"""
|
||||
pad
|
||||
Args:
|
||||
encoded_inputs:
|
||||
max_length: Optional[int]
|
||||
padding_strategy: str
|
||||
return_attention_mask: bool
|
||||
|
||||
Returns:
|
||||
batch_outputs: Dict[Any, list]
|
||||
|
||||
"""
|
||||
# no batch encoded_inputs["input_ids"]--->[98, 67, 32388, 318, 1912, 287, 170, 8496, 318, 905, 2667, 32]
|
||||
if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
|
||||
encoded_inputs = self._pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding_strategy=padding_strategy,
|
||||
return_attention_mask=return_attention_mask
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
# encoded_inputs with batch_size
|
||||
batch_size = len(encoded_inputs["input_ids"])
|
||||
assert all(
|
||||
len(v) == batch_size for v in encoded_inputs.values()
|
||||
), "Some items in the output dictionary have a different batch size than others."
|
||||
|
||||
if padding_strategy == "LONGEST":
|
||||
max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
|
||||
padding_strategy = "MAX_LENGTH"
|
||||
|
||||
batch_outputs = {}
|
||||
for i in range(batch_size):
|
||||
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
|
||||
outputs = self._pad(
|
||||
encoded_inputs=inputs,
|
||||
max_length=max_length,
|
||||
padding_strategy=padding_strategy,
|
||||
return_attention_mask=return_attention_mask
|
||||
)
|
||||
for key, value in outputs.items():
|
||||
if key not in batch_outputs:
|
||||
batch_outputs[key] = []
|
||||
batch_outputs[key].append(value)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def prepare_for_model(self,
|
||||
ids,
|
||||
pair_ids=None,
|
||||
add_special_tokens=True,
|
||||
max_length=None,
|
||||
padding=None,
|
||||
truncate_direction="RIGHT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True):
|
||||
"""
|
||||
prepare for model
|
||||
Args:
|
||||
ids:
|
||||
pair_ids:
|
||||
add_special_tokens: bool
|
||||
max_length: Any
|
||||
padding: Any
|
||||
truncate_direction: str
|
||||
return_overflowing_tokens: bool
|
||||
return_attention_mask: bool
|
||||
|
||||
Returns:
|
||||
encoded_inputs:Dict
|
||||
|
||||
"""
|
||||
|
||||
pair = bool(pair_ids is not None)
|
||||
len_ids = len(ids)
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
encoded_inputs = {}
|
||||
# Compute the total size of the returned encodings
|
||||
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
||||
|
||||
# Truncation: Handle max sequence length
|
||||
if max_length and total_len > max_length:
|
||||
|
||||
ids, overflowing_tokens = self.truncate_sequences(ids=ids,
|
||||
num_tokens_to_remove=total_len - max_length,
|
||||
truncation_strategy="ONLY_FIRST",
|
||||
direction=truncate_direction)
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
|
||||
# build output dictionary
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
# check lengths
|
||||
if max_length is None or len(encoded_inputs["input_ids"]) > max_length:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model (%ids > %length). Running this sequence through the model will result in "
|
||||
"indexing errors", len(ids), max_length
|
||||
)
|
||||
# padding
|
||||
if padding or return_attention_mask:
|
||||
encoded_inputs = self.pad(encoded_inputs=encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding_strategy="MAX_LENGTH",
|
||||
return_attention_mask=return_attention_mask)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
class CNN_DailyMail_tokenizer(GPT2Tokenizer):
|
||||
"""
|
||||
CNN DailyMail tokenizer
|
||||
"""
|
||||
def prepare_for_model(self,
|
||||
ids,
|
||||
pair_ids,
|
||||
max_length=1024,
|
||||
max_summary_length=150,
|
||||
add_special_tokens=True,
|
||||
padding=None,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True):
|
||||
len_ids = len(ids)
|
||||
len_pair_ids = len(pair_ids)
|
||||
encoded_inputs = {}
|
||||
# Compute the total size of the returned encodings
|
||||
total_len = len_ids + len_pair_ids
|
||||
ids_overflowing_tokens = []
|
||||
pair_overflowing_tokens = []
|
||||
# Truncation: Handle max sequence length
|
||||
if total_len > max_length-3:
|
||||
if len_pair_ids > max_summary_length:
|
||||
num_tokens_to_remove = len_pair_ids - max_summary_length
|
||||
pair_ids, pair_overflowing_tokens = self.truncate_sequences(ids=pair_ids,
|
||||
num_tokens_to_remove=num_tokens_to_remove,
|
||||
truncation_strategy="ONLY_FIRST",
|
||||
direction="RIGHT")
|
||||
if len_ids+max_summary_length > max_length-3:
|
||||
num_tokens_to_remove = (len_ids + max_summary_length) - (max_length - 3)
|
||||
ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids,
|
||||
num_tokens_to_remove=num_tokens_to_remove,
|
||||
truncation_strategy="ONLY_FIRST",
|
||||
direction="RIGHT")
|
||||
else:
|
||||
ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids,
|
||||
num_tokens_to_remove=total_len - (max_length-3),
|
||||
truncation_strategy="ONLY_FIRST",
|
||||
direction="RIGHT")
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["article_overflowing_tokens"] = ids_overflowing_tokens
|
||||
encoded_inputs["highlights_overflowing_tokens"] = pair_overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - (max_length-3)
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
# check lengths
|
||||
if max_length is None or len(encoded_inputs["input_ids"]) > max_length:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model (%ids > %length). Running this sequence through the model will result "
|
||||
"in indexing errors", len(ids), max_length
|
||||
)
|
||||
# padding
|
||||
if padding or return_attention_mask:
|
||||
encoded_inputs = self.pad(encoded_inputs=encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding_strategy="MAX_LENGTH",
|
||||
return_attention_mask=return_attention_mask)
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def Tokenizer(vocab_file="./pretrain-data/gpt2-vocab.json",
|
||||
merge_file="./pretrain-data/gpt2-merges.txt",
|
||||
mode="normal"):
|
||||
""" use the GPT2Tokenizer"""
|
||||
print(" | Tokenizer mode: {}".format(mode))
|
||||
if mode == "normal":
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merge_file, add_prefix_space=False)
|
||||
elif mode == "cnn_dailymail":
|
||||
tokenizer = CNN_DailyMail_tokenizer(vocab_file, merge_file, add_prefix_space=False)
|
||||
else:
|
||||
raise ValueError("No Such Mode for {} in src.utils.tokenization.Tokenizer()".format(mode))
|
||||
return tokenizer
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
init weight
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def _average_units(shape):
|
||||
if not shape:
|
||||
return 1
|
||||
if len(shape) == 1:
|
||||
return float(shape[0])
|
||||
if len(shape) == 2:
|
||||
return float(shape[0] + shape[1]) / 2.
|
||||
raise RuntimeError("not support shape.")
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
scale_shape = shape
|
||||
avg_units = _average_units(scale_shape)
|
||||
scale = 1.0 / max(1., avg_units)
|
||||
limit = math.sqrt(3.0 * scale)
|
||||
values = np.random.uniform(-limit, limit, shape).astype(np.float32)
|
||||
return Tensor(values)
|
||||
|
||||
|
||||
def one_weight(shape):
|
||||
ones = np.ones(shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
def zero_weight(shape):
|
||||
zeros = np.zeros(shape).astype(np.float32)
|
||||
return Tensor(zeros)
|
||||
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset preprocess"""
|
||||
|
||||
import argparse
|
||||
|
||||
from src.utils.data_preprocess import lambada_dataset_preprocess
|
||||
from src.utils.data_preprocess import cbt_dataset_preprocess
|
||||
from src.utils.data_preprocess import wikitext_dataset_preprocess
|
||||
from src.utils.data_preprocess import ptb_dataset_preprocess
|
||||
from src.utils.data_preprocess import onebw_dataset_preprocess
|
||||
from src.utils.data_preprocess import coqa_dataset_preprocess
|
||||
from src.utils.data_preprocess import wmt14_en_fr_preprocess
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
|
||||
parser.add_argument("--task", type=str, default="translation",
|
||||
help="The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada"
|
||||
"Summarization, ReadingComprehension]")
|
||||
parser.add_argument("--input_file", type=str, default="",
|
||||
help="The raw dataset path. ")
|
||||
parser.add_argument("--dataset", type=str, default="onebw",
|
||||
help="The name of dataset which should be processed, only for LanguageModeling task.")
|
||||
parser.add_argument("--output_file", type=str, default="",
|
||||
help="The output dataset path after preprocessing.")
|
||||
parser.add_argument("--condition", type=str, default="test",
|
||||
help="Process train or test dataset, including [train, test], only for 1BW and "
|
||||
"CNN & DailyMail dataset.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
task = args_opt.task
|
||||
condition = args_opt.condition
|
||||
dataset = args_opt.dataset
|
||||
input_file = args_opt.input_file
|
||||
output_file = args_opt.output_file
|
||||
|
||||
if task.lower() == "languagemodeling":
|
||||
print("Start processing Language Modeling dataset ...")
|
||||
if dataset.lower() == "wikitext2" or dataset.lower() == "wikitext103":
|
||||
wikitext_dataset_preprocess(input_file=input_file, output_file=output_file)
|
||||
elif dataset.lower() == "ptb":
|
||||
ptb_dataset_preprocess(input_file=input_file, output_file=output_file)
|
||||
elif dataset.lower() == "onebw":
|
||||
onebw_dataset_preprocess(condition, input_file=input_file, output_file=output_file)
|
||||
else:
|
||||
raise ValueError("Only support wikitext2, wikitext103, ptb, onebw dataset")
|
||||
|
||||
elif task.lower() == "lambada":
|
||||
print("Start processing Lambada dataset ...")
|
||||
lambada_dataset_preprocess(input_file=input_file, output_file=output_file)
|
||||
|
||||
elif task.lower() == "cbt":
|
||||
print("Start processing CBT dataset ...")
|
||||
cbt_dataset_preprocess(input_file=input_file, output_file=output_file)
|
||||
|
||||
elif task.lower() == "readingcomprehension":
|
||||
print("Start processing ReadingComprehension dataset ...")
|
||||
coqa_dataset_preprocess(input_file=input_file, output_file=output_file)
|
||||
|
||||
elif task.lower() == "summarization":
|
||||
print("Start processing Summarization dataset ...")
|
||||
|
||||
elif task.lower() == "translation":
|
||||
print("Start processing Translation dataset ...")
|
||||
wmt14_en_fr_preprocess(input_file=input_file, output_file=output_file)
|
||||
|
||||
else:
|
||||
raise ValueError("Only support Language Modeling, CBT, Translation, Lambada, "
|
||||
"Summarization, Reading Comprehension task.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Python implementation of BLEU and smooth-BLEU.
|
||||
This module provides a Python implementation of BLEU and smooth-BLEU.
|
||||
Smooth BLEU is computed following the method outlined in the paper:
|
||||
Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
|
||||
evaluation metrics for machine translation. COLING 2004.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""
|
||||
Extracts all n-grams upto a given maximum order from an input segment.
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i:i + order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
|
||||
smooth=False):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus, translation_corpus):
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram) - 1] += overlap[ngram]
|
||||
for order in range(1, max_order + 1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order - 1] += possible_matches
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = ((matches_by_order[i] + 1.) /
|
||||
(possible_matches_by_order[i] + 1.))
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (float(matches_by_order[i]) /
|
||||
possible_matches_by_order[i])
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.
|
||||
else:
|
||||
bp = math.exp(1 - 1. / ratio)
|
||||
|
||||
bleu = geo_mean * bp
|
||||
|
||||
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue