From 2c0258d373e73fdaaedf7d6eba687c7cf8db8524 Mon Sep 17 00:00:00 2001
From: jonyguo <guozhijian@huawei.com>
Date: Tue, 9 Jun 2020 23:10:16 +0800
Subject: [PATCH] add example for enwiki -> mindrecord

---
 example/nlp_to_mindrecord/enwiki/README.md    | 173 ++++++++++++++++++
 .../enwiki/create_dataset.py                  |  43 +++++
 example/nlp_to_mindrecord/enwiki/run.sh       | 133 ++++++++++++++
 example/nlp_to_mindrecord/enwiki/run_read.sh  |  44 +++++
 example/nlp_to_mindrecord/zhwiki/run.sh       |   2 +-
 .../zhwiki/create_pretraining_data.patch      |  16 +-
 6 files changed, 402 insertions(+), 9 deletions(-)
 create mode 100644 example/nlp_to_mindrecord/enwiki/README.md
 create mode 100644 example/nlp_to_mindrecord/enwiki/create_dataset.py
 create mode 100644 example/nlp_to_mindrecord/enwiki/run.sh
 create mode 100644 example/nlp_to_mindrecord/enwiki/run_read.sh

diff --git a/example/nlp_to_mindrecord/enwiki/README.md b/example/nlp_to_mindrecord/enwiki/README.md
new file mode 100644
index 00000000000..e92e8dbcc65
--- /dev/null
+++ b/example/nlp_to_mindrecord/enwiki/README.md
@@ -0,0 +1,173 @@
+# Guideline to Convert Training Data enwiki to MindRecord For Bert Pre Training
+
+<!-- TOC -->
+
+- [What does the example do](#what-does-the-example-do)
+- [How to use the example to process enwiki](#how-to-use-the-example-to-process-enwiki)
+    - [Download enwiki training data](#download-enwiki-training-data)
+    - [Process the enwiki](#process-the-enwiki)
+    - [Generate MindRecord](#generate-mindrecord)
+    - [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
+
+
+<!-- /TOC -->
+
+## What does the example do
+
+This example is based on [enwiki](https://dumps.wikimedia.org/enwiki) training data, generating MindRecord file, and finally used for Bert network training.
+
+1.  run.sh: generate MindRecord entry script.
+2.  run_read.py: create MindDataset by MindRecord entry script.
+    - create_dataset.py: use MindDataset to read MindRecord to generate dataset.
+
+## How to use the example to process enwiki
+
+Download enwiki data, process it, convert it to MindRecord, use MindDataset to read MindRecord.
+
+### Download enwiki training data
+
+> [enwiki dataset download address](https://dumps.wikimedia.org/enwiki) **-> 20200501 -> enwiki-20200501-pages-articles-multistream.xml.bz2**
+
+### Process the enwiki
+
+1. Please follow the steps in [process enwiki](https://github.com/mlperf/training/tree/master/language_model/tensorflow/bert)
+- All permissions of this step belong to the link address website.
+
+### Generate MindRecord
+
+1. Run the run.sh script.
+    ```
+    bash run.sh input_dir output_dir vocab_file
+    ```
+    - input_dir: the directory which contains files like 'part-00251-of-00500'.
+    - output_dir: which will store the output mindrecord files.
+    - vocab_file: the vocab file which you can download from other opensource project.
+
+2. The output like this:
+    ```
+    ...
+    Begin preprocess Wed Jun 10 09:21:23 CST 2020
+    Begin preprocess input file: /mnt/data/results/part-00000-of-00500
+    Begin output file: part-00000-of-00500.mindrecord
+    Total task: 510, processing: 1
+    Begin preprocess input file: /mnt/data/results/part-00001-of-00500
+    Begin output file: part-00001-of-00500.mindrecord
+    Total task: 510, processing: 2
+    Begin preprocess input file: /mnt/data/results/part-00002-of-00500
+    Begin output file: part-00002-of-00500.mindrecord
+    Total task: 510, processing: 3
+    Begin preprocess input file: /mnt/data/results/part-00003-of-00500
+    Begin output file: part-00003-of-00500.mindrecord
+    Total task: 510, processing: 4
+    Begin preprocess input file: /mnt/data/results/part-00004-of-00500
+    Begin output file: part-00004-of-00500.mindrecord
+    Total task: 510, processing: 4
+    ...
+    ```
+
+3. Generate files like this:
+    ```bash
+    $ ls {your_output_dir}/
+    part-00000-of-00500.mindrecord part-00000-of-00500.mindrecord.db part-00001-of-00500.mindrecord part-00001-of-00500.mindrecord.db part-00002-of-00500.mindrecord part-00002-of-00500.mindrecord.db ...
+    ```
+
+### Create MindDataset By MindRecord
+
+1. Run the run_read.sh script.
+    ```bash
+    bash run_read.sh input_dir
+    ```
+    - input_dir: the directory which contains mindrecord files.
+
+2. The output like this:
+    ```
+    ...
+    example 633: input_ids: [  101  2043 19781  4305  2140  4520  2041  1010   103  2034  2455  2002
+      7879  2003  1996  2455  1997   103 26378  4160  1012   102  7291  2001
+      1996   103  1011  2343  1997  6327  1010  3423  1998   103  4262  2005
+      1996  2118  1997  2329  3996   103   102     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0]
+    example 633: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
+     1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
+    example 633: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
+     1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+     0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
+    example 633: masked_lm_positions: [ 8 17 20 25 33 41  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
+      0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
+      0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
+      0  0  0  0]
+    example 633: masked_lm_ids: [ 1996 16137  1012  3580  2451  1012     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0     0     0     0     0     0     0     0     0
+         0     0     0     0]
+    example 633: masked_lm_weights: [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+     0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+     0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
+     0. 0. 0. 0.]
+    example 633: next_sentence_labels: [1]
+    ...
+    ```
diff --git a/example/nlp_to_mindrecord/enwiki/create_dataset.py b/example/nlp_to_mindrecord/enwiki/create_dataset.py
new file mode 100644
index 00000000000..d90d12b7f23
--- /dev/null
+++ b/example/nlp_to_mindrecord/enwiki/create_dataset.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""create MindDataset by MindRecord"""
+import argparse
+import mindspore.dataset as ds
+
+def create_dataset(data_file):
+    """create MindDataset"""
+    num_readers = 4
+    data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
+    index = 0
+    for item in data_set.create_dict_iterator():
+        # print("example {}: {}".format(index, item))
+        print("example {}: input_ids: {}".format(index, item['input_ids']))
+        print("example {}: input_mask: {}".format(index, item['input_mask']))
+        print("example {}: segment_ids: {}".format(index, item['segment_ids']))
+        print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
+        print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
+        print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
+        print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
+        index += 1
+        if index % 1000 == 0:
+            print("read rows: {}".format(index))
+    print("total rows: {}".format(index))
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
+    args = parser.parse_args()
+
+    create_dataset(args.input_file)
diff --git a/example/nlp_to_mindrecord/enwiki/run.sh b/example/nlp_to_mindrecord/enwiki/run.sh
new file mode 100644
index 00000000000..cf66bed0fde
--- /dev/null
+++ b/example/nlp_to_mindrecord/enwiki/run.sh
@@ -0,0 +1,133 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# -ne 3 ]; then
+    echo "Usage: $0 input_dir output_dir vocab_file"
+    exit 1
+fi
+
+if [ ! -d $1 ]; then
+    echo "The input dir: $1 is not exist."
+    exit 1
+fi
+
+if [ ! -d $2 ]; then
+    echo "The output dir: $2 is not exist."
+    exit 1
+fi
+rm -fr $2/*.mindrecord*
+
+if [ ! -f $3 ]; then
+    echo "The vocab file: $3 is not exist."
+    exit 1
+fi
+
+data_dir=$1
+output_dir=$2
+vocab_file=$3
+file_list=()
+output_filename=()
+file_index=0
+
+function getdir() {
+    elements=`ls $1`
+    for element in ${elements[*]};
+    do
+        dir_or_file=$1"/"$element
+        if [ -d $dir_or_file ];
+        then
+            getdir $dir_or_file
+        else
+            file_list[$file_index]=$dir_or_file
+            echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt   # dir dir file to mapfile
+            mapfile parent_dir < dir_file_list.txt
+            rm dir_file_list.txt >/dev/null 2>&1
+            tmp_output_filename=${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
+            output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
+            file_index=`expr $file_index + 1`
+        fi
+    done
+}
+
+getdir "${data_dir}"
+# echo "The input files: "${file_list[@]}
+# echo "The output files: "${output_filename[@]}
+
+if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
+    echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
+    exit 1
+fi
+
+if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
+    echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
+    exit 1
+fi
+
+# patch for create_pretraining_data.py
+patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
+if [ $? -ne 0 ]; then
+    echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
+    exit 1
+fi
+
+# get the cpu core count
+num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
+avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
+
+echo "Begin preprocess `date`"
+
+# using patched script to generate mindrecord
+file_list_len=`expr ${#file_list[*]} - 1`
+for index in $(seq 0 $file_list_len); do
+    echo "Begin preprocess input file: ${file_list[$index]}"
+    echo "Begin output file: ${output_filename[$index]}"
+    python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
+        --input_file=${file_list[$index]} \
+        --output_file=${output_dir}/${output_filename[$index]} \
+        --partition_number=1 \
+        --vocab_file=${vocab_file} \
+        --do_lower_case=True \
+        --max_seq_length=512 \
+        --max_predictions_per_seq=76 \
+        --masked_lm_prob=0.15 \
+        --random_seed=12345 \
+        --dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 &
+    process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
+    echo "Total task: ${#file_list[*]}, processing: ${process_count}"
+    if [ $process_count -ge $avaiable_core_size ]; then
+        while [ 1 ]; do
+            process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
+            if [ $process_count -gt $process_num ]; then
+                process_count=$process_num
+                break;
+            fi
+            sleep 2
+        done
+    fi
+done
+
+process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
+while [ 1 ]; do
+    if [ $process_num -eq 0 ]; then
+        break;
+    fi
+    echo "There are still ${process_num} preprocess running ..."
+    sleep 2
+    process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
+done
+
+echo "Preprocess all the data success."
+echo "End preprocess `date`"
diff --git a/example/nlp_to_mindrecord/enwiki/run_read.sh b/example/nlp_to_mindrecord/enwiki/run_read.sh
new file mode 100644
index 00000000000..737e9375c4b
--- /dev/null
+++ b/example/nlp_to_mindrecord/enwiki/run_read.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# -ne 1 ]; then
+    echo "Usage: $0 input_dir"
+    exit 1
+fi
+
+if [ ! -d $1 ]; then
+    echo "The input dir: $1 is not exist."
+    exit 1
+fi
+
+file_list=()
+file_index=0
+
+# get all the mindrecord file from output dir
+function getdir() {
+    elements=`ls $1/part-*.mindrecord`
+    for element in ${elements[*]};
+    do
+        file_list[$file_index]=$element
+        file_index=`expr $file_index + 1`
+    done
+}
+
+getdir $1
+echo "Get all the mindrecord files: "${file_list[*]}
+
+# create dataset for train
+python create_dataset.py --input_file ${file_list[*]}
diff --git a/example/nlp_to_mindrecord/zhwiki/run.sh b/example/nlp_to_mindrecord/zhwiki/run.sh
index 24f2a98eb68..431ff54c65f 100644
--- a/example/nlp_to_mindrecord/zhwiki/run.sh
+++ b/example/nlp_to_mindrecord/zhwiki/run.sh
@@ -85,7 +85,7 @@ for index in $(seq 0 $file_list_len); do
         --random_seed=12345 \
         --dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 &
     process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
-    echo "Total task: ${file_list_len}, processing: ${process_count}"
+    echo "Total task: ${#file_list[*]}, processing: ${process_count}"
     if [ $process_count -ge $avaiable_core_size ]; then
         while [ 1 ]; do
             process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
diff --git a/third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch b/third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
index 64a126d899c..1a7b15dce2a 100644
--- a/third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
+++ b/third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
@@ -89,7 +89,7 @@
 +                 "segment_ids": {"type": "int64", "shape": [-1]},
 +                 "masked_lm_positions": {"type": "int64", "shape": [-1]},
 +                 "masked_lm_ids": {"type": "int64", "shape": [-1]},
-+                 "masked_lm_weights": {"type": "float64", "shape": [-1]},
++                 "masked_lm_weights": {"type": "float32", "shape": [-1]},
 +                 "next_sentence_labels": {"type": "int64", "shape": [-1]},
 +                }
 +  writer.add_schema(data_schema, "zhwiki schema")
@@ -112,13 +112,13 @@
 -
 -    writers[writer_index].write(tf_example.SerializeToString())
 -    writer_index = (writer_index + 1) % len(writers)
-+    features["input_ids"] = np.asarray(input_ids)
-+    features["input_mask"] = np.asarray(input_mask)
-+    features["segment_ids"] = np.asarray(segment_ids)
-+    features["masked_lm_positions"] = np.asarray(masked_lm_positions)
-+    features["masked_lm_ids"] = np.asarray(masked_lm_ids)
-+    features["masked_lm_weights"] = np.asarray(masked_lm_weights)
-+    features["next_sentence_labels"] = np.asarray([next_sentence_label])
++    features["input_ids"] = np.asarray(input_ids, np.int64)
++    features["input_mask"] = np.asarray(input_mask, np.int64)
++    features["segment_ids"] = np.asarray(segment_ids, np.int64)
++    features["masked_lm_positions"] = np.asarray(masked_lm_positions, np.int64)
++    features["masked_lm_ids"] = np.asarray(masked_lm_ids, np.int64)
++    features["masked_lm_weights"] = np.asarray(masked_lm_weights, np.float32)
++    features["next_sentence_labels"] = np.asarray([next_sentence_label], np.int64)
  
      total_written += 1