From e6cbb48157a0771190494ccaf79442b9eb49e2b6 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 28 May 2020 19:36:33 +0800 Subject: [PATCH] add Transformer model --- model_zoo/Transformer/README.md | 176 +++ model_zoo/Transformer/create_data.py | 201 +++ model_zoo/Transformer/eval.py | 136 ++ .../Transformer/scripts/process_output.sh | 35 + .../Transformer/scripts/replace-quote.perl | 11 + .../scripts/run_distribute_train.sh | 63 + .../scripts/run_standalone_train.sh | 45 + model_zoo/Transformer/src/__init__.py | 0 model_zoo/Transformer/src/beam_search.py | 269 ++++ model_zoo/Transformer/src/config.py | 71 + model_zoo/Transformer/src/dataset.py | 48 + model_zoo/Transformer/src/eval_config.py | 69 + model_zoo/Transformer/src/lr_schedule.py | 52 + model_zoo/Transformer/src/process_output.py | 47 + model_zoo/Transformer/src/tokenization.py | 193 +++ .../Transformer/src/transformer_for_train.py | 341 +++++ .../Transformer/src/transformer_model.py | 1237 +++++++++++++++++ model_zoo/Transformer/src/weight_init.py | 52 + model_zoo/Transformer/train.py | 179 +++ 19 files changed, 3225 insertions(+) create mode 100644 model_zoo/Transformer/README.md create mode 100644 model_zoo/Transformer/create_data.py create mode 100644 model_zoo/Transformer/eval.py create mode 100644 model_zoo/Transformer/scripts/process_output.sh create mode 100644 model_zoo/Transformer/scripts/replace-quote.perl create mode 100644 model_zoo/Transformer/scripts/run_distribute_train.sh create mode 100644 model_zoo/Transformer/scripts/run_standalone_train.sh create mode 100644 model_zoo/Transformer/src/__init__.py create mode 100644 model_zoo/Transformer/src/beam_search.py create mode 100644 model_zoo/Transformer/src/config.py create mode 100644 model_zoo/Transformer/src/dataset.py create mode 100644 model_zoo/Transformer/src/eval_config.py create mode 100644 model_zoo/Transformer/src/lr_schedule.py create mode 100644 model_zoo/Transformer/src/process_output.py create mode 100644 model_zoo/Transformer/src/tokenization.py create mode 100644 model_zoo/Transformer/src/transformer_for_train.py create mode 100644 model_zoo/Transformer/src/transformer_model.py create mode 100644 model_zoo/Transformer/src/weight_init.py create mode 100644 model_zoo/Transformer/train.py diff --git a/model_zoo/Transformer/README.md b/model_zoo/Transformer/README.md new file mode 100644 index 00000000000..7ba0c8eb3d4 --- /dev/null +++ b/model_zoo/Transformer/README.md @@ -0,0 +1,176 @@ +# Transformer Example +## Description +This example implements training and evaluation of Transformer Model, which is introduced in the following paper: +- Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008. + +## Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download and preprocess the WMT English-German dataset for training and evaluation. + +> Notes:If you are running an evaluation task, prepare the corresponding checkpoint file. + +## Example structure + +```shell +. +└─Transformer + ├─README.md + ├─scripts + ├─process_output.sh + ├─replace-quote.perl + ├─run_distribute_train.sh + └─run_standalone_train.sh + ├─src + ├─__init__.py + ├─beam_search.py + ├─config.py + ├─dataset.py + ├─eval_config.py + ├─lr_schedule.py + ├─process_output.py + ├─tokenization.py + ├─transformer_for_train.py + ├─transformer_model.py + └─weight_init.py + ├─create_data.py + ├─eval.py + └─train.py +``` + +--- + +## Prepare the dataset +- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files: + - train.tok.clean.bpe.32000.en + - train.tok.clean.bpe.32000.de + - vocab.bpe.32000 + - newstest2014.tok.bpe.32000.en + - newstest2014.tok.bpe.32000.de + - newstest2014.tok.de + +- Convert the original data to mindrecord for training: + + ``` bash + paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all + python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 + ``` +- Convert the original data to mindrecord for evaluation: + + ``` bash + paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all + python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True + ``` + +## Running the example + +### Training +- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#mindspore) for more information about dataset. + +- Run `run_standalone_train.sh` for non-distributed training of Transformer model. + + ``` bash + sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH + ``` +- Run `run_distribute_train.sh` for distributed training of Transformer model. + + ``` bash + sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH + ``` + +### Evaluation +- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path. + +- Run `eval.py` for evaluation of Transformer model. + + ```bash + python eval.py + ``` + +- Run `process_output.sh` to process the output token ids to get the real translation results. + + ```bash + sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE + ``` + You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation. + +- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score. + + ```bash + perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu + ``` + +--- + +## Usage + +### Training +``` +usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] + [--enable_save_ckpt ENABLE_SAVE_CKPT] + [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] + [--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N] + [--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] + [--data_path DATA_PATH] + +options: + --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" + --epoch_size epoch size: N, default is 52 + --device_num number of used devices: N, default is 1 + --device_id device id: N, default is 0 + --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" + --enable_lossscale enable lossscale: "true" | "false", default is "true" + --do_shuffle enable shuffle: "true" | "false", default is "true" + --enable_data_sink enable data sink: "true" | "false", default is "false" + --checkpoint_path path to load checkpoint files: PATH, default is "" + --save_checkpoint_steps steps for saving checkpoint files: N, default is 2500 + --save_checkpoint_num number for saving checkpoint files: N, default is 30 + --save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/" + --data_path path to dataset file: PATH, default is "" +``` + +## Options and Parameters +It contains of parameters of Transformer model and options for training and evaluation, which is set in file `config.py` and `evaluation_config.py` respectively. +### Options: +``` +config.py: + transformer_network version of Transformer model: base | large, default is large + init_loss_scale_value initial value of loss scale: N, default is 2^10 + scale_factor factor used to update loss scale: N, default is 2 + scale_window steps for once updatation of loss scale: N, default is 2000 + optimizer optimizer used in the network: Adam, default is "Adam" + +eval_config.py: + transformer_network version of Transformer model: base | large, default is large + data_file data file: PATH + model_file checkpoint file to be loaded: PATH + output_file output file of evaluation: PATH +``` + +### Parameters: +``` +Parameters for dataset and network (Training/Evaluation): + batch_size batch size of input dataset: N, default is 96 + seq_length length of input sequence: N, default is 128 + vocab_size size of each embedding vector: N, default is 36560 + hidden_size size of Transformer encoder layers: N, default is 1024 + num_hidden_layers number of hidden layers: N, default is 6 + num_attention_heads number of attention heads: N, default is 16 + intermediate_size size of intermediate layer: N, default is 4096 + hidden_act activation function used: ACTIVATION, default is "relu" + hidden_dropout_prob dropout probability for TransformerOutput: Q, default is 0.3 + attention_probs_dropout_prob dropout probability for TransformerAttention: Q, default is 0.3 + max_position_embeddings maximum length of sequences: N, default is 128 + initializer_range initialization value of TruncatedNormal: Q, default is 0.02 + label_smoothing label smoothing setting: Q, default is 0.1 + input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True + beam_width beam width setting: N, default is 4 + max_decode_length max decode length in evaluation: N, default is 80 + length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0 + compute_type compute type in Transformer: mstype.float16 | mstype.float32, default is mstype.float16 + +Parameters for learning rate: + learning_rate value of learning rate: Q + warmup_steps steps of the learning rate warm up: N + start_decay_step step of the learning rate to decay: N + min_lr minimal learning rate: Q +``` \ No newline at end of file diff --git a/model_zoo/Transformer/create_data.py b/model_zoo/Transformer/create_data.py new file mode 100644 index 00000000000..af941623cbc --- /dev/null +++ b/model_zoo/Transformer/create_data.py @@ -0,0 +1,201 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Create training instances for Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import logging +import numpy as np +import src.tokenization as tokenization +from mindspore.mindrecord import FileWriter + +class SampleInstance(): + """A single sample instance (sentence pair).""" + + def __init__(self, source_sos_tokens, source_eos_tokens, target_sos_tokens, target_eos_tokens): + self.source_sos_tokens = source_sos_tokens + self.source_eos_tokens = source_eos_tokens + self.target_sos_tokens = target_sos_tokens + self.target_eos_tokens = target_eos_tokens + + def __str__(self): + s = "" + s += "source sos tokens: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.source_sos_tokens])) + s += "source eos tokens: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.source_eos_tokens])) + s += "target sos tokens: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.target_sos_tokens])) + s += "target eos tokens: %s\n" % (" ".join( + [tokenization.printable_text(x) for x in self.target_eos_tokens])) + s += "\n" + return s + + def __repr__(self): + return self.__str__() + + +def write_instance_to_file(writer, instance, tokenizer, max_seq_length): + """Create files from `SampleInstance`s.""" + + def _convert_ids_and_mask(input_tokens): + input_ids = tokenizer.convert_tokens_to_ids(input_tokens) + input_mask = [1] * len(input_ids) + assert len(input_ids) <= max_seq_length + + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + + return input_ids, input_mask + + source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens) + source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens) + target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens) + target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens) + + features = collections.OrderedDict() + features["source_sos_ids"] = np.asarray(source_sos_ids) + features["source_sos_mask"] = np.asarray(source_sos_mask) + features["source_eos_ids"] = np.asarray(source_eos_ids) + features["source_eos_mask"] = np.asarray(source_eos_mask) + features["target_sos_ids"] = np.asarray(target_sos_ids) + features["target_sos_mask"] = np.asarray(target_sos_mask) + features["target_eos_ids"] = np.asarray(target_eos_ids) + features["target_eos_mask"] = np.asarray(target_eos_mask) + + writer.write_raw_data([features]) + return features + +def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len): + """Creates `SampleInstance`s for a single sentence pair.""" + EOS = "" + SOS = "" + + if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length: + if clip_to_max_len: + print("####lalalal") + source_words = source_words[:min([len(source_words, max_seq_length-1)])] + target_words = target_words[:min([len(target_words, max_seq_length-1)])] + else: + return None + + source_sos_tokens = [SOS] + source_words + source_eos_tokens = source_words + [EOS] + target_sos_tokens = [SOS] + target_words + target_eos_tokens = target_words + [EOS] + + instance = SampleInstance( + source_sos_tokens=source_sos_tokens, + source_eos_tokens=source_eos_tokens, + target_sos_tokens=target_sos_tokens, + target_eos_tokens=target_eos_tokens) + return instance + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_file", type=str, required=True, + help='Input raw text file (or comma-separated list of files).') + parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.') + parser.add_argument("--num_splits", type=int, default=16, + help='The MindRecord file will be split into the number of partition.') + parser.add_argument("--vocab_file", type=str, required=True, + help='The vocabulary file that the Transformer model was trained on.') + parser.add_argument("--clip_to_max_len", type=bool, default=False, + help='clip sequences to maximum sequence length.') + parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.') + args = parser.parse_args() + + tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file) + + input_files = [] + for input_pattern in args.input_file.split(","): + input_files.append(input_pattern) + + logging.info("*** Reading from input files ***") + for input_file in input_files: + logging.info(" %s", input_file) + + output_file = args.output_file + logging.info("*** Writing to output files ***") + logging.info(" %s", output_file) + + writer = FileWriter(output_file, args.num_splits) + data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "source_eos_ids": {"type": "int64", "shape": [-1]}, + "source_eos_mask": {"type": "int64", "shape": [-1]}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]} + } + writer.add_schema(data_schema, "tranformer hisi") + + total_written = 0 + total_read = 0 + + for input_file in input_files: + logging.info("*** Reading from %s ***", input_file) + with open(input_file, "r") as reader: + while True: + line = tokenization.convert_to_unicode(reader.readline()) + if not line: + break + + total_read += 1 + if total_read % 100000 == 0: + logging.info("%d ...", total_read) + + source_line, target_line = line.strip().split("\t") + source_tokens = tokenizer.tokenize(source_line) + target_tokens = tokenizer.tokenize(target_line) + + if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length: + logging.info("ignore long sentence!") + continue + + instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length, + clip_to_max_len=args.clip_to_max_len) + if instance is None: + continue + + features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length) + total_written += 1 + + if total_written <= 20: + logging.info("*** Example ***") + logging.info("source tokens: %s", " ".join( + [tokenization.printable_text(x) for x in instance.source_eos_tokens])) + logging.info("target tokens: %s", " ".join( + [tokenization.printable_text(x) for x in instance.target_sos_tokens])) + + for feature_name in features.keys(): + feature = features[feature_name] + logging.info("%s: %s", feature_name, feature) + + writer.commit() + logging.info("Wrote %d total instances", total_written) + + +if __name__ == "__main__": + main() diff --git a/model_zoo/Transformer/eval.py b/model_zoo/Transformer/eval.py new file mode 100644 index 00000000000..26d00f1c589 --- /dev/null +++ b/model_zoo/Transformer/eval.py @@ -0,0 +1,136 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer evaluation script.""" + +import os +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as deC +from mindspore import context + +from src.transformer_model import TransformerModel +from src.eval_config import cfg, transformer_net_cfg + +def load_test_data(batch_size=1, data_file=None): + """ + Load test dataset + """ + ds = de.MindDataset(data_file, + columns_list=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"], + shuffle=False) + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + ds.channel_name = 'transformer' + return ds + +class TransformerInferCell(nn.Cell): + """ + Encapsulation class of transformer network infer. + """ + def __init__(self, network): + super(TransformerInferCell, self).__init__(auto_prefix=False) + self.network = network + + def construct(self, + source_ids, + source_mask): + predicted_ids = self.network(source_ids, source_mask) + return predicted_ids + +def load_weights(model_path): + """ + Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file. + """ + if model_path.endswith(".npz"): + ms_ckpt = np.load(model_path) + is_npz = True + else: + ms_ckpt = load_checkpoint(model_path) + is_npz = False + + weights = {} + for msname in ms_ckpt: + infer_name = msname.replace("transformer.transformer.", "") + if "tfm_decoder" in msname: + infer_name = infer_name.replace(".layers.", ".layer") + infer_name = "tfm_decoder.decoder." + infer_name + if is_npz: + weights[infer_name] = ms_ckpt[msname] + else: + weights[infer_name] = ms_ckpt[msname].data.asnumpy() + weights["tfm_decoder.decoder.tfm_embedding_lookup.embedding_table"] = \ + weights["tfm_embedding_lookup.embedding_table"] + + parameter_dict = {} + for name in weights: + parameter_dict[name] = Parameter(Tensor(weights[name]), name=name) + return parameter_dict + +def run_transformer_eval(): + """ + Transformer evaluation. + """ + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False, + device_id=device_id) + + dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=cfg.data_file) + tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False) + + parameter_dict = load_weights(cfg.model_file) + load_param_into_net(tfm_model, parameter_dict) + + tfm_infer = TransformerInferCell(tfm_model) + model = Model(tfm_infer) + + predictions = [] + source_sents = [] + target_sents = [] + for batch in dataset.create_dict_iterator(): + source_sents.append(batch["source_eos_ids"]) + target_sents.append(batch["target_eos_ids"]) + source_ids = Tensor(batch["source_eos_ids"], mstype.int32) + source_mask = Tensor(batch["source_eos_mask"], mstype.int32) + predicted_ids = model.predict(source_ids, source_mask) + predictions.append(predicted_ids.asnumpy()) + + # decode and write to file + f = open(cfg.output_file, 'w') + for batch_out in predictions: + for i in range(transformer_net_cfg.batch_size): + if batch_out.ndim == 3: + batch_out = batch_out[:, 0] + token_ids = [str(x) for x in batch_out[i].tolist()] + f.write(" ".join(token_ids) + "\n") + f.close() + +if __name__ == "__main__": + run_transformer_eval() diff --git a/model_zoo/Transformer/scripts/process_output.sh b/model_zoo/Transformer/scripts/process_output.sh new file mode 100644 index 00000000000..c7bc2b5e4ef --- /dev/null +++ b/model_zoo/Transformer/scripts/process_output.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE" +echo "for example: sh process_output.sh /path/newstest2014.tok.de /path/eval_output_file /path/vocab.bpe.32000" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +BASEDIR=$(dirname "$0") + +ref_data=$1 +eval_output=$2 +vocab_file=$3 + +cat $eval_output \ + | python src/process_output.py --vocab_file $vocab_file \ + | sed 's/@@ //g' > ${eval_output}.processed + +perl -ple 's/(\S)-(\S)/$1 #@#-#@# $2/g' < $ref_data | perl ${BASEDIR}/replace-quote.perl > ${ref_data}.forbleu +perl -ple 's/(\S)-(\S)/$1 #@#-#@# $2/g' < ${eval_output}.processed > ${eval_output}.forbleu \ No newline at end of file diff --git a/model_zoo/Transformer/scripts/replace-quote.perl b/model_zoo/Transformer/scripts/replace-quote.perl new file mode 100644 index 00000000000..95f9abcc912 --- /dev/null +++ b/model_zoo/Transformer/scripts/replace-quote.perl @@ -0,0 +1,11 @@ +#!/usr/bin/env perl + +use warnings; +use strict; + +while() { + s/”/\"/g; + s/“/\"/g; + s/„/\"/g; + print $_; +} \ No newline at end of file diff --git a/model_zoo/Transformer/scripts/run_distribute_train.sh b/model_zoo/Transformer/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..772e690dc2e --- /dev/null +++ b/model_zoo/Transformer/scripts/run_distribute_train.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH" +echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +rm -rf run_distribute_train +mkdir run_distribute_train +cd run_distribute_train || exit + +EPOCH_SIZE=$2 +DATA_PATH=$3 + +export MINDSPORE_HCCL_CONFIG_PATH=$4 +export RANK_TABLE_FILE=$4 +export RANK_SIZE=$1 +export HCCL_FLAG=1 +export DEPLOY_MODE=0 + +for((i=0;i env.log + python train.py \ + --distribute="true" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --device_num=$RANK_SIZE \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --enable_data_sink="false" \ + --checkpoint_path="" \ + --save_checkpoint_steps=2500 \ + --save_checkpoint_num=30 \ + --data_path=$DATA_PATH > log.txt 2>&1 & + cd ../ +done +cd .. \ No newline at end of file diff --git a/model_zoo/Transformer/scripts/run_standalone_train.sh b/model_zoo/Transformer/scripts/run_standalone_train.sh new file mode 100644 index 00000000000..8e677191a8a --- /dev/null +++ b/model_zoo/Transformer/scripts/run_standalone_train.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH" +echo "for example: sh run_standalone_train.sh 0 52 /path/ende-l128-mindrecord00" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +rm -rf run_standalone_train +mkdir run_standalone_train +cp -rf ./src/ train.py ./run_standalone_train +cd run_standalone_train || exit + +export DEVICE_ID=$1 +EPOCH_SIZE=$2 +DATA_PATH=$3 + +python train.py \ + --distribute="false" \ + --epoch_size=$EPOCH_SIZE \ + --device_id=$DEVICE_ID \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --enable_data_sink="false" \ + --checkpoint_path="" \ + --save_checkpoint_steps=2500 \ + --save_checkpoint_num=30 \ + --data_path=$DATA_PATH > log.txt 2>&1 & +cd .. \ No newline at end of file diff --git a/model_zoo/Transformer/src/__init__.py b/model_zoo/Transformer/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/Transformer/src/beam_search.py b/model_zoo/Transformer/src/beam_search.py new file mode 100644 index 00000000000..9742924a736 --- /dev/null +++ b/model_zoo/Transformer/src/beam_search.py @@ -0,0 +1,269 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer beam search module.""" + +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + +INF = 1. * 1e9 + +class LengthPenalty(nn.Cell): + """ + Normalize scores of translations according to their length. + + Args: + weight (float): Weight of length penalty. Default: 1.0. + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + weight=1.0, + compute_type=mstype.float32): + super(LengthPenalty, self).__init__() + self.weight = weight + self.add = P.TensorAdd() + self.pow = P.Pow() + self.div = P.RealDiv() + self.cast = P.Cast() + self.five = Tensor(5.0, mstype.float32) + self.six = Tensor(6.0, mstype.float32) + + def construct(self, length_tensor): + length_tensor = self.cast(length_tensor, mstype.float32) + output = self.add(length_tensor, self.five) + output = self.div(output, self.six) + output = self.pow(output, self.weight) + return output + + +class TileBeam(nn.Cell): + """ + TileBeam. + + Args: + beam_width (int): beam width setting. Default: 4. + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + beam_width, + compute_type=mstype.float32): + super(TileBeam, self).__init__() + self.beam_width = beam_width + self.expand = P.ExpandDims() + self.tile = P.Tile() + self.reshape = P.Reshape() + self.shape = P.Shape() + + def construct(self, input_tensor): + """ + input_tensor: shape [batch, dim1, dim2] + output_tensor: shape [batch*beam, dim1, dim2] + """ + shape = self.shape(input_tensor) + input_tensor = self.expand(input_tensor, 1) + tile_shape = (1,) + (self.beam_width,) + for _ in range(len(shape)-1): + tile_shape = tile_shape + (1,) + output = self.tile(input_tensor, tile_shape) + out_shape = (shape[0]*self.beam_width,) + shape[1:] + output = self.reshape(output, out_shape) + return output + + +class Mod(nn.Cell): + """ + Mod function. + + Args: + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + compute_type=mstype.float32): + super(Mod, self).__init__() + self.compute_type = compute_type + self.floor_div = P.FloorDiv() + self.sub = P.Sub() + self.multiply = P.Mul() + + def construct(self, input_x, input_y): + x = self.floor_div(input_x, input_y) + x = self.multiply(x, input_y) + x = self.sub(input_x, x) + return x + + +class BeamSearchDecoder(nn.Cell): + """ + Beam search decoder. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + vocab_size (int): Size of vocabulary. + decoder (:class:`TransformerDecoderStep`): Decoder module. + beam_width (int): beam width setting. Default: 4. + length_penalty_weight (float): Weight of length penalty. Default: 1.0. + max_decode_length (int): max decode length. Default: 128. + sos_id (int): Id of sequence start token. Default: 1. + eos_id (int): Id of sequence end token. Default: 2. + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + vocab_size, + decoder, + beam_width=4, + length_penalty_weight=1.0, + max_decode_length=128, + sos_id=1, + eos_id=2, + compute_type=mstype.float32): + super(BeamSearchDecoder, self).__init__(auto_prefix=False) + self.batch_size = batch_size + self.vocab_size = vocab_size + self.beam_width = beam_width + self.length_penalty_weight = length_penalty_weight + self.max_decode_length = max_decode_length + self.decoder = decoder + + self.add = P.TensorAdd() + self.expand = P.ExpandDims() + self.reshape = P.Reshape() + self.shape_flat = (-1,) + self.shape = P.Shape() + + self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32) + self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32) + + self.select = P.Select() + self.flat_shape = (batch_size, beam_width * vocab_size) + self.topk = P.TopK(sorted=True) + self.floor_div = P.FloorDiv() + self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32) + self.real_div = P.RealDiv() + self.mod = Mod() + self.equal = P.Equal() + self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32) + + beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1]) + self.beam_ids = Tensor(beam_ids, mstype.int32) + batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width + self.batch_ids = Tensor(batch_ids, mstype.int32) + self.concat = P.Concat(axis=-1) + self.gather_nd = P.GatherNd() + + # init inputs and states + self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) + self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) + init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1]) + self.init_scores = Tensor(init_scores, mstype.float32) + self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool)) + self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32)) + self.length_penalty = LengthPenalty(weight=length_penalty_weight) + self.one = Tensor(1, mstype.int32) + + def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, + state_seq, state_finished, state_length): + """ + One step for decode + """ + log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask) + log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size)) + + # select topk indices + total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1)) + + # mask finished beams + mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor) + total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1)) + + # reshape scores to [batch, beam*vocab] + flat_scores = self.reshape(total_log_probs, self.flat_shape) + # select topk + topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) + + # convert to beam and word indices + beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) + word_indices = self.mod(topk_indices, self.vocab_size_tensor) + + # mask finished indices + beam_indices = self.select(state_finished, self.beam_ids, beam_indices) + word_indices = self.select(state_finished, self.eos_ids, word_indices) + topk_scores = self.select(state_finished, state_log_probs, topk_scores) + + ###### put finished sequences to the end + # sort according to scores with -inf for finished beams + tmp_log_probs = self.select( + self.equal(word_indices, self.eos_ids), + self.ninf_tensor, + topk_scores) + _, tmp_indices = self.topk(tmp_log_probs, self.beam_width) + # update + tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1))) + beam_indices = self.gather_nd(beam_indices, tmp_gather_indices) + word_indices = self.gather_nd(word_indices, tmp_gather_indices) + topk_scores = self.gather_nd(topk_scores, tmp_gather_indices) + + ###### generate new beam_search states + # gather indices for selecting alive beams + gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1))) + + # length add 1 if not finished in the previous step + length_add = self.add(state_length, self.one) + state_length = self.select(state_finished, state_length, length_add) + state_length = self.gather_nd(state_length, gather_indices) + + # concat seq + seq = self.gather_nd(state_seq, gather_indices) + state_seq = self.concat((seq, self.expand(word_indices, -1))) + + # new finished flag and log_probs + state_finished = self.equal(word_indices, self.eos_ids) + state_log_probs = topk_scores + + ###### generate new inputs and decoder states + cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1)) + return cur_input_ids, state_log_probs, state_seq, state_finished, state_length + + def construct(self, enc_states, enc_attention_mask): + cur_input_ids = self.start_ids + # beam search states + state_log_probs = self.init_scores + state_seq = self.init_seq + state_finished = self.init_finished + state_length = self.init_length + + for _ in range(self.max_decode_length): + # run one step decoder to get outputs of the current step + # shape [batch*beam, 1, vocab] + cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step( + cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length) + + # add length penalty scores + penalty_len = self.length_penalty(state_length) + # return penalty_len + log_probs = self.real_div(state_log_probs, penalty_len) + + # sort according to scores + _, top_beam_indices = self.topk(log_probs, self.beam_width) + gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) + # sort sequence + predicted_ids = self.gather_nd(state_seq, gather_indices) + # take the first one + predicted_ids = predicted_ids[::, 0:1:1, ::] + return predicted_ids diff --git a/model_zoo/Transformer/src/config.py b/model_zoo/Transformer/src/config.py new file mode 100644 index 00000000000..25d23a1fbbb --- /dev/null +++ b/model_zoo/Transformer/src/config.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network config setting, will be used in dataset.py, train.py.""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .transformer_model import TransformerConfig +cfg = edict({ + 'transformer_network': 'large', + 'init_loss_scale_value': 1024, + 'scale_factor': 2, + 'scale_window': 2000, + 'optimizer': 'Adam', + 'lr_schedule': edict({ + 'learning_rate': 2.0, + 'warmup_steps': 8000, + 'start_decay_step': 16000, + 'min_lr': 0.0, + }), +}) +''' +two kinds of transformer model version +''' +if cfg.transformer_network == 'large': + transformer_net_cfg = TransformerConfig( + batch_size=96, + seq_length=128, + vocab_size=36560, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="relu", + hidden_dropout_prob=0.2, + attention_probs_dropout_prob=0.2, + max_position_embeddings=128, + initializer_range=0.02, + label_smoothing=0.1, + input_mask_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16) +if cfg.transformer_network == 'base': + transformer_net_cfg = TransformerConfig( + batch_size=96, + seq_length=128, + vocab_size=36560, + hidden_size=512, + num_hidden_layers=6, + num_attention_heads=8, + intermediate_size=2048, + hidden_act="relu", + hidden_dropout_prob=0.2, + attention_probs_dropout_prob=0.2, + max_position_embeddings=128, + initializer_range=0.02, + label_smoothing=0.1, + input_mask_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16) diff --git a/model_zoo/Transformer/src/dataset.py b/model_zoo/Transformer/src/dataset.py new file mode 100644 index 00000000000..5b006046a5d --- /dev/null +++ b/model_zoo/Transformer/src/dataset.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data operations, will be used in train.py.""" + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as deC +from mindspore import log as logger +from .config import transformer_net_cfg + +def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", + dataset_path=None): + """create dataset""" + repeat_count = epoch_count + ds = de.MindDataset(dataset_path, + columns_list=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"], + shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id) + + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op) + ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op) + + # apply batch operations + ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) + ds = ds.repeat(repeat_count) + + ds.channel_name = 'transformer' + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeatcount: {}".format(ds.get_repeat_count())) + return ds, repeat_count diff --git a/model_zoo/Transformer/src/eval_config.py b/model_zoo/Transformer/src/eval_config.py new file mode 100644 index 00000000000..e3d3915867b --- /dev/null +++ b/model_zoo/Transformer/src/eval_config.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network evaluation config setting, will be used in eval.py.""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .transformer_model import TransformerConfig + +cfg = edict({ + 'transformer_network': 'large', + 'data_file': '/your/path/evaluation.mindrecord', + 'model_file': '/your/path/checkpoint_file', + 'output_file': '/your/path/output', +}) +''' +two kinds of transformer model version +''' +if cfg.transformer_network == 'large': + transformer_net_cfg = TransformerConfig( + batch_size=1, + seq_length=128, + vocab_size=36560, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="relu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=128, + label_smoothing=0.1, + input_mask_from_dataset=True, + beam_width=4, + max_decode_length=80, + length_penalty_weight=1.0, + dtype=mstype.float32, + compute_type=mstype.float16) +if cfg.transformer_network == 'base': + transformer_net_cfg = TransformerConfig( + batch_size=1, + seq_length=128, + vocab_size=36560, + hidden_size=512, + num_hidden_layers=6, + num_attention_heads=8, + intermediate_size=2048, + hidden_act="relu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=128, + label_smoothing=0.1, + input_mask_from_dataset=True, + beam_width=4, + max_decode_length=80, + length_penalty_weight=1.0, + dtype=mstype.float32, + compute_type=mstype.float16) diff --git a/model_zoo/Transformer/src/lr_schedule.py b/model_zoo/Transformer/src/lr_schedule.py new file mode 100644 index 00000000000..1f393737387 --- /dev/null +++ b/model_zoo/Transformer/src/lr_schedule.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate utilities.""" + +def linear_warmup(warmup_steps, current_step): + return min([1.0, float(current_step)/float(warmup_steps)]) + +def rsqrt_decay(warmup_steps, current_step): + return float(max([current_step, warmup_steps])) ** -0.5 + +def rsqrt_hidden(hidden_size): + return float(hidden_size) ** -0.5 + +def create_dynamic_lr(schedule, training_steps, learning_rate, warmup_steps, hidden_size, + start_decay_step=0, min_lr=0.): + """ + Generate dynamic learning rate. + """ + if start_decay_step < warmup_steps: + start_decay_step = warmup_steps + lr = [] + for current_step in range(1, training_steps+1): + cur_lr = 1.0 + for name in schedule.split("*"): + if name == "constant": + cur_lr *= float(learning_rate) + elif name == "rsqrt_hidden": + cur_lr *= rsqrt_hidden(hidden_size) + elif name == "linear_warmup": + cur_lr *= linear_warmup(warmup_steps, current_step) + elif name == "rsqrt_decay": + cur_lr *= rsqrt_decay(warmup_steps, current_step-start_decay_step+warmup_steps) + else: + raise ValueError("unknown learning rate schedule") + if warmup_steps < current_step < start_decay_step: + cur_lr = lr[-1] + if current_step > warmup_steps: + cur_lr = max([cur_lr, min_lr]) + lr.append(cur_lr) + return lr diff --git a/model_zoo/Transformer/src/process_output.py b/model_zoo/Transformer/src/process_output.py new file mode 100644 index 00000000000..cccff02e090 --- /dev/null +++ b/model_zoo/Transformer/src/process_output.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Convert ids to tokens.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tokenization + +# Explicitly set the encoding +sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True) +sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True) + +def main(): + parser = argparse.ArgumentParser( + description="recore nbest with smoothed sentence-level bleu.") + parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.") + args = parser.parse_args() + + tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file) + + for line in sys.stdin: + token_ids = [int(x) for x in line.strip().split()] + tokens = tokenizer.convert_ids_to_tokens(token_ids) + sent = " ".join(tokens) + sent = sent.split("")[-1] + sent = sent.split("")[0] + print(sent.strip()) + +if __name__ == "__main__": + main() diff --git a/model_zoo/Transformer/src/tokenization.py b/model_zoo/Transformer/src/tokenization.py new file mode 100644 index 00000000000..fd0fc979556 --- /dev/null +++ b/model_zoo/Transformer/src/tokenization.py @@ -0,0 +1,193 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################### +# Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes: +# - Remove some unused classes and functions +# - Modify load_vocab, convert_to_unicode, printable_text function +# - Modify BasicTokenizer class +# - Add WhiteSpaceTokenizer class +############################################################################### + +"""Tokenization utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import unicodedata +import six + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + if isinstance(text, bytes): + return text.decode("utf-8", "ignore") + raise ValueError("Unsupported string type: %s" % (type(text))) + if six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + if isinstance(text, unicode): + return text + raise ValueError("Unsupported string type: %s" % (type(text))) + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + if isinstance(text, bytes): + return text.decode("utf-8", "ignore") + raise ValueError("Unsupported string type: %s" % (type(text))) + if six.PY2: + if isinstance(text, str): + return text + if isinstance(text, unicode): + return text.encode("utf-8") + raise ValueError("Unsupported string type: %s" % (type(text))) + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + if item in vocab: + output.append(vocab[item]) + else: + output.append(vocab[""]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class WhiteSpaceTokenizer(): + """Runs end-to-end tokenziation.""" + def __init__(self, vocab_file): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer() + + def tokenize(self, text): + return self.basic_tokenizer.tokenize(text) + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + +class BasicTokenizer(): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self): + """Constructs a BasicTokenizer.""" + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + return whitespace_tokenize(text) + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char in (" ", "\t", "\n", "\r"): + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char in ("\t", "\n", "\r"): + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/model_zoo/Transformer/src/transformer_for_train.py b/model_zoo/Transformer/src/transformer_for_train.py new file mode 100644 index 00000000000..ac54aee7f9b --- /dev/null +++ b/model_zoo/Transformer/src/transformer_for_train.py @@ -0,0 +1,341 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer for training.""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean +from mindspore.communication.management import get_group_size +from mindspore import context +from .transformer_model import TransformerModel + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 5.0 + +class ClipGradients(nn.Cell): + """ + Clip gradients. + + Args: + grads (list): List of gradient tuples. + clip_type (Tensor): The way to clip, 'value' or 'norm'. + clip_value (Tensor): Specifies how much to clip. + + Returns: + List, a list of clipped_grad tuples. + """ + def __init__(self): + super(ClipGradients, self).__init__() + self.clip_by_norm = nn.ClipByNorm() + self.cast = P.Cast() + self.dtype = P.DType() + def construct(self, + grads, + clip_type, + clip_value): + #return grads + if clip_type != 0 and clip_type != 1: + return grads + + new_grads = () + for grad in grads: + dt = self.dtype(grad) + if clip_type == 0: + t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), + self.cast(F.tuple_to_array((clip_value,)), dt)) + else: + t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) + new_grads = new_grads + (t,) + + return new_grads + + +class TransformerTrainingLoss(nn.Cell): + """ + Provide transformer training loss. + + Args: + config (TransformerConfig): The config of Transformer. + + Returns: + Tensor, total loss. + """ + def __init__(self, config): + super(TransformerTrainingLoss, self).__init__(auto_prefix=False) + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(float(1-config.label_smoothing), mstype.float32) + self.off_value = Tensor(config.label_smoothing/float(self.vocab_size-1), mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.flatten = P.Flatten() + self.neg = P.Neg() + self.cast = P.Cast() + self.flat_shape = (config.batch_size*config.seq_length,) + + def construct(self, prediction_scores, label_ids, label_weights): + """Defines the computation performed.""" + label_ids = self.reshape(label_ids, self.flat_shape) + label_weights = self.cast(self.reshape(label_weights, self.flat_shape), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + \ + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + loss = numerator / denominator + return loss + + +class TransformerNetworkWithLoss(nn.Cell): + """ + Provide transformer training loss through network. + + Args: + config (TransformerConfig): The config of Transformer. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False) + self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings) + self.loss = TransformerTrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights): + prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask) + total_loss = self.loss(prediction_scores, label_ids, label_weights) + return self.cast(total_loss, mstype.float32) + + +class TransformerTrainOneStepCell(nn.Cell): + """ + Encapsulation class of transformer network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode not in ParallelMode.MODE_LIST: + raise ValueError("Parallel mode does not support: ", parallel_mode) + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.clip_gradients = ClipGradients() + self.cast = P.Cast() + + def set_sens(self, value): + self.sens = value + + def construct(self, + source_eos_ids, + source_eos_mask, + target_sos_ids, + target_sos_mask, + target_eos_ids, + target_eos_mask,): + """Defines the computation performed.""" + source_ids = source_eos_ids + source_mask = source_eos_mask + target_ids = target_sos_ids + target_mask = target_sos_mask + label_ids = target_eos_ids + label_weights = target_eos_mask + + weights = self.weights + loss = self.network(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights) + grads = self.grad(self.network, weights)(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + + succ = self.optimizer(grads) + return F.depend(loss, succ) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + +class TransformerTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of Transformer network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + + self.parallel_mode = _get_parallel_mode() + if self.parallel_mode not in ParallelMode.MODE_LIST: + raise ValueError("Parallel mode does not support: ", parallel_mode) + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.clip_gradients = ClipGradients() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + self.add_flags(has_effect=True) + + def construct(self, + source_eos_ids, + source_eos_mask, + target_sos_ids, + target_sos_mask, + target_eos_ids, + target_eos_mask, + sens=None): + """Defines the computation performed.""" + source_ids = source_eos_ids + source_mask = source_eos_mask + target_ids = target_sos_ids + target_mask = target_sos_mask + label_ids = target_eos_ids + label_weights = target_eos_mask + + weights = self.weights + loss = self.network(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights) + # alloc status + init = self.alloc_status() + # clear overflow buffer + self.clear_before_grad(init) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + grads = self.grad(self.network, weights)(source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_weights, + self.cast(scaling_sens, + mstype.float32)) + + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/Transformer/src/transformer_model.py b/model_zoo/Transformer/src/transformer_model.py new file mode 100644 index 00000000000..17b5127dca4 --- /dev/null +++ b/model_zoo/Transformer/src/transformer_model.py @@ -0,0 +1,1237 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer model.""" + +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .beam_search import BeamSearchDecoder, TileBeam + +class TransformerConfig: + """ + Configuration for `Transformer`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 36560. + hidden_size (int): Size of the layers. Default: 1024. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder/decoder + cell. Default: 6. + num_attention_heads (int): Number of attention heads in the Transformer + encoder/decoder cell. Default: 16. + intermediate_size (int): Size of intermediate layer in the Transformer + encoder/decoder cell. Default: 4096. + hidden_act (str): Activation function used in the Transformer encoder/decoder + cell. Default: "relu". + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.3. + attention_probs_dropout_prob (float): The dropout probability for + MultiheadAttention. Default: 0.3. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 128. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + label_smoothing (float): label smoothing setting. Default: 0.1 + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + beam_width (int): beam width setting. Default: 4 + max_decode_length (int): max decode length in evaluation. Default: 80 + length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0 + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=36560, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="relu", + hidden_dropout_prob=0.3, + attention_probs_dropout_prob=0.3, + max_position_embeddings=128, + initializer_range=0.02, + label_smoothing=0.1, + input_mask_from_dataset=True, + beam_width=4, + max_decode_length=80, + length_penalty_weight=1.0, + dtype=mstype.float32, + compute_type=mstype.float32): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.label_smoothing = label_smoothing + self.input_mask_from_dataset = input_mask_from_dataset + self.beam_width = beam_width + self.max_decode_length = max_decode_length + self.length_penalty_weight = length_penalty_weight + self.dtype = dtype + self.compute_type = compute_type + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = P.Shape() + + def construct(self, input_ids): + input_shape = self.shape(input_ids) + + flat_ids = self.reshape(input_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + + out_shape = input_shape + (self.embedding_size,) + output = self.reshape(output_for_reshape, out_shape) + return output, self.embedding_table + + +def position_encoding(length, + depth, + min_timescale=1, + max_timescale=1e4): + """ + Create Tensor of sinusoids of different frequencies. + + Args: + length (int): Length of the Tensor to create, i.e. Number of steps. + depth (int): Hidden size. + min_timescale (float): Default: 1. + max_timescale (float): Default: 10000. + + Returns: + Tensor of shape (length, depth) + """ + depth = depth // 2 + positions = np.arange(length, dtype=np.float32) + log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1)) + inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment) + scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0) + x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) + return x + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 128. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + embedding_size, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=128, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=mstype.float32) + self.multiply = P.Mul() + self.add = P.TensorAdd() + self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32) + self.use_dropout = dropout_prob > 0 + self.expand_dims = P.ExpandDims() + self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size), + mstype.float32) + self.shape = P.Shape() + + def construct(self, word_embeddings): + input_shape = self.shape(word_embeddings) + input_len = input_shape[1] + + output = self.multiply(word_embeddings, self.scores_mul) + + # add position embeddings + position_embeddings = self.position_embedding_table[0:input_len:1, ::] + position_embeddings = self.expand_dims(position_embeddings, 0) + output = self.add(output, position_embeddings) + + if self.use_dropout: + output = self.dropout(output) + return output + + +class CastWrapper(nn.Cell): + """ + Cast wrapper. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(CastWrapper, self).__init__() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + return self.cast(x, self.dst_type) + + +class LayerPreprocess(nn.Cell): + """ + preprocess input of each layer. + """ + def __init__(self, + in_channels=None): + super(LayerPreprocess, self).__init__() + self.layernorm = nn.LayerNorm((in_channels,)) + self.cast = P.Cast() + self.get_dtype = P.DType() + + def construct(self, input_tensor): + output = self.cast(input_tensor, mstype.float32) + output = self.layernorm(output) + output = self.cast(output, self.get_dtype(input_tensor)) + return output + + +class LayerPostprocess(nn.Cell): + """ + postprocess ouput of each layer. + """ + def __init__(self, + dropout_prob=0.1): + super(LayerPostprocess, self).__init__() + self.add = P.TensorAdd() + self.dropout = nn.Dropout(1 - dropout_prob) + self.use_dropout = dropout_prob > 0 + + def construct(self, hidden_tensor, input_tensor): + output = hidden_tensor + if self.use_dropout: + output = self.dropout(output) + output = self.add(output, input_tensor) + return output + + +class MultiheadAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + MultiheadAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + out_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + out_act=None, + has_attention_mask=True, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=True, + compute_type=mstype.float32): + super(MultiheadAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + assert has_attention_mask + + self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + has_bias=False, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + has_bias=False, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + has_bias=False, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.out_layer = nn.Dense(units, + out_tensor_width, + activation=out_act, + has_bias=False, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + self.use_dropout = attention_probs_dropout_prob > 0 + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + if from_seq_length == -1: + self.shape_return = (-1, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + + self.cast_compute_type = CastWrapper(dst_type=compute_type) + self.softmax_cast = P.Cast() + + def construct(self, from_tensor, to_tensor, attention_mask=None): + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + attention_scores = self.multiply(attention_scores, self.scores_mul) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_scores = self.softmax_cast(attention_scores, mstype.float32) + attention_probs = self.softmax(attention_scores) + attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key_layer)) + if self.use_dropout: + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + context_layer = self.out_layer(context_layer) + return context_layer + + +class SelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + from_seq_length (int): Length of query sequence. + to_seq_length (int): Length of memory sequence. + hidden_size (int): Size of attention layers. + num_attention_heads (int): Number of attention heads. Default: 16. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + has_attention_mask (bool): Specifies whether has attention mask. Default: True. + is_encdec_att (bool): Specifies whether query sequence and memory sequence are different. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_seq_length, + to_seq_length, + hidden_size, + num_attention_heads=16, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + has_attention_mask=True, + is_encdec_att=False, + compute_type=mstype.float32): + super(SelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + self.size_per_head = int(hidden_size / num_attention_heads) + self.is_encdec_att = is_encdec_att + + self.attention = MultiheadAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + out_tensor_width=hidden_size, + from_seq_length=from_seq_length, + to_seq_length=to_seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + has_attention_mask=has_attention_mask, + do_return_2d_tensor=True, + compute_type=compute_type) + + self.preprocess = LayerPreprocess(in_channels=hidden_size) + self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + def construct(self, input_tensor, memory_tensor, attention_mask): + input_tensor = self.reshape(input_tensor, self.shape) + memory_tensor = self.reshape(memory_tensor, self.shape) + + output = self.preprocess(input_tensor) + + if not self.is_encdec_att: + memory_tensor = output + + attention_output = self.attention(output, memory_tensor, attention_mask) + output = self.postprocess(attention_output, input_tensor) + return output + + +class FeedForward(nn.Cell): + """ + Apply two-layer feed forward + + Args: + in_channels (int): Size of the input layer. + hidden_size (int): Size of the hidden layer. + out_channels (int): Size of the output layers. + hidden_act (str): name of the activation function. Default: relu + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in FeedForward. Default: mstype.float32. + """ + def __init__(self, + in_channels, + hidden_size, + out_channels, + hidden_act="relu", + initializer_range=0.02, + hidden_dropout_prob=0.1, + compute_type=mstype.float32): + super(FeedForward, self).__init__() + + self.conv1 = nn.Dense(in_channels, + hidden_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.conv2 = nn.Dense(hidden_size, + out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + + self.preprocess = LayerPreprocess(in_channels=in_channels) + self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) + + self.reshape = P.Reshape() + self.shape = (-1, in_channels) + self.dropout = nn.Dropout(1 - hidden_dropout_prob) + self.use_dropout = hidden_dropout_prob > 0 + + def construct(self, input_tensor): + input_tensor = self.reshape(input_tensor, self.shape) + output = self.preprocess(input_tensor) + output = self.conv1(output) + if self.use_dropout: + output = self.dropout(output) + output = self.conv2(output) + output = self.postprocess(output, input_tensor) + return output + + +class EncoderCell(nn.Cell): + """ + Encoder cells used in Transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. Default: 1024. + seq_length (int): Length of input sequence. Default: 128. + num_attention_heads (int): Number of attention heads. Default: 16. + intermediate_size (int): Size of intermediate layer. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.1. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function. Default: "relu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=1024, + seq_length=128, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(EncoderCell, self).__init__() + self.attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + is_encdec_att=False, + compute_type=compute_type) + self.feedforward = FeedForward( + in_channels=hidden_size, + hidden_size=intermediate_size, + out_channels=hidden_size, + hidden_act=hidden_act, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask): + # self-attention with ln, res + attention_output = self.attention(hidden_states, hidden_states, attention_mask) + # feed forward with ln, res + output = self.feedforward(attention_output) + return output + + +class TransformerEncoder(nn.Cell): + """ + Multi-layer transformer encoder. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(TransformerEncoder, self).__init__() + self.num_hidden_layers = num_hidden_layers + + layers = [] + for _ in range(num_hidden_layers): + layer = EncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + self.layers = nn.CellList(layers) + + self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + prev_output = self.layer_preprocess(prev_output) + output = self.reshape(prev_output, self.out_shape) + return output + + +class DecoderCell(nn.Cell): + """ + decoder cells used in Transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the Transformer decoder layers. Default: 1024. + seq_length (int): Length of input sequence. Default: 128. + enc_seq_length (int): Length of source sentences. Default:128 + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function. Default: "relu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=1024, + seq_length=128, + enc_seq_length=128, + num_attention_heads=12, + intermediate_size=4096, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(DecoderCell, self).__init__() + self.self_attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + is_encdec_att=False, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.cross_attention = SelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + from_seq_length=seq_length, + to_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + is_encdec_att=True, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.feedforward = FeedForward( + in_channels=hidden_size, + hidden_size=intermediate_size, + out_channels=hidden_size, + hidden_act=hidden_act, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask): + # self-attention with ln, res + attention_output = self.self_attention(hidden_states, hidden_states, attention_mask) + # cross-attention with ln, res + attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask) + # feed forward with ln, res + output = self.feedforward(attention_output) + return output + + +class TransformerDecoder(nn.Cell): + """ + Multi-layer transformer decoder. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + enc_seq_length (int): Length of source sentences. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + enc_seq_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + hidden_act="relu", + compute_type=mstype.float32): + super(TransformerDecoder, self).__init__() + self.num_hidden_layers = num_hidden_layers + + # wait to be supported + # layers = [] + # for _ in range(num_hidden_layers): + # layer = DecoderCell(batch_size=batch_size, + # hidden_size=hidden_size, + # seq_length=seq_length, + # enc_seq_length=enc_seq_length, + # num_attention_heads=num_attention_heads, + # intermediate_size=intermediate_size, + # attention_probs_dropout_prob=attention_probs_dropout_prob, + # use_one_hot_embeddings=use_one_hot_embeddings, + # initializer_range=initializer_range, + # hidden_dropout_prob=hidden_dropout_prob, + # hidden_act=hidden_act, + # compute_type=compute_type) + # layers.append(layer) + # self.layers = nn.CellList(layers) + self.layer0 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + self.layer1 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + self.layer2 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + self.layer3 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + self.layer4 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + self.layer5 = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + + self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + # wait to be supported + # for layer_module in self.layers: + # layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) + # prev_output = layer_output + prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask) + + prev_output = self.layer_preprocess(prev_output) + output = self.reshape(prev_output, self.out_shape) + return output + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (:class:`TransformerConfig`): Configuration for Transformer. + """ + def __init__(self): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.batch_matmul = P.BatchMatMul() + + def construct(self, input_mask): + input_shape = self.shape(input_mask) + shape_right = (input_shape[0], 1, input_shape[1]) + shape_left = input_shape + (1,) + + input_mask = self.cast(input_mask, mstype.float32) + mask_left = self.reshape(input_mask, shape_left) + mask_right = self.reshape(input_mask, shape_right) + attention_mask = self.batch_matmul(mask_left, mask_right) + + return attention_mask + + +class PredLogProbs(nn.Cell): + """ + Get log probs. + + Args: + batch_size (int): Batch size. + seq_length (int): Length of input sequence. + width (int): Hidden size. + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + width, + compute_type=mstype.float32, + dtype=mstype.float32): + super(PredLogProbs, self).__init__() + self.batch_size = batch_size + self.seq_length = seq_length + self.width = width + self.compute_type = compute_type + self.dtype = dtype + + self.reshape = P.Reshape() + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) + self.cast = P.Cast() + + def construct(self, + input_tensor, + output_weights): + input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + + log_probs = self.log_softmax(logits) + return log_probs + + +class TransformerDecoderStep(nn.Cell): + """ + Multi-layer transformer decoder step. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + max_decode_length (int): Max decode length. + enc_seq_length (int): Length of source sentences. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 16. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096. + attention_probs_dropout_prob (float): The dropout probability for + SelfAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32. + embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module. + embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module. + projection (:class:`PredLogProbs`): PredLogProbs module + """ + def __init__(self, + batch_size, + hidden_size, + enc_seq_length, + max_decode_length, + num_hidden_layers, + num_attention_heads=16, + intermediate_size=4096, + attention_probs_dropout_prob=0.3, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.3, + hidden_act="relu", + compute_type=mstype.float32, + embedding_lookup=None, + embedding_processor=None, + projection=None): + super(TransformerDecoderStep, self).__init__(auto_prefix=False) + self.num_hidden_layers = num_hidden_layers + + self.tfm_embedding_lookup = embedding_lookup + self.tfm_embedding_processor = embedding_processor + self.projection = projection + + self.tfm_decoder = TransformerDecoder( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=-1, # -1 means length is not fixed + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + + self.ones_like = P.OnesLike() + self.shape = P.Shape() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + ones = np.ones(shape=(max_decode_length, max_decode_length)) + self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) + + self.cast_compute_type = CastWrapper(dst_type=compute_type) + + def construct(self, input_ids, enc_states, enc_attention_mask): + # input_ids: [batch_size * beam_width] + # process embedding + input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids) + input_embedding = self.tfm_embedding_processor(input_embedding) + input_embedding = self.cast_compute_type(input_embedding) + + input_shape = self.shape(input_ids) + input_len = input_shape[1] + future_mask = self.future_mask[0:input_len:1, 0:input_len:1] + + input_mask = self.ones_like(input_ids) + input_mask = self._create_attention_mask_from_input_mask(input_mask) + input_mask = self.multiply(input_mask, self.expand(future_mask, 0)) + input_mask = self.cast_compute_type(input_mask) + + enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] + + # call TransformerDecoder + decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask) + + # take the last step + decoder_output = decoder_output[::, input_len-1:input_len:1, ::] + + # projection and log_prob + log_probs = self.projection(decoder_output, embedding_tables) + + return log_probs + + +class TransformerModel(nn.Cell): + """ + Transformer with encoder and decoder. + + Args: + config (Class): Configuration for Transformer. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(TransformerModel, self).__init__() + config = copy.deepcopy(config) + self.is_training = is_training + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + + self.last_idx = self.num_hidden_layers - 1 + + self.tfm_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + self.tfm_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + self.tfm_encoder = TransformerEncoder( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type) + + if is_training: + self.projection = PredLogProbs( + batch_size=self.batch_size, + seq_length=self.seq_length, + width=self.hidden_size, + compute_type=config.compute_type, + dtype=config.dtype) + self.tfm_decoder = TransformerDecoder( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + enc_seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type) + else: + self.projection = PredLogProbs( + batch_size=self.batch_size * config.beam_width, + seq_length=1, + width=self.hidden_size, + compute_type=config.compute_type, + dtype=config.dtype) + self.tfm_decoder = TransformerDecoderStep( + batch_size=self.batch_size * config.beam_width, + hidden_size=self.hidden_size, + enc_seq_length=self.seq_length, + max_decode_length=config.max_decode_length, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=False, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + embedding_lookup=self.tfm_embedding_lookup, + embedding_processor=self.tfm_embedding_postprocessor, + projection=self.projection) + self.tfm_decoder = BeamSearchDecoder( + batch_size=config.batch_size, + seq_length=config.seq_length, + vocab_size=config.vocab_size, + decoder=self.tfm_decoder, + beam_width=config.beam_width, + length_penalty_weight=config.length_penalty_weight, + max_decode_length=config.max_decode_length) + self.tfm_decoder.add_flags(loop_can_unroll=True) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = CastWrapper(dst_type=config.compute_type) + self.expand = P.ExpandDims() + self.multiply = P.Mul() + + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() + + if is_training: + ones = np.ones(shape=(self.seq_length, self.seq_length)) + self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) + else: + self.tile_beam = TileBeam( + beam_width=config.beam_width) + ones = np.ones(shape=(config.batch_size, config.max_decode_length)) + self.encdec_mask = Tensor(ones, dtype=mstype.float32) + + def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): + # process source sentence + src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) + src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) + # transformer encoder + encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output), + self.cast_compute_type(enc_attention_mask)) + + if self.is_training: + # process target sentence + tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) + tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings) + # attention mask [batch_size, seq_length, seq_length] + tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) + tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) + # transformer decoder + decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output), + self.cast_compute_type(tgt_attention_mask), + encoder_output, enc_attention_mask) + # calculate logits and log_probs + log_probs = self.projection(decoder_output, embedding_tables) + return log_probs + + beam_encoder_output = self.tile_beam(encoder_output) + + enc_attention_mask = self.multiply( + enc_attention_mask[::, 0:1:1, ::], + self.expand(self.encdec_mask, -1)) + + beam_enc_attention_mask = self.tile_beam(enc_attention_mask) + beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) + predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask) + return predicted_ids diff --git a/model_zoo/Transformer/src/weight_init.py b/model_zoo/Transformer/src/weight_init.py new file mode 100644 index 00000000000..f2f048063d5 --- /dev/null +++ b/model_zoo/Transformer/src/weight_init.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Weight init utilities.""" + +import math +import numpy as np +from mindspore.common.tensor import Tensor + +def _average_units(shape): + """ + Average shape dim. + """ + if not shape: + return 1. + if len(shape) == 1: + return float(shape[0]) + if len(shape) == 2: + return float(shape[0] + shape[1]) / 2. + raise RuntimeError("not support shape.") + +def weight_variable(shape): + scale_shape = shape + avg_units = _average_units(scale_shape) + scale = 1.0 / max(1., avg_units) + limit = math.sqrt(3.0 * scale) + values = np.random.uniform(-limit, limit, shape).astype(np.float32) + return Tensor(values) + +def one_weight(shape): + ones = np.ones(shape).astype(np.float32) + return Tensor(ones) + +def zero_weight(shape): + zeros = np.zeros(shape).astype(np.float32) + return Tensor(zeros) + +def normal_weight(shape, num_units): + norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32) + return Tensor(norm) + \ No newline at end of file diff --git a/model_zoo/Transformer/train.py b/model_zoo/Transformer/train.py new file mode 100644 index 00000000000..37165a6c206 --- /dev/null +++ b/model_zoo/Transformer/train.py @@ -0,0 +1,179 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer training script.""" + +import time +import argparse + +import mindspore.common.dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.nn.optim import Adam +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.train.callback import Callback, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.communication.management as D +from mindspore.train.parallel_utils import ParallelMode +from mindspore import context + +from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \ + TransformerTrainOneStepWithLossScaleCell +from src.config import cfg, transformer_net_cfg +from src.dataset import create_transformer_dataset +from src.weight_init import weight_variable, one_weight, zero_weight, normal_weight +from src.lr_schedule import create_dynamic_lr + + +def get_ms_timestamp(): + t = time.time() + return int(round(t * 1000)) +time_stamp_init = False +time_stamp_first = 0 + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss is NAN or INF terminating training. + Note: + If per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + global time_stamp_init, time_stamp_first + if not time_stamp_init: + time_stamp_first = get_ms_timestamp() + time_stamp_init = True + + def step_end(self, run_context): + global time_stamp_first + time_stamp_current = get_ms_timestamp() + cb_params = run_context.original_args() + print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) + with open("./loss.log", "a+") as f: + f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs))) + f.write('\n') + + +def argparse_init(): + """ + Argparse init. + """ + parser = argparse.ArgumentParser(description='transformer') + parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") + parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") + parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.") + parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") + parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.") + parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") + parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, " + "default is true.") + parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, " + "default is 2500.") + parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.") + parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, " + "default is ./checkpoint/") + parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path") + return parser + +def run_transformer_train(): + """ + Transformer training. + """ + parser = argparse_init() + args, _ = parser.parse_known_args() + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) + context.set_context(save_graphs=True, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) + + if args.distribute == "true": + device_num = args.device_num + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, + parameter_broadcast=True, device_num=device_num) + D.init() + rank_id = args.device_id % device_num + else: + device_num = 1 + rank_id = 0 + dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num, + rank_id=rank_id, do_shuffle=args.do_shuffle, + enable_data_sink=args.enable_data_sink, + dataset_path=args.data_path) + + netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) + + if args.checkpoint_path: + parameter_dict = load_checkpoint(args.checkpoint_path) + else: + parameter_dict = {} + params = netwithloss.trainable_params() + for param in params: + name = param.name + value = param.default_input + if isinstance(value, Tensor): + if name.endswith(".gamma"): + parameter_dict[name] = Parameter(one_weight(value.asnumpy().shape), name=name) + elif name.endswith(".beta") or name.endswith(".bias"): + parameter_dict[name] = Parameter(zero_weight(value.asnumpy().shape), name=name) + elif "embedding" in name: + parameter_dict[name] = Parameter(normal_weight(value.asnumpy().shape, + transformer_net_cfg.hidden_size), name=name) + else: + parameter_dict[name] = Parameter(weight_variable(value.asnumpy().shape), name=name) + load_param_into_net(netwithloss, parameter_dict) + + lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay", + training_steps=dataset.get_dataset_size()*args.epoch_size, + learning_rate=cfg.lr_schedule.learning_rate, + warmup_steps=cfg.lr_schedule.warmup_steps, + hidden_size=transformer_net_cfg.hidden_size), mstype.float32) + optimizer = Adam(netwithloss.trainable_params(), lr) + + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] + if args.enable_save_ckpt == "true": + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, + keep_checkpoint_max=args.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) + callbacks.append(ckpoint_cb) + + if args.enable_lossscale == "true": + scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, + scale_factor=cfg.scale_factor, + scale_window=cfg.scale_window) + update_cell = scale_manager.get_update_cell() + netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, + scale_update_cell=update_cell) + else: + netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer) + + netwithgrads.set_train(True) + model = Model(netwithgrads) + model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) + +if __name__ == '__main__': + run_transformer_train()