mindspore/model_zoo/Transformer/create_data.py

202 lines
8.3 KiB
Python
Raw Normal View History

2020-05-28 19:36:33 +08:00
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create training instances for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import numpy as np
import src.tokenization as tokenization
from mindspore.mindrecord import FileWriter
class SampleInstance():
"""A single sample instance (sentence pair)."""
def __init__(self, source_sos_tokens, source_eos_tokens, target_sos_tokens, target_eos_tokens):
self.source_sos_tokens = source_sos_tokens
self.source_eos_tokens = source_eos_tokens
self.target_sos_tokens = target_sos_tokens
self.target_eos_tokens = target_eos_tokens
def __str__(self):
s = ""
s += "source sos tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.source_sos_tokens]))
2020-05-28 19:36:33 +08:00
s += "source eos tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.source_eos_tokens]))
2020-05-28 19:36:33 +08:00
s += "target sos tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.target_sos_tokens]))
2020-05-28 19:36:33 +08:00
s += "target eos tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.target_eos_tokens]))
2020-05-28 19:36:33 +08:00
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_file(writer, instance, tokenizer, max_seq_length):
"""Create files from `SampleInstance`s."""
def _convert_ids_and_mask(input_tokens):
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
return input_ids, input_mask
source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens)
source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens)
target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens)
target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens)
features = collections.OrderedDict()
features["source_sos_ids"] = np.asarray(source_sos_ids)
features["source_sos_mask"] = np.asarray(source_sos_mask)
features["source_eos_ids"] = np.asarray(source_eos_ids)
features["source_eos_mask"] = np.asarray(source_eos_mask)
features["target_sos_ids"] = np.asarray(target_sos_ids)
features["target_sos_mask"] = np.asarray(target_sos_mask)
features["target_eos_ids"] = np.asarray(target_eos_ids)
features["target_eos_mask"] = np.asarray(target_eos_mask)
writer.write_raw_data([features])
return features
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
"""Creates `SampleInstance`s for a single sentence pair."""
EOS = "</s>"
SOS = "<s>"
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
if clip_to_max_len:
print("####lalalal")
source_words = source_words[:min([len(source_words, max_seq_length-1)])]
target_words = target_words[:min([len(target_words, max_seq_length-1)])]
else:
return None
source_sos_tokens = [SOS] + source_words
source_eos_tokens = source_words + [EOS]
target_sos_tokens = [SOS] + target_words
target_eos_tokens = target_words + [EOS]
instance = SampleInstance(
source_sos_tokens=source_sos_tokens,
source_eos_tokens=source_eos_tokens,
target_sos_tokens=target_sos_tokens,
target_eos_tokens=target_eos_tokens)
return instance
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True,
help='Input raw text file (or comma-separated list of files).')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
parser.add_argument("--num_splits", type=int, default=16,
help='The MindRecord file will be split into the number of partition.')
parser.add_argument("--vocab_file", type=str, required=True,
help='The vocabulary file that the Transformer model was trained on.')
parser.add_argument("--clip_to_max_len", type=bool, default=False,
help='clip sequences to maximum sequence length.')
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
args = parser.parse_args()
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
input_files = []
for input_pattern in args.input_file.split(","):
input_files.append(input_pattern)
logging.info("*** Reading from input files ***")
for input_file in input_files:
logging.info(" %s", input_file)
output_file = args.output_file
logging.info("*** Writing to output files ***")
logging.info(" %s", output_file)
writer = FileWriter(output_file, args.num_splits)
data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"source_eos_ids": {"type": "int64", "shape": [-1]},
"source_eos_mask": {"type": "int64", "shape": [-1]},
"target_sos_ids": {"type": "int64", "shape": [-1]},
"target_sos_mask": {"type": "int64", "shape": [-1]},
"target_eos_ids": {"type": "int64", "shape": [-1]},
"target_eos_mask": {"type": "int64", "shape": [-1]}
}
writer.add_schema(data_schema, "tranformer hisi")
total_written = 0
total_read = 0
for input_file in input_files:
logging.info("*** Reading from %s ***", input_file)
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
total_read += 1
if total_read % 100000 == 0:
logging.info("%d ...", total_read)
source_line, target_line = line.strip().split("\t")
source_tokens = tokenizer.tokenize(source_line)
target_tokens = tokenizer.tokenize(target_line)
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
logging.info("ignore long sentence!")
continue
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
clip_to_max_len=args.clip_to_max_len)
if instance is None:
continue
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length)
total_written += 1
if total_written <= 20:
logging.info("*** Example ***")
logging.info("source tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.source_eos_tokens]))
2020-05-28 19:36:33 +08:00
logging.info("target tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.target_sos_tokens]))
2020-05-28 19:36:33 +08:00
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
main()