forked from mindspore-Ecosystem/mindspore
enhance: add example for zhwiki and CLUERNER2020 to mindrecord
This commit is contained in:
parent
5306172fee
commit
56d03f9eb9
|
@ -0,0 +1,82 @@
|
|||
# Guideline to Convert Training Data CLUERNER2020 to MindRecord For Bert Fine Tuning
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [How to use the example to process CLUERNER2020](#how-to-use-the-example-to-process-cluerner2020)
|
||||
- [Download CLUERNER2020 and unzip](#download-cluerner2020-and-unzip)
|
||||
- [Generate MindRecord](#generate-mindrecord)
|
||||
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example is based on [CLUERNER2020](https://www.cluebenchmarks.com/introduce.html) training data, generating MindRecord file, and finally used for Bert Fine Tuning progress.
|
||||
|
||||
1. run.sh: generate MindRecord entry script
|
||||
2. run_read.py: create MindDataset by MindRecord entry script.
|
||||
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
|
||||
|
||||
## How to use the example to process CLUERNER2020
|
||||
|
||||
Download CLUERNER2020, convert it to MindRecord, use MindDataset to read MindRecord.
|
||||
|
||||
### Download CLUERNER2020 and unzip
|
||||
|
||||
1. Download the training data zip.
|
||||
> [CLUERNER2020 dataset download address](https://www.cluebenchmarks.com/introduce.html) **-> 任务介绍 -> CLUENER 细粒度命名实体识别 -> cluener下载链接**
|
||||
|
||||
2. Unzip the training data to dir example/nlp_to_mindrecord/CLUERNER2020/cluener_public.
|
||||
```
|
||||
unzip -d {your-mindspore}/example/nlp_to_mindrecord/CLUERNER2020/data/cluener_public cluener_public.zip
|
||||
```
|
||||
|
||||
### Generate MindRecord
|
||||
|
||||
1. Run the run.sh script.
|
||||
```bash
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
2. Output like this:
|
||||
```
|
||||
...
|
||||
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:12.498.235 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/train.mindrecord'], and the list of index files are: ['data/train.mindrecord.db']
|
||||
...
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.400.175 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.400.863 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.401.534 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.402.179 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.402.702 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
...
|
||||
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:13.431.208 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/dev.mindrecord'], and the list of index files are: ['data/dev.mindrecord.db']
|
||||
```
|
||||
|
||||
3. Generate files like this:
|
||||
```bash
|
||||
$ ls output/
|
||||
dev.mindrecord dev.mindrecord.db README.md train.mindrecord train.mindrecord.db
|
||||
```
|
||||
|
||||
### Create MindDataset By MindRecord
|
||||
|
||||
1. Run the run_read.sh script.
|
||||
```bash
|
||||
bash run_read.sh
|
||||
```
|
||||
|
||||
2. Output like this:
|
||||
```
|
||||
...
|
||||
example 1340: input_ids: [ 101 3173 1290 4852 7676 3949 122 3299 123 126 3189 4510 8020 6381 5442 7357 2590 3636 8021 7676 3949 4294 1166 6121 3124 1277 6121 3124 7270 2135 3295 5789 3326 123 126 3189 1355 6134 1093 1325 3173 2399 6590 6791 8024 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: label_ids: [ 0 18 19 20 2 4 0 0 0 0 0 0 0 34 36 26 27 28 0 34 35 35 35 35 35 35 35 35 35 36 26 27 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: input_ids: [ 101 1728 711 4293 3868 1168 2190 2150 3791 934 3633 3428 4638 6237 7025 8024 3297 1400 5310 3362 6206 5023 5401 1744 3297 7770 3791 7368 976 1139 1104 2137 511 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: label_ids: [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18 19 19 19 19 20 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
...
|
||||
```
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create MindDataset by MindRecord"""
|
||||
import mindspore.dataset as ds
|
||||
|
||||
def create_dataset(data_file):
|
||||
"""create MindDataset"""
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
|
||||
index = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
# print("example {}: {}".format(index, item))
|
||||
print("example {}: input_ids: {}".format(index, item['input_ids']))
|
||||
print("example {}: input_mask: {}".format(index, item['input_mask']))
|
||||
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
|
||||
print("example {}: label_ids: {}".format(index, item['label_ids']))
|
||||
index += 1
|
||||
if index % 1000 == 0:
|
||||
print("read rows: {}".format(index))
|
||||
print("total rows: {}".format(index))
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_dataset('output/train.mindrecord')
|
||||
create_dataset('output/dev.mindrecord')
|
|
@ -0,0 +1 @@
|
|||
cluener_public
|
|
@ -0,0 +1 @@
|
|||
## The input dataset
|
|
@ -0,0 +1 @@
|
|||
## output dir
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/train.mindrecord*
|
||||
rm -f output/dev.mindrecord*
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/CLUERNER2020" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/CLUERNER2020 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for data_processor_seq.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/CLUERNER2020/ -o data_processor_seq_patched.py < ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# use patched script
|
||||
python ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq_patched.py \
|
||||
--vocab_file=../../../third_party/to_mindrecord/CLUERNER2020/vocab.txt \
|
||||
--label2id_file=../../../third_party/to_mindrecord/CLUERNER2020/label2id.json
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python create_dataset.py
|
|
@ -0,0 +1,113 @@
|
|||
# Guideline to Convert Training Data zhwiki to MindRecord For Bert Pre Training
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [Run simple test](#run-simple-test)
|
||||
- [How to use the example to process zhwiki](#how-to-use-the-example-to-process-zhwiki)
|
||||
- [Download zhwiki training data](#download-zhwiki-training-data)
|
||||
- [Extract the zhwiki](#extract-the-zhwiki)
|
||||
- [Generate MindRecord](#generate-mindrecord)
|
||||
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example is based on [zhwiki](https://dumps.wikimedia.org/zhwiki) training data, generating MindRecord file, and finally used for Bert network training.
|
||||
|
||||
1. run.sh: generate MindRecord entry script.
|
||||
2. run_read.py: create MindDataset by MindRecord entry script.
|
||||
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
|
||||
|
||||
## Run simple test
|
||||
|
||||
Follow the step:
|
||||
|
||||
```bash
|
||||
bash run_simple.sh # generate output/simple.mindrecord* by ../../../third_party/to_mindrecord/zhwiki/sample_text.txt
|
||||
bash run_read_simple.sh # use MindDataset to read output/simple.mindrecord*
|
||||
```
|
||||
|
||||
## How to use the example to process zhwiki
|
||||
|
||||
Download zhwiki data, extract it, convert it to MindRecord, use MindDataset to read MindRecord.
|
||||
|
||||
### Download zhwiki training data
|
||||
|
||||
> [zhwiki dataset download address](https://dumps.wikimedia.org/zhwiki) **-> 20200401 -> zhwiki-20200401-pages-articles-multistream.xml.bz2**
|
||||
|
||||
- put the zhwiki-20200401-pages-articles-multistream.xml.bz2 in {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
|
||||
|
||||
### Extract the zhwiki
|
||||
|
||||
1. Download [wikiextractor](https://github.com/attardi/wikiextractor) script to {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
|
||||
|
||||
```
|
||||
$ ls data/
|
||||
README.md wikiextractor zhwiki-20200401-pages-articles-multistream.xml.bz2
|
||||
```
|
||||
|
||||
2. Extract the zhwiki.
|
||||
```python
|
||||
python data/wikiextractor/WikiExtractor.py data/zhwiki-20200401-pages-articles-multistream.xml.bz2 --processes 4 --templates data/template --bytes 8M --min_text_length 0 --filter_disambig_pages --output data/extract
|
||||
```
|
||||
|
||||
3. Generate like this:
|
||||
```
|
||||
$ ls data/extract
|
||||
AA AB
|
||||
```
|
||||
|
||||
### Generate MindRecord
|
||||
|
||||
1. Run the run.sh script.
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
> Caution: This process maybe slow, please wait patiently. If you do not have a machine with enough memory and cpu, it is recommended that you modify the script to generate mindrecord in step by step.
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
patching file create_pretraining_data_patched.py (read from create_pretraining_data.py)
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_00
|
||||
Begin output file: AAwiki_00.mindrecord
|
||||
Total task: 5, processing: 1
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_01
|
||||
Begin output file: AAwiki_01.mindrecord
|
||||
Total task: 5, processing: 2
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_02
|
||||
Begin output file: AAwiki_02.mindrecord
|
||||
Total task: 5, processing: 3
|
||||
Begin preprocess input file: ./data/extract/AB/wiki_02
|
||||
Begin output file: ABwiki_02.mindrecord
|
||||
Total task: 5, processing: 4
|
||||
...
|
||||
```
|
||||
|
||||
3. Generate files like this:
|
||||
```bash
|
||||
$ ls output/
|
||||
AAwiki_00.mindrecord AAwiki_00.mindrecord.db AAwiki_01.mindrecord AAwiki_01.mindrecord.db AAwiki_02.mindrecord AAwiki_02.mindrecord.db ... ABwiki_00.mindrecord ABwiki_00.mindrecord.db ...
|
||||
```
|
||||
|
||||
### Create MindDataset By MindRecord
|
||||
|
||||
1. Run the run_read.sh script.
|
||||
```bash
|
||||
bash run_read.sh
|
||||
```
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
...
|
||||
example 74: input_ids: [ 101 8168 118 12847 8783 9977 15908 117 8256 9245 11643 8168 8847 8588 11575 8154 8228 143 8384 8376 9197 10241 103 10564 11421 8199 12268 112 161 8228 11541 9586 8436 8174 8363 9864 9702 103 103 119 103 9947 10564 103 8436 8806 11479 103 8912 119 103 103 103 12209 8303 103 8757 8824 117 8256 103 8619 8168 11541 102 11684 8196 103 8228 8847 11523 117 9059 9064 12410 8358 8181 10764 117 11167 11706 9920 148 8332 11390 8936 8205 10951 11997 103 8154 117 103 8670 10467 112 161 10951 13139 12413 117 10288 143 10425 8205 152 10795 8472 8196 103 161 12126 9172 13129 12106 8217 8174 12244 8205 143 103 8461 8277 10628 160 8221 119 102]
|
||||
example 74: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
|
||||
example 74: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
|
||||
example 74: masked_lm_positions: [ 6 22 37 38 40 43 47 50 51 52 55 60 67 76 89 92 98 109 120 0]
|
||||
example 74: masked_lm_ids: [ 8118 8165 8329 8890 8554 8458 119 8850 8565 10392 8174 11467 10291 8181 8549 12718 13139 112 158 0]
|
||||
example 74: masked_lm_weights: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
|
||||
example 74: next_sentence_labels: [0]
|
||||
...
|
||||
```
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create MindDataset by MindRecord"""
|
||||
import argparse
|
||||
import mindspore.dataset as ds
|
||||
|
||||
def create_dataset(data_file):
|
||||
"""create MindDataset"""
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
|
||||
index = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
# print("example {}: {}".format(index, item))
|
||||
print("example {}: input_ids: {}".format(index, item['input_ids']))
|
||||
print("example {}: input_mask: {}".format(index, item['input_mask']))
|
||||
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
|
||||
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
|
||||
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
|
||||
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
|
||||
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
|
||||
index += 1
|
||||
if index % 1000 == 0:
|
||||
print("read rows: {}".format(index))
|
||||
print("total rows: {}".format(index))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
|
||||
args = parser.parse_args()
|
||||
|
||||
create_dataset(args.input_file)
|
|
@ -0,0 +1,3 @@
|
|||
wikiextractor/
|
||||
zhwiki-20200401-pages-articles-multistream.xml.bz2
|
||||
extract/
|
|
@ -0,0 +1 @@
|
|||
## The input dataset
|
|
@ -0,0 +1 @@
|
|||
## Output the mindrecord
|
|
@ -0,0 +1,112 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/*.mindrecord*
|
||||
|
||||
data_dir="./data/extract"
|
||||
file_list=()
|
||||
output_filename=()
|
||||
file_index=0
|
||||
|
||||
function getdir() {
|
||||
elements=`ls $1`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
dir_or_file=$1"/"$element
|
||||
if [ -d $dir_or_file ];
|
||||
then
|
||||
getdir $dir_or_file
|
||||
else
|
||||
file_list[$file_index]=$dir_or_file
|
||||
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
|
||||
mapfile parent_dir < dir_file_list.txt
|
||||
rm dir_file_list.txt >/dev/null 2>&1
|
||||
tmp_output_filename=${parent_dir[${#parent_dir[@]}-2]}${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
|
||||
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
|
||||
file_index=`expr $file_index + 1`
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
getdir "${data_dir}"
|
||||
# echo "The input files: "${file_list[@]}
|
||||
# echo "The output files: "${output_filename[@]}
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for create_pretraining_data.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# get the cpu core count
|
||||
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
|
||||
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
|
||||
|
||||
echo "Begin preprocess `date`"
|
||||
|
||||
# using patched script to generate mindrecord
|
||||
file_list_len=`expr ${#file_list[*]} - 1`
|
||||
for index in $(seq 0 $file_list_len); do
|
||||
echo "Begin preprocess input file: ${file_list[$index]}"
|
||||
echo "Begin output file: ${output_filename[$index]}"
|
||||
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
|
||||
--input_file=${file_list[$index]} \
|
||||
--output_file=output/${output_filename[$index]} \
|
||||
--partition_number=1 \
|
||||
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 &
|
||||
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
echo "Total task: ${file_list_len}, processing: ${process_count}"
|
||||
if [ $process_count -ge $avaiable_core_size ]; then
|
||||
while [ 1 ]; do
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
if [ $process_count -gt $process_num ]; then
|
||||
process_count=$process_num
|
||||
break;
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
fi
|
||||
done
|
||||
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
while [ 1 ]; do
|
||||
if [ $process_num -eq 0 ]; then
|
||||
break;
|
||||
fi
|
||||
echo "There are still ${process_num} preprocess running ..."
|
||||
sleep 2
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
done
|
||||
|
||||
echo "Preprocess all the data success."
|
||||
echo "End preprocess `date`"
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
file_list=()
|
||||
file_index=0
|
||||
|
||||
# get all the mindrecord file from output dir
|
||||
function getdir() {
|
||||
elements=`ls $1/[A-Z]*.mindrecord`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
file_list[$file_index]=$element
|
||||
file_index=`expr $file_index + 1`
|
||||
done
|
||||
}
|
||||
|
||||
getdir "./output"
|
||||
echo "Get all the mindrecord files: "${file_list[*]}
|
||||
|
||||
# create dataset for train
|
||||
python create_dataset.py --input_file ${file_list[*]}
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# create dataset for train
|
||||
python create_dataset.py --input_file=output/simple.mindrecord0
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/simple.mindrecord*
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for create_pretraining_data.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# using patched script to generate mindrecord
|
||||
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
|
||||
--input_file=../../../third_party/to_mindrecord/zhwiki/sample_text.txt \
|
||||
--output_file=output/simple.mindrecord \
|
||||
--partition_number=4 \
|
||||
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=5
|
|
@ -0,0 +1 @@
|
|||
## the file is a patch which is about just change data_processor_seq.py the part of generated tfrecord to MindRecord in [CLUEbenchmark/CLUENER2020](https://github.com/CLUEbenchmark/CLUENER2020/tree/master/tf_version)
|
|
@ -0,0 +1,105 @@
|
|||
--- data_processor_seq.py 2020-05-28 10:07:13.365947168 +0800
|
||||
+++ data_processor_seq.py 2020-05-28 10:14:33.298177130 +0800
|
||||
@@ -4,11 +4,18 @@
|
||||
@author: Cong Yu
|
||||
@time: 2019-12-07 17:03
|
||||
"""
|
||||
+import sys
|
||||
+sys.path.append("../../../third_party/to_mindrecord/CLUERNER2020")
|
||||
+
|
||||
+import argparse
|
||||
import json
|
||||
import tokenization
|
||||
import collections
|
||||
-import tensorflow as tf
|
||||
|
||||
+import numpy as np
|
||||
+from mindspore.mindrecord import FileWriter
|
||||
+
|
||||
+# pylint: skip-file
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
@@ -80,11 +87,18 @@ def process_one_example(tokenizer, label
|
||||
return feature
|
||||
|
||||
|
||||
-def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path):
|
||||
+def prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path, out_path):
|
||||
"""
|
||||
- 生成训练数据, tf.record, 单标签分类模型, 随机打乱数据
|
||||
+ 生成训练数据, *.mindrecord, 单标签分类模型, 随机打乱数据
|
||||
"""
|
||||
- writer = tf.python_io.TFRecordWriter(out_path)
|
||||
+ writer = FileWriter(out_path)
|
||||
+
|
||||
+ data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
+ "input_mask": {"type": "int64", "shape": [-1]},
|
||||
+ "segment_ids": {"type": "int64", "shape": [-1]},
|
||||
+ "label_ids": {"type": "int64", "shape": [-1]}}
|
||||
+ writer.add_schema(data_schema, "CLUENER2020 schema")
|
||||
+
|
||||
example_count = 0
|
||||
|
||||
for line in open(path):
|
||||
@@ -113,16 +127,12 @@ def prepare_tf_record_data(tokenizer, ma
|
||||
feature = process_one_example(tokenizer, label2id, list(_["text"]), labels,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
- def create_int_feature(values):
|
||||
- f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
- return f
|
||||
-
|
||||
features = collections.OrderedDict()
|
||||
# 序列标注任务
|
||||
- features["input_ids"] = create_int_feature(feature[0])
|
||||
- features["input_mask"] = create_int_feature(feature[1])
|
||||
- features["segment_ids"] = create_int_feature(feature[2])
|
||||
- features["label_ids"] = create_int_feature(feature[3])
|
||||
+ features["input_ids"] = np.asarray(feature[0])
|
||||
+ features["input_mask"] = np.asarray(feature[1])
|
||||
+ features["segment_ids"] = np.asarray(feature[2])
|
||||
+ features["label_ids"] = np.asarray(feature[3])
|
||||
if example_count < 5:
|
||||
print("*** Example ***")
|
||||
print(_["text"])
|
||||
@@ -132,8 +142,7 @@ def prepare_tf_record_data(tokenizer, ma
|
||||
print("segment_ids: %s" % " ".join([str(x) for x in feature[2]]))
|
||||
print("label: %s " % " ".join([str(x) for x in feature[3]]))
|
||||
|
||||
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
- writer.write(tf_example.SerializeToString())
|
||||
+ writer.write_raw_data([features])
|
||||
example_count += 1
|
||||
|
||||
# if example_count == 20:
|
||||
@@ -141,17 +150,22 @@ def prepare_tf_record_data(tokenizer, ma
|
||||
if example_count % 3000 == 0:
|
||||
print(example_count)
|
||||
print("total example:", example_count)
|
||||
- writer.close()
|
||||
+ writer.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- vocab_file = "./vocab.txt"
|
||||
+ parser = argparse.ArgumentParser()
|
||||
+ parser.add_argument("--vocab_file", type=str, required=True, help='The vocabulary file.')
|
||||
+ parser.add_argument("--label2id_file", type=str, required=True, help='The label2id.json file.')
|
||||
+ args = parser.parse_args()
|
||||
+
|
||||
+ vocab_file = args.vocab_file
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
|
||||
- label2id = json.loads(open("label2id.json").read())
|
||||
+ label2id = json.loads(open(args.label2id_file).read())
|
||||
|
||||
max_seq_len = 64
|
||||
|
||||
- prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json",
|
||||
- out_path="data/train.tf_record")
|
||||
- prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json",
|
||||
- out_path="data/dev.tf_record")
|
||||
+ prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path="data/cluener_public/train.json",
|
||||
+ out_path="output/train.mindrecord")
|
||||
+ prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path="data/cluener_public/dev.json",
|
||||
+ out_path="output/dev.mindrecord")
|
|
@ -0,0 +1 @@
|
|||
## the file is a patch which is about just change create_pretraining_data.py the part of generated tfrecord to MindRecord in [google-research/bert](https://github.com/google-research/bert)
|
|
@ -0,0 +1,288 @@
|
|||
--- create_pretraining_data.py 2020-05-27 17:02:14.285363720 +0800
|
||||
+++ create_pretraining_data.py 2020-05-27 17:30:52.427767841 +0800
|
||||
@@ -12,57 +12,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
||||
+"""Create masked LM/next sentence masked_lm MindRecord files for BERT."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
+import sys
|
||||
+sys.path.append("../../../third_party/to_mindrecord/zhwiki")
|
||||
+
|
||||
+import argparse
|
||||
import collections
|
||||
+import logging
|
||||
import random
|
||||
import tokenization
|
||||
-import tensorflow as tf
|
||||
-
|
||||
-flags = tf.flags
|
||||
-
|
||||
-FLAGS = flags.FLAGS
|
||||
-
|
||||
-flags.DEFINE_string("input_file", None,
|
||||
- "Input raw text file (or comma-separated list of files).")
|
||||
-
|
||||
-flags.DEFINE_string(
|
||||
- "output_file", None,
|
||||
- "Output TF example file (or comma-separated list of files).")
|
||||
-
|
||||
-flags.DEFINE_string("vocab_file", None,
|
||||
- "The vocabulary file that the BERT model was trained on.")
|
||||
-
|
||||
-flags.DEFINE_bool(
|
||||
- "do_lower_case", True,
|
||||
- "Whether to lower case the input text. Should be True for uncased "
|
||||
- "models and False for cased models.")
|
||||
-
|
||||
-flags.DEFINE_bool(
|
||||
- "do_whole_word_mask", False,
|
||||
- "Whether to use whole word masking rather than per-WordPiece masking.")
|
||||
-
|
||||
-flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
||||
|
||||
-flags.DEFINE_integer("max_predictions_per_seq", 20,
|
||||
- "Maximum number of masked LM predictions per sequence.")
|
||||
+import numpy as np
|
||||
+from mindspore.mindrecord import FileWriter
|
||||
|
||||
-flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
||||
+# pylint: skip-file
|
||||
|
||||
-flags.DEFINE_integer(
|
||||
- "dupe_factor", 10,
|
||||
- "Number of times to duplicate the input data (with different masks).")
|
||||
-
|
||||
-flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
||||
-
|
||||
-flags.DEFINE_float(
|
||||
- "short_seq_prob", 0.1,
|
||||
- "Probability of creating sequences which are shorter than the "
|
||||
- "maximum length.")
|
||||
+logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
+ datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
|
||||
|
||||
|
||||
class TrainingInstance(object):
|
||||
@@ -94,13 +65,19 @@ class TrainingInstance(object):
|
||||
|
||||
|
||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||
- max_predictions_per_seq, output_files):
|
||||
- """Create TF example files from `TrainingInstance`s."""
|
||||
- writers = []
|
||||
- for output_file in output_files:
|
||||
- writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||
-
|
||||
- writer_index = 0
|
||||
+ max_predictions_per_seq, output_file, partition_number):
|
||||
+ """Create MindRecord files from `TrainingInstance`s."""
|
||||
+ writer = FileWriter(output_file, int(partition_number))
|
||||
+
|
||||
+ data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
+ "input_mask": {"type": "int64", "shape": [-1]},
|
||||
+ "segment_ids": {"type": "int64", "shape": [-1]},
|
||||
+ "masked_lm_positions": {"type": "int64", "shape": [-1]},
|
||||
+ "masked_lm_ids": {"type": "int64", "shape": [-1]},
|
||||
+ "masked_lm_weights": {"type": "float64", "shape": [-1]},
|
||||
+ "next_sentence_labels": {"type": "int64", "shape": [-1]},
|
||||
+ }
|
||||
+ writer.add_schema(data_schema, "zhwiki schema")
|
||||
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
@@ -130,55 +107,35 @@ def write_instance_to_example_files(inst
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
|
||||
features = collections.OrderedDict()
|
||||
- features["input_ids"] = create_int_feature(input_ids)
|
||||
- features["input_mask"] = create_int_feature(input_mask)
|
||||
- features["segment_ids"] = create_int_feature(segment_ids)
|
||||
- features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
- features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
- features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
- features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
-
|
||||
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
-
|
||||
- writers[writer_index].write(tf_example.SerializeToString())
|
||||
- writer_index = (writer_index + 1) % len(writers)
|
||||
+ features["input_ids"] = np.asarray(input_ids)
|
||||
+ features["input_mask"] = np.asarray(input_mask)
|
||||
+ features["segment_ids"] = np.asarray(segment_ids)
|
||||
+ features["masked_lm_positions"] = np.asarray(masked_lm_positions)
|
||||
+ features["masked_lm_ids"] = np.asarray(masked_lm_ids)
|
||||
+ features["masked_lm_weights"] = np.asarray(masked_lm_weights)
|
||||
+ features["next_sentence_labels"] = np.asarray([next_sentence_label])
|
||||
|
||||
total_written += 1
|
||||
|
||||
if inst_index < 20:
|
||||
- tf.logging.info("*** Example ***")
|
||||
- tf.logging.info("tokens: %s" % " ".join(
|
||||
+ logging.info("*** Example ***")
|
||||
+ logging.info("tokens: %s" % " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
- values = []
|
||||
- if feature.int64_list.value:
|
||||
- values = feature.int64_list.value
|
||||
- elif feature.float_list.value:
|
||||
- values = feature.float_list.value
|
||||
- tf.logging.info(
|
||||
- "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||
-
|
||||
- for writer in writers:
|
||||
- writer.close()
|
||||
-
|
||||
- tf.logging.info("Wrote %d total instances", total_written)
|
||||
-
|
||||
-
|
||||
-def create_int_feature(values):
|
||||
- feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
- return feature
|
||||
+ logging.info(
|
||||
+ "%s: %s" % (feature_name, " ".join([str(x) for x in feature])))
|
||||
+ writer.write_raw_data([features])
|
||||
|
||||
+ writer.commit()
|
||||
|
||||
-def create_float_feature(values):
|
||||
- feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
- return feature
|
||||
+ logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||
- max_predictions_per_seq, rng):
|
||||
+ max_predictions_per_seq, rng, do_whole_word_mask):
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
|
||||
@@ -189,7 +146,7 @@ def create_training_instances(input_file
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
- with tf.gfile.GFile(input_file, "r") as reader:
|
||||
+ with open(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
@@ -214,7 +171,7 @@ def create_training_instances(input_file
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask))
|
||||
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
@@ -222,7 +179,7 @@ def create_training_instances(input_file
|
||||
|
||||
def create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
|
||||
@@ -320,7 +277,7 @@ def create_instances_from_document(
|
||||
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
- tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
@@ -340,7 +297,7 @@ MaskedLmInstance = collections.namedtupl
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||
- max_predictions_per_seq, vocab_words, rng):
|
||||
+ max_predictions_per_seq, vocab_words, rng, do_whole_word_mask):
|
||||
"""Creates the predictions for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
@@ -356,7 +313,7 @@ def create_masked_lm_predictions(tokens,
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
- if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
+ if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
token.startswith("##")):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
@@ -433,37 +390,42 @@ def truncate_seq_pair(tokens_a, tokens_b
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
-def main(_):
|
||||
- tf.logging.set_verbosity(tf.logging.INFO)
|
||||
+def main():
|
||||
+ parser = argparse.ArgumentParser()
|
||||
+ parser.add_argument("--input_file", type=str, required=True, help='Input raw text file (or comma-separated list of files).')
|
||||
+ parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
+ parser.add_argument("--partition_number", type=int, default=1, help='The MindRecord file will be split into the number of partition.')
|
||||
+ parser.add_argument("--vocab_file", type=str, required=True, help='The vocabulary file than the BERT model was trained on.')
|
||||
+ parser.add_argument("--do_lower_case", type=bool, default=False, help='Whether to lower case the input text. Should be True for uncased models and False for cased models.')
|
||||
+ parser.add_argument("--do_whole_word_mask", type=bool, default=False, help='Whether to use whole word masking rather than per-WordPiece masking.')
|
||||
+ parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
|
||||
+ parser.add_argument("--max_predictions_per_seq", type=int, default=20, help='Maximum number of masked LM predictions per sequence.')
|
||||
+ parser.add_argument("--random_seed", type=int, default=12345, help='Random seed for data generation.')
|
||||
+ parser.add_argument("--dupe_factor", type=int, default=10, help='Number of times to duplicate the input data (with diffrent masks).')
|
||||
+ parser.add_argument("--masked_lm_prob", type=float, default=0.15, help='Masked LM probability.')
|
||||
+ parser.add_argument("--short_seq_prob", type=float, default=0.1, help='Probability of creating sequences which are shorter than the maximum length.')
|
||||
+ args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
+ vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
||||
|
||||
input_files = []
|
||||
- for input_pattern in FLAGS.input_file.split(","):
|
||||
- input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
+ for input_pattern in args.input_file.split(","):
|
||||
+ input_files.append(input_pattern)
|
||||
|
||||
- tf.logging.info("*** Reading from input files ***")
|
||||
+ logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
- tf.logging.info(" %s", input_file)
|
||||
+ logging.info(" %s", input_file)
|
||||
|
||||
- rng = random.Random(FLAGS.random_seed)
|
||||
+ rng = random.Random(args.random_seed)
|
||||
instances = create_training_instances(
|
||||
- input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
- FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
- rng)
|
||||
-
|
||||
- output_files = FLAGS.output_file.split(",")
|
||||
- tf.logging.info("*** Writing to output files ***")
|
||||
- for output_file in output_files:
|
||||
- tf.logging.info(" %s", output_file)
|
||||
+ input_files, tokenizer, args.max_seq_length, args.dupe_factor,
|
||||
+ args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq,
|
||||
+ rng, args.do_whole_word_mask)
|
||||
|
||||
- write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
- FLAGS.max_predictions_per_seq, output_files)
|
||||
+ write_instance_to_example_files(instances, tokenizer, args.max_seq_length,
|
||||
+ args.max_predictions_per_seq, args.output_file, args.partition_number)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- flags.mark_flag_as_required("input_file")
|
||||
- flags.mark_flag_as_required("output_file")
|
||||
- flags.mark_flag_as_required("vocab_file")
|
||||
- tf.app.run()
|
||||
+ main()
|
|
@ -0,0 +1 @@
|
|||
data_processor_seq_patched.py
|
|
@ -0,0 +1 @@
|
|||
## All the scripts here come from [CLUEbenchmark/CLUENER2020](https://github.com/CLUEbenchmark/CLUENER2020/tree/master/tf_version)
|
|
@ -0,0 +1,157 @@
|
|||
#!/usr/bin/python
|
||||
# coding:utf8
|
||||
"""
|
||||
@author: Cong Yu
|
||||
@time: 2019-12-07 17:03
|
||||
"""
|
||||
import json
|
||||
import tokenization
|
||||
import collections
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def process_one_example(tokenizer, label2id, text, label, max_seq_len=128):
|
||||
# textlist = text.split(' ')
|
||||
# labellist = label.split(' ')
|
||||
textlist = list(text)
|
||||
labellist = list(label)
|
||||
tokens = []
|
||||
labels = []
|
||||
for i, word in enumerate(textlist):
|
||||
token = tokenizer.tokenize(word)
|
||||
tokens.extend(token)
|
||||
label_1 = labellist[i]
|
||||
for m in range(len(token)):
|
||||
if m == 0:
|
||||
labels.append(label_1)
|
||||
else:
|
||||
print("some unknown token...")
|
||||
labels.append(labels[0])
|
||||
# tokens = tokenizer.tokenize(example.text) -2 的原因是因为序列需要加一个句首和句尾标志
|
||||
if len(tokens) >= max_seq_len - 1:
|
||||
tokens = tokens[0:(max_seq_len - 2)]
|
||||
labels = labels[0:(max_seq_len - 2)]
|
||||
ntokens = []
|
||||
segment_ids = []
|
||||
label_ids = []
|
||||
ntokens.append("[CLS]") # 句子开始设置CLS 标志
|
||||
segment_ids.append(0)
|
||||
# [CLS] [SEP] 可以为 他们构建标签,或者 统一到某个标签,反正他们是不变的,基本不参加训练 即:x-l 永远不变
|
||||
label_ids.append(0) # label2id["[CLS]"]
|
||||
for i, token in enumerate(tokens):
|
||||
ntokens.append(token)
|
||||
segment_ids.append(0)
|
||||
label_ids.append(label2id[labels[i]])
|
||||
ntokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
# append("O") or append("[SEP]") not sure!
|
||||
label_ids.append(0) # label2id["[SEP]"]
|
||||
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < max_seq_len:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
label_ids.append(0)
|
||||
ntokens.append("**NULL**")
|
||||
assert len(input_ids) == max_seq_len
|
||||
assert len(input_mask) == max_seq_len
|
||||
assert len(segment_ids) == max_seq_len
|
||||
assert len(label_ids) == max_seq_len
|
||||
|
||||
feature = (input_ids, input_mask, segment_ids, label_ids)
|
||||
return feature
|
||||
|
||||
|
||||
def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path):
|
||||
"""
|
||||
生成训练数据, tf.record, 单标签分类模型, 随机打乱数据
|
||||
"""
|
||||
writer = tf.python_io.TFRecordWriter(out_path)
|
||||
example_count = 0
|
||||
|
||||
for line in open(path):
|
||||
if not line.strip():
|
||||
continue
|
||||
_ = json.loads(line.strip())
|
||||
len_ = len(_["text"])
|
||||
labels = ["O"] * len_
|
||||
for k, v in _["label"].items():
|
||||
for kk, vv in v.items():
|
||||
for vvv in vv:
|
||||
span = vvv
|
||||
s = span[0]
|
||||
e = span[1] + 1
|
||||
# print(s, e)
|
||||
if e - s == 1:
|
||||
labels[s] = "S_" + k
|
||||
else:
|
||||
labels[s] = "B_" + k
|
||||
for i in range(s + 1, e - 1):
|
||||
labels[i] = "M_" + k
|
||||
labels[e - 1] = "E_" + k
|
||||
# print()
|
||||
# feature = process_one_example(tokenizer, label2id, row[column_name_x1], row[column_name_y],
|
||||
# max_seq_len=max_seq_len)
|
||||
feature = process_one_example(tokenizer, label2id, list(_["text"]), labels,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return f
|
||||
|
||||
features = collections.OrderedDict()
|
||||
# 序列标注任务
|
||||
features["input_ids"] = create_int_feature(feature[0])
|
||||
features["input_mask"] = create_int_feature(feature[1])
|
||||
features["segment_ids"] = create_int_feature(feature[2])
|
||||
features["label_ids"] = create_int_feature(feature[3])
|
||||
if example_count < 5:
|
||||
print("*** Example ***")
|
||||
print(_["text"])
|
||||
print(_["label"])
|
||||
print("input_ids: %s" % " ".join([str(x) for x in feature[0]]))
|
||||
print("input_mask: %s" % " ".join([str(x) for x in feature[1]]))
|
||||
print("segment_ids: %s" % " ".join([str(x) for x in feature[2]]))
|
||||
print("label: %s " % " ".join([str(x) for x in feature[3]]))
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
example_count += 1
|
||||
|
||||
# if example_count == 20:
|
||||
# break
|
||||
if example_count % 3000 == 0:
|
||||
print(example_count)
|
||||
print("total example:", example_count)
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab_file = "./vocab.txt"
|
||||
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
|
||||
label2id = json.loads(open("label2id.json").read())
|
||||
|
||||
max_seq_len = 64
|
||||
|
||||
prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json",
|
||||
out_path="data/train.tf_record")
|
||||
prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json",
|
||||
out_path="data/dev.tf_record")
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"O": 0,
|
||||
"S_address": 1,
|
||||
"B_address": 2,
|
||||
"M_address": 3,
|
||||
"E_address": 4,
|
||||
"S_book": 5,
|
||||
"B_book": 6,
|
||||
"M_book": 7,
|
||||
"E_book": 8,
|
||||
"S_company": 9,
|
||||
"B_company": 10,
|
||||
"M_company": 11,
|
||||
"E_company": 12,
|
||||
"S_game": 13,
|
||||
"B_game": 14,
|
||||
"M_game": 15,
|
||||
"E_game": 16,
|
||||
"S_government": 17,
|
||||
"B_government": 18,
|
||||
"M_government": 19,
|
||||
"E_government": 20,
|
||||
"S_movie": 21,
|
||||
"B_movie": 22,
|
||||
"M_movie": 23,
|
||||
"E_movie": 24,
|
||||
"S_name": 25,
|
||||
"B_name": 26,
|
||||
"M_name": 27,
|
||||
"E_name": 28,
|
||||
"S_organization": 29,
|
||||
"B_organization": 30,
|
||||
"M_organization": 31,
|
||||
"E_organization": 32,
|
||||
"S_position": 33,
|
||||
"B_position": 34,
|
||||
"M_position": 35,
|
||||
"E_position": 36,
|
||||
"S_scene": 37,
|
||||
"B_scene": 38,
|
||||
"M_scene": 39,
|
||||
"E_scene": 40
|
||||
}
|
|
@ -0,0 +1,388 @@
|
|||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
if item in vocab:
|
||||
output.append(vocab[item])
|
||||
else:
|
||||
output.append(vocab['[UNK]'])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
|||
create_pretraining_data_patched.py
|
|
@ -0,0 +1 @@
|
|||
## All the scripts here come from [google-research/bert](https://github.com/google-research/bert)
|
|
@ -0,0 +1,469 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import tokenization
|
||||
import tensorflow as tf
|
||||
|
||||
flags = tf.flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("input_file", None,
|
||||
"Input raw text file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_file", None,
|
||||
"Output TF example file (or comma-separated list of files).")
|
||||
|
||||
flags.DEFINE_string("vocab_file", None,
|
||||
"The vocabulary file that the BERT model was trained on.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_lower_case", True,
|
||||
"Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_whole_word_mask", False,
|
||||
"Whether to use whole word masking rather than per-WordPiece masking.")
|
||||
|
||||
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
||||
|
||||
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
||||
"Maximum number of masked LM predictions per sequence.")
|
||||
|
||||
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"dupe_factor", 10,
|
||||
"Number of times to duplicate the input data (with different masks).")
|
||||
|
||||
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"short_seq_prob", 0.1,
|
||||
"Probability of creating sequences which are shorter than the "
|
||||
"maximum length.")
|
||||
|
||||
|
||||
class TrainingInstance(object):
|
||||
"""A single training instance (sentence pair)."""
|
||||
|
||||
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||
is_random_next):
|
||||
self.tokens = tokens
|
||||
self.segment_ids = segment_ids
|
||||
self.is_random_next = is_random_next
|
||||
self.masked_lm_positions = masked_lm_positions
|
||||
self.masked_lm_labels = masked_lm_labels
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.tokens]))
|
||||
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||
s += "is_random_next: %s\n" % self.is_random_next
|
||||
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||
[str(x) for x in self.masked_lm_positions]))
|
||||
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||
max_predictions_per_seq, output_files):
|
||||
"""Create TF example files from `TrainingInstance`s."""
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||
|
||||
writer_index = 0
|
||||
|
||||
total_written = 0
|
||||
for (inst_index, instance) in enumerate(instances):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = list(instance.segment_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
masked_lm_positions = list(instance.masked_lm_positions)
|
||||
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||
|
||||
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||
masked_lm_positions.append(0)
|
||||
masked_lm_ids.append(0)
|
||||
masked_lm_weights.append(0.0)
|
||||
|
||||
next_sentence_label = 1 if instance.is_random_next else 0
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = create_int_feature(input_ids)
|
||||
features["input_mask"] = create_int_feature(input_mask)
|
||||
features["segment_ids"] = create_int_feature(segment_ids)
|
||||
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
|
||||
writers[writer_index].write(tf_example.SerializeToString())
|
||||
writer_index = (writer_index + 1) % len(writers)
|
||||
|
||||
total_written += 1
|
||||
|
||||
if inst_index < 20:
|
||||
tf.logging.info("*** Example ***")
|
||||
tf.logging.info("tokens: %s" % " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
values = []
|
||||
if feature.int64_list.value:
|
||||
values = feature.int64_list.value
|
||||
elif feature.float_list.value:
|
||||
values = feature.float_list.value
|
||||
tf.logging.info(
|
||||
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
tf.logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
def create_int_feature(values):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_float_feature(values):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||
return feature
|
||||
|
||||
|
||||
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||
max_predictions_per_seq, rng):
|
||||
"""Create `TrainingInstance`s from raw text."""
|
||||
all_documents = [[]]
|
||||
|
||||
# Input file format:
|
||||
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||
# sentence boundaries for the "next sentence prediction" task).
|
||||
# (2) Blank lines between documents. Document boundaries are needed so
|
||||
# that the "next sentence prediction" task doesn't span between documents.
|
||||
for input_file in input_files:
|
||||
with tf.gfile.GFile(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
|
||||
# Empty lines are used as document delimiters
|
||||
if not line:
|
||||
all_documents.append([])
|
||||
tokens = tokenizer.tokenize(line)
|
||||
if tokens:
|
||||
all_documents[-1].append(tokens)
|
||||
|
||||
# Remove empty documents
|
||||
all_documents = [x for x in all_documents if x]
|
||||
rng.shuffle(all_documents)
|
||||
|
||||
vocab_words = list(tokenizer.vocab.keys())
|
||||
instances = []
|
||||
for _ in range(dupe_factor):
|
||||
for document_index in range(len(all_documents)):
|
||||
instances.extend(
|
||||
create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||
|
||||
rng.shuffle(instances)
|
||||
return instances
|
||||
|
||||
|
||||
def create_instances_from_document(
|
||||
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates `TrainingInstance`s for a single document."""
|
||||
document = all_documents[document_index]
|
||||
|
||||
# Account for [CLS], [SEP], [SEP]
|
||||
max_num_tokens = max_seq_length - 3
|
||||
|
||||
# We *usually* want to fill up the entire sequence since we are padding
|
||||
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||
# computation. However, we *sometimes*
|
||||
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||
# The `target_seq_length` is just a rough target however, whereas
|
||||
# `max_seq_length` is a hard limit.
|
||||
target_seq_length = max_num_tokens
|
||||
if rng.random() < short_seq_prob:
|
||||
target_seq_length = rng.randint(2, max_num_tokens)
|
||||
|
||||
# We DON'T just concatenate all of the tokens from a document into a long
|
||||
# sequence and choose an arbitrary split point because this would make the
|
||||
# next sentence prediction task too easy. Instead, we split the input into
|
||||
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||
# input.
|
||||
instances = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i = 0
|
||||
while i < len(document):
|
||||
segment = document[i]
|
||||
current_chunk.append(segment)
|
||||
current_length += len(segment)
|
||||
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||
if current_chunk:
|
||||
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
tokens_a.extend(current_chunk[j])
|
||||
|
||||
tokens_b = []
|
||||
# Random next
|
||||
is_random_next = False
|
||||
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||
is_random_next = True
|
||||
target_b_length = target_seq_length - len(tokens_a)
|
||||
|
||||
# This should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document
|
||||
# we're processing.
|
||||
for _ in range(10):
|
||||
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||
if random_document_index != document_index:
|
||||
break
|
||||
|
||||
random_document = all_documents[random_document_index]
|
||||
random_start = rng.randint(0, len(random_document) - 1)
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
break
|
||||
# We didn't actually use these segments so we "put them back" so
|
||||
# they don't go to waste.
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
i -= num_unused_segments
|
||||
# Actual next
|
||||
else:
|
||||
is_random_next = False
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
(tokens, masked_lm_positions,
|
||||
masked_lm_labels) = create_masked_lm_predictions(
|
||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||
instance = TrainingInstance(
|
||||
tokens=tokens,
|
||||
segment_ids=segment_ids,
|
||||
is_random_next=is_random_next,
|
||||
masked_lm_positions=masked_lm_positions,
|
||||
masked_lm_labels=masked_lm_labels)
|
||||
instances.append(instance)
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
i += 1
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||
max_predictions_per_seq, vocab_words, rng):
|
||||
"""Creates the predictions for the masked LM objective."""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
# Whole Word Masking means that if we mask all of the wordpieces
|
||||
# corresponding to an original word. When a word has been split into
|
||||
# WordPieces, the first token does not have any marker and any subsequence
|
||||
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||
# append it to the previous set of word indexes.
|
||||
#
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
token.startswith("##")):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
rng.shuffle(cand_indexes)
|
||||
|
||||
output_tokens = list(tokens)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index_set in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
|
||||
masked_token = None
|
||||
# 80% of the time, replace with [MASK]
|
||||
if rng.random() < 0.8:
|
||||
masked_token = "[MASK]"
|
||||
else:
|
||||
# 10% of the time, keep original
|
||||
if rng.random() < 0.5:
|
||||
masked_token = tokens[index]
|
||||
# 10% of the time, replace with random word
|
||||
else:
|
||||
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||
|
||||
output_tokens[index] = masked_token
|
||||
|
||||
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||
assert len(masked_lms) <= num_to_predict
|
||||
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||
|
||||
masked_lm_positions = []
|
||||
masked_lm_labels = []
|
||||
for p in masked_lms:
|
||||
masked_lm_positions.append(p.index)
|
||||
masked_lm_labels.append(p.label)
|
||||
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if rng.random() < 0.5:
|
||||
del trunc_tokens[0]
|
||||
else:
|
||||
trunc_tokens.pop()
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
|
||||
tf.logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s", input_file)
|
||||
|
||||
rng = random.Random(FLAGS.random_seed)
|
||||
instances = create_training_instances(
|
||||
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||
rng)
|
||||
|
||||
output_files = FLAGS.output_file.split(",")
|
||||
tf.logging.info("*** Writing to output files ***")
|
||||
for output_file in output_files:
|
||||
tf.logging.info(" %s", output_file)
|
||||
|
||||
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||
FLAGS.max_predictions_per_seq, output_files)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("output_file")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
tf.app.run()
|
|
@ -0,0 +1,33 @@
|
|||
This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
|
||||
Text should be one-sentence-per-line, with empty lines between documents.
|
||||
This sample text is public domain and was randomly selected from Project Guttenberg.
|
||||
|
||||
The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
|
||||
Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
|
||||
Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
|
||||
"Cass" Beard had risen early that morning, but not with a view to discovery.
|
||||
A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
|
||||
The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
|
||||
This was nearly opposite.
|
||||
Mr. Cassius crossed the highway, and stopped suddenly.
|
||||
Something glittered in the nearest red pool before him.
|
||||
Gold, surely!
|
||||
But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
|
||||
Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
|
||||
Like most of his fellow gold-seekers, Cass was superstitious.
|
||||
|
||||
The fountain of classic wisdom, Hypatia herself.
|
||||
As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
|
||||
From my youth I felt in me a soul above the matter-entangled herd.
|
||||
She revealed to me the glorious fact, that I am a spark of Divinity itself.
|
||||
A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
|
||||
There is a philosophic pleasure in opening one's treasures to the modest young.
|
||||
Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
|
||||
Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
|
||||
but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
|
||||
Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
|
||||
His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
|
||||
while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
|
||||
At last they reached the quay at the opposite end of the street;
|
||||
and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
|
||||
He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
|
|
@ -0,0 +1,394 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue