forked from mindspore-Ecosystem/mindspore
!1611 Transformer model scripts merge
Merge pull request !1611 from yuchaojie/transformer
This commit is contained in:
commit
9d9cd3c1ef
|
@ -0,0 +1,176 @@
|
|||
# Transformer Example
|
||||
## Description
|
||||
This example implements training and evaluation of Transformer Model, which is introduced in the following paper:
|
||||
- Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008.
|
||||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download and preprocess the WMT English-German dataset for training and evaluation.
|
||||
|
||||
> Notes:If you are running an evaluation task, prepare the corresponding checkpoint file.
|
||||
|
||||
## Example structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─Transformer
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─process_output.sh
|
||||
├─replace-quote.perl
|
||||
├─run_distribute_train.sh
|
||||
└─run_standalone_train.sh
|
||||
├─src
|
||||
├─__init__.py
|
||||
├─beam_search.py
|
||||
├─config.py
|
||||
├─dataset.py
|
||||
├─eval_config.py
|
||||
├─lr_schedule.py
|
||||
├─process_output.py
|
||||
├─tokenization.py
|
||||
├─transformer_for_train.py
|
||||
├─transformer_model.py
|
||||
└─weight_init.py
|
||||
├─create_data.py
|
||||
├─eval.py
|
||||
└─train.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Prepare the dataset
|
||||
- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
|
||||
- train.tok.clean.bpe.32000.en
|
||||
- train.tok.clean.bpe.32000.de
|
||||
- vocab.bpe.32000
|
||||
- newstest2014.tok.bpe.32000.en
|
||||
- newstest2014.tok.bpe.32000.de
|
||||
- newstest2014.tok.de
|
||||
|
||||
- Convert the original data to mindrecord for training:
|
||||
|
||||
``` bash
|
||||
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
|
||||
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128
|
||||
```
|
||||
- Convert the original data to mindrecord for evaluation:
|
||||
|
||||
``` bash
|
||||
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
|
||||
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Training
|
||||
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#mindspore) for more information about dataset.
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of Transformer model.
|
||||
|
||||
``` bash
|
||||
sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH
|
||||
```
|
||||
- Run `run_distribute_train.sh` for distributed training of Transformer model.
|
||||
|
||||
``` bash
|
||||
sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path.
|
||||
|
||||
- Run `eval.py` for evaluation of Transformer model.
|
||||
|
||||
```bash
|
||||
python eval.py
|
||||
```
|
||||
|
||||
- Run `process_output.sh` to process the output token ids to get the real translation results.
|
||||
|
||||
```bash
|
||||
sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE
|
||||
```
|
||||
You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation.
|
||||
|
||||
- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
|
||||
|
||||
```bash
|
||||
perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
### Training
|
||||
```
|
||||
usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
|
||||
[--enable_save_ckpt ENABLE_SAVE_CKPT]
|
||||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N]
|
||||
[--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||
[--data_path DATA_PATH]
|
||||
|
||||
options:
|
||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
--epoch_size epoch size: N, default is 52
|
||||
--device_num number of used devices: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
|
||||
--enable_lossscale enable lossscale: "true" | "false", default is "true"
|
||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "false"
|
||||
--checkpoint_path path to load checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 2500
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 30
|
||||
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
|
||||
--data_path path to dataset file: PATH, default is ""
|
||||
```
|
||||
|
||||
## Options and Parameters
|
||||
It contains of parameters of Transformer model and options for training and evaluation, which is set in file `config.py` and `evaluation_config.py` respectively.
|
||||
### Options:
|
||||
```
|
||||
config.py:
|
||||
transformer_network version of Transformer model: base | large, default is large
|
||||
init_loss_scale_value initial value of loss scale: N, default is 2^10
|
||||
scale_factor factor used to update loss scale: N, default is 2
|
||||
scale_window steps for once updatation of loss scale: N, default is 2000
|
||||
optimizer optimizer used in the network: Adam, default is "Adam"
|
||||
|
||||
eval_config.py:
|
||||
transformer_network version of Transformer model: base | large, default is large
|
||||
data_file data file: PATH
|
||||
model_file checkpoint file to be loaded: PATH
|
||||
output_file output file of evaluation: PATH
|
||||
```
|
||||
|
||||
### Parameters:
|
||||
```
|
||||
Parameters for dataset and network (Training/Evaluation):
|
||||
batch_size batch size of input dataset: N, default is 96
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, default is 36560
|
||||
hidden_size size of Transformer encoder layers: N, default is 1024
|
||||
num_hidden_layers number of hidden layers: N, default is 6
|
||||
num_attention_heads number of attention heads: N, default is 16
|
||||
intermediate_size size of intermediate layer: N, default is 4096
|
||||
hidden_act activation function used: ACTIVATION, default is "relu"
|
||||
hidden_dropout_prob dropout probability for TransformerOutput: Q, default is 0.3
|
||||
attention_probs_dropout_prob dropout probability for TransformerAttention: Q, default is 0.3
|
||||
max_position_embeddings maximum length of sequences: N, default is 128
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
label_smoothing label smoothing setting: Q, default is 0.1
|
||||
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
|
||||
beam_width beam width setting: N, default is 4
|
||||
max_decode_length max decode length in evaluation: N, default is 80
|
||||
length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0
|
||||
compute_type compute type in Transformer: mstype.float16 | mstype.float32, default is mstype.float16
|
||||
|
||||
Parameters for learning rate:
|
||||
learning_rate value of learning rate: Q
|
||||
warmup_steps steps of the learning rate warm up: N
|
||||
start_decay_step step of the learning rate to decay: N
|
||||
min_lr minimal learning rate: Q
|
||||
```
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create training instances for Transformer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
import src.tokenization as tokenization
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
class SampleInstance():
|
||||
"""A single sample instance (sentence pair)."""
|
||||
|
||||
def __init__(self, source_sos_tokens, source_eos_tokens, target_sos_tokens, target_eos_tokens):
|
||||
self.source_sos_tokens = source_sos_tokens
|
||||
self.source_eos_tokens = source_eos_tokens
|
||||
self.target_sos_tokens = target_sos_tokens
|
||||
self.target_eos_tokens = target_eos_tokens
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "source sos tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.source_sos_tokens]))
|
||||
s += "source eos tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.source_eos_tokens]))
|
||||
s += "target sos tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.target_sos_tokens]))
|
||||
s += "target eos tokens: %s\n" % (" ".join(
|
||||
[tokenization.printable_text(x) for x in self.target_eos_tokens]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance, tokenizer, max_seq_length):
|
||||
"""Create files from `SampleInstance`s."""
|
||||
|
||||
def _convert_ids_and_mask(input_tokens):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
|
||||
return input_ids, input_mask
|
||||
|
||||
source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens)
|
||||
source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens)
|
||||
target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens)
|
||||
target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["source_sos_ids"] = np.asarray(source_sos_ids)
|
||||
features["source_sos_mask"] = np.asarray(source_sos_mask)
|
||||
features["source_eos_ids"] = np.asarray(source_eos_ids)
|
||||
features["source_eos_mask"] = np.asarray(source_eos_mask)
|
||||
features["target_sos_ids"] = np.asarray(target_sos_ids)
|
||||
features["target_sos_mask"] = np.asarray(target_sos_mask)
|
||||
features["target_eos_ids"] = np.asarray(target_eos_ids)
|
||||
features["target_eos_mask"] = np.asarray(target_eos_mask)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
|
||||
"""Creates `SampleInstance`s for a single sentence pair."""
|
||||
EOS = "</s>"
|
||||
SOS = "<s>"
|
||||
|
||||
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
|
||||
if clip_to_max_len:
|
||||
print("####lalalal")
|
||||
source_words = source_words[:min([len(source_words, max_seq_length-1)])]
|
||||
target_words = target_words[:min([len(target_words, max_seq_length-1)])]
|
||||
else:
|
||||
return None
|
||||
|
||||
source_sos_tokens = [SOS] + source_words
|
||||
source_eos_tokens = source_words + [EOS]
|
||||
target_sos_tokens = [SOS] + target_words
|
||||
target_eos_tokens = target_words + [EOS]
|
||||
|
||||
instance = SampleInstance(
|
||||
source_sos_tokens=source_sos_tokens,
|
||||
source_eos_tokens=source_eos_tokens,
|
||||
target_sos_tokens=target_sos_tokens,
|
||||
target_eos_tokens=target_eos_tokens)
|
||||
return instance
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True,
|
||||
help='Input raw text file (or comma-separated list of files).')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
parser.add_argument("--num_splits", type=int, default=16,
|
||||
help='The MindRecord file will be split into the number of partition.')
|
||||
parser.add_argument("--vocab_file", type=str, required=True,
|
||||
help='The vocabulary file that the Transformer model was trained on.')
|
||||
parser.add_argument("--clip_to_max_len", type=bool, default=False,
|
||||
help='clip sequences to maximum sequence length.')
|
||||
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in args.input_file.split(","):
|
||||
input_files.append(input_pattern)
|
||||
|
||||
logging.info("*** Reading from input files ***")
|
||||
for input_file in input_files:
|
||||
logging.info(" %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("*** Writing to output files ***")
|
||||
logging.info(" %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_mask": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "tranformer hisi")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
for input_file in input_files:
|
||||
logging.info("*** Reading from %s ***", input_file)
|
||||
with open(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
|
||||
total_read += 1
|
||||
if total_read % 100000 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
source_line, target_line = line.strip().split("\t")
|
||||
source_tokens = tokenizer.tokenize(source_line)
|
||||
target_tokens = tokenizer.tokenize(target_line)
|
||||
|
||||
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
|
||||
logging.info("ignore long sentence!")
|
||||
continue
|
||||
|
||||
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
|
||||
clip_to_max_len=args.clip_to_max_len)
|
||||
if instance is None:
|
||||
continue
|
||||
|
||||
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("source tokens: %s", " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.source_eos_tokens]))
|
||||
logging.info("target tokens: %s", " ".join(
|
||||
[tokenization.printable_text(x) for x in instance.target_sos_tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import context
|
||||
|
||||
from src.transformer_model import TransformerModel
|
||||
from src.eval_config import cfg, transformer_net_cfg
|
||||
|
||||
def load_test_data(batch_size=1, data_file=None):
|
||||
"""
|
||||
Load test dataset
|
||||
"""
|
||||
ds = de.MindDataset(data_file,
|
||||
columns_list=["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"],
|
||||
shuffle=False)
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds.channel_name = 'transformer'
|
||||
return ds
|
||||
|
||||
class TransformerInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of transformer network infer.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(TransformerInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask):
|
||||
predicted_ids = self.network(source_ids, source_mask)
|
||||
return predicted_ids
|
||||
|
||||
def load_weights(model_path):
|
||||
"""
|
||||
Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file.
|
||||
"""
|
||||
if model_path.endswith(".npz"):
|
||||
ms_ckpt = np.load(model_path)
|
||||
is_npz = True
|
||||
else:
|
||||
ms_ckpt = load_checkpoint(model_path)
|
||||
is_npz = False
|
||||
|
||||
weights = {}
|
||||
for msname in ms_ckpt:
|
||||
infer_name = msname.replace("transformer.transformer.", "")
|
||||
if "tfm_decoder" in msname:
|
||||
infer_name = infer_name.replace(".layers.", ".layer")
|
||||
infer_name = "tfm_decoder.decoder." + infer_name
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[msname]
|
||||
else:
|
||||
weights[infer_name] = ms_ckpt[msname].data.asnumpy()
|
||||
weights["tfm_decoder.decoder.tfm_embedding_lookup.embedding_table"] = \
|
||||
weights["tfm_embedding_lookup.embedding_table"]
|
||||
|
||||
parameter_dict = {}
|
||||
for name in weights:
|
||||
parameter_dict[name] = Parameter(Tensor(weights[name]), name=name)
|
||||
return parameter_dict
|
||||
|
||||
def run_transformer_eval():
|
||||
"""
|
||||
Transformer evaluation.
|
||||
"""
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False,
|
||||
device_id=device_id)
|
||||
|
||||
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=cfg.data_file)
|
||||
tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False)
|
||||
|
||||
parameter_dict = load_weights(cfg.model_file)
|
||||
load_param_into_net(tfm_model, parameter_dict)
|
||||
|
||||
tfm_infer = TransformerInferCell(tfm_model)
|
||||
model = Model(tfm_infer)
|
||||
|
||||
predictions = []
|
||||
source_sents = []
|
||||
target_sents = []
|
||||
for batch in dataset.create_dict_iterator():
|
||||
source_sents.append(batch["source_eos_ids"])
|
||||
target_sents.append(batch["target_eos_ids"])
|
||||
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
|
||||
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
|
||||
predicted_ids = model.predict(source_ids, source_mask)
|
||||
predictions.append(predicted_ids.asnumpy())
|
||||
|
||||
# decode and write to file
|
||||
f = open(cfg.output_file, 'w')
|
||||
for batch_out in predictions:
|
||||
for i in range(transformer_net_cfg.batch_size):
|
||||
if batch_out.ndim == 3:
|
||||
batch_out = batch_out[:, 0]
|
||||
token_ids = [str(x) for x in batch_out[i].tolist()]
|
||||
f.write(" ".join(token_ids) + "\n")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_transformer_eval()
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE"
|
||||
echo "for example: sh process_output.sh /path/newstest2014.tok.de /path/eval_output_file /path/vocab.bpe.32000"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
BASEDIR=$(dirname "$0")
|
||||
|
||||
ref_data=$1
|
||||
eval_output=$2
|
||||
vocab_file=$3
|
||||
|
||||
cat $eval_output \
|
||||
| python src/process_output.py --vocab_file $vocab_file \
|
||||
| sed 's/@@ //g' > ${eval_output}.processed
|
||||
|
||||
perl -ple 's/(\S)-(\S)/$1 #@#-#@# $2/g' < $ref_data | perl ${BASEDIR}/replace-quote.perl > ${ref_data}.forbleu
|
||||
perl -ple 's/(\S)-(\S)/$1 #@#-#@# $2/g' < ${eval_output}.processed > ${eval_output}.forbleu
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env perl
|
||||
|
||||
use warnings;
|
||||
use strict;
|
||||
|
||||
while(<STDIN>) {
|
||||
s/”/\"/g;
|
||||
s/“/\"/g;
|
||||
s/„/\"/g;
|
||||
print $_;
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_pretrain.sh 8 52 /path/ende-l128-mindrecord00 /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
rm -rf run_distribute_train
|
||||
mkdir run_distribute_train
|
||||
cd run_distribute_train || exit
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
DATA_PATH=$3
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$4
|
||||
export RANK_TABLE_FILE=$4
|
||||
export RANK_SIZE=$1
|
||||
export HCCL_FLAG=1
|
||||
export DEPLOY_MODE=0
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
|
||||
mkdir helper$i
|
||||
cp -rf ../src/ ../train.py ./helper$i
|
||||
cd ./helper$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--distribute="true" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--device_num=$RANK_SIZE \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="false" \
|
||||
--checkpoint_path="" \
|
||||
--save_checkpoint_steps=2500 \
|
||||
--save_checkpoint_num=30 \
|
||||
--data_path=$DATA_PATH > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
cd ..
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 52 /path/ende-l128-mindrecord00"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
rm -rf run_standalone_train
|
||||
mkdir run_standalone_train
|
||||
cp -rf ./src/ train.py ./run_standalone_train
|
||||
cd run_standalone_train || exit
|
||||
|
||||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_PATH=$3
|
||||
|
||||
python train.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="false" \
|
||||
--checkpoint_path="" \
|
||||
--save_checkpoint_steps=2500 \
|
||||
--save_checkpoint_num=30 \
|
||||
--data_path=$DATA_PATH > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,269 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer beam search module."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
INF = 1. * 1e9
|
||||
|
||||
class LengthPenalty(nn.Cell):
|
||||
"""
|
||||
Normalize scores of translations according to their length.
|
||||
|
||||
Args:
|
||||
weight (float): Weight of length penalty. Default: 1.0.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
weight=1.0,
|
||||
compute_type=mstype.float32):
|
||||
super(LengthPenalty, self).__init__()
|
||||
self.weight = weight
|
||||
self.add = P.TensorAdd()
|
||||
self.pow = P.Pow()
|
||||
self.div = P.RealDiv()
|
||||
self.cast = P.Cast()
|
||||
self.five = Tensor(5.0, mstype.float32)
|
||||
self.six = Tensor(6.0, mstype.float32)
|
||||
|
||||
def construct(self, length_tensor):
|
||||
length_tensor = self.cast(length_tensor, mstype.float32)
|
||||
output = self.add(length_tensor, self.five)
|
||||
output = self.div(output, self.six)
|
||||
output = self.pow(output, self.weight)
|
||||
return output
|
||||
|
||||
|
||||
class TileBeam(nn.Cell):
|
||||
"""
|
||||
TileBeam.
|
||||
|
||||
Args:
|
||||
beam_width (int): beam width setting. Default: 4.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
beam_width,
|
||||
compute_type=mstype.float32):
|
||||
super(TileBeam, self).__init__()
|
||||
self.beam_width = beam_width
|
||||
self.expand = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
"""
|
||||
input_tensor: shape [batch, dim1, dim2]
|
||||
output_tensor: shape [batch*beam, dim1, dim2]
|
||||
"""
|
||||
shape = self.shape(input_tensor)
|
||||
input_tensor = self.expand(input_tensor, 1)
|
||||
tile_shape = (1,) + (self.beam_width,)
|
||||
for _ in range(len(shape)-1):
|
||||
tile_shape = tile_shape + (1,)
|
||||
output = self.tile(input_tensor, tile_shape)
|
||||
out_shape = (shape[0]*self.beam_width,) + shape[1:]
|
||||
output = self.reshape(output, out_shape)
|
||||
return output
|
||||
|
||||
|
||||
class Mod(nn.Cell):
|
||||
"""
|
||||
Mod function.
|
||||
|
||||
Args:
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
compute_type=mstype.float32):
|
||||
super(Mod, self).__init__()
|
||||
self.compute_type = compute_type
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.sub = P.Sub()
|
||||
self.multiply = P.Mul()
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
x = self.floor_div(input_x, input_y)
|
||||
x = self.multiply(x, input_y)
|
||||
x = self.sub(input_x, x)
|
||||
return x
|
||||
|
||||
|
||||
class BeamSearchDecoder(nn.Cell):
|
||||
"""
|
||||
Beam search decoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence.
|
||||
vocab_size (int): Size of vocabulary.
|
||||
decoder (:class:`TransformerDecoderStep`): Decoder module.
|
||||
beam_width (int): beam width setting. Default: 4.
|
||||
length_penalty_weight (float): Weight of length penalty. Default: 1.0.
|
||||
max_decode_length (int): max decode length. Default: 128.
|
||||
sos_id (int): Id of sequence start token. Default: 1.
|
||||
eos_id (int): Id of sequence end token. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
vocab_size,
|
||||
decoder,
|
||||
beam_width=4,
|
||||
length_penalty_weight=1.0,
|
||||
max_decode_length=128,
|
||||
sos_id=1,
|
||||
eos_id=2,
|
||||
compute_type=mstype.float32):
|
||||
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.beam_width = beam_width
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.max_decode_length = max_decode_length
|
||||
self.decoder = decoder
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
self.expand = P.ExpandDims()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_flat = (-1,)
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
|
||||
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
|
||||
|
||||
self.select = P.Select()
|
||||
self.flat_shape = (batch_size, beam_width * vocab_size)
|
||||
self.topk = P.TopK(sorted=True)
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
|
||||
self.real_div = P.RealDiv()
|
||||
self.mod = Mod()
|
||||
self.equal = P.Equal()
|
||||
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
|
||||
|
||||
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
|
||||
self.beam_ids = Tensor(beam_ids, mstype.int32)
|
||||
batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width
|
||||
self.batch_ids = Tensor(batch_ids, mstype.int32)
|
||||
self.concat = P.Concat(axis=-1)
|
||||
self.gather_nd = P.GatherNd()
|
||||
|
||||
# init inputs and states
|
||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
||||
init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1])
|
||||
self.init_scores = Tensor(init_scores, mstype.float32)
|
||||
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
|
||||
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
||||
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
|
||||
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||
state_seq, state_finished, state_length):
|
||||
"""
|
||||
One step for decode
|
||||
"""
|
||||
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask)
|
||||
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
|
||||
|
||||
# select topk indices
|
||||
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
|
||||
|
||||
# mask finished beams
|
||||
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
|
||||
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
|
||||
|
||||
# reshape scores to [batch, beam*vocab]
|
||||
flat_scores = self.reshape(total_log_probs, self.flat_shape)
|
||||
# select topk
|
||||
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
||||
|
||||
# convert to beam and word indices
|
||||
beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor)
|
||||
word_indices = self.mod(topk_indices, self.vocab_size_tensor)
|
||||
|
||||
# mask finished indices
|
||||
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
||||
word_indices = self.select(state_finished, self.eos_ids, word_indices)
|
||||
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
|
||||
|
||||
###### put finished sequences to the end
|
||||
# sort according to scores with -inf for finished beams
|
||||
tmp_log_probs = self.select(
|
||||
self.equal(word_indices, self.eos_ids),
|
||||
self.ninf_tensor,
|
||||
topk_scores)
|
||||
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
|
||||
# update
|
||||
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
|
||||
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
|
||||
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
|
||||
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
|
||||
|
||||
###### generate new beam_search states
|
||||
# gather indices for selecting alive beams
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
|
||||
|
||||
# length add 1 if not finished in the previous step
|
||||
length_add = self.add(state_length, self.one)
|
||||
state_length = self.select(state_finished, state_length, length_add)
|
||||
state_length = self.gather_nd(state_length, gather_indices)
|
||||
|
||||
# concat seq
|
||||
seq = self.gather_nd(state_seq, gather_indices)
|
||||
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
||||
|
||||
# new finished flag and log_probs
|
||||
state_finished = self.equal(word_indices, self.eos_ids)
|
||||
state_log_probs = topk_scores
|
||||
|
||||
###### generate new inputs and decoder states
|
||||
cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1))
|
||||
return cur_input_ids, state_log_probs, state_seq, state_finished, state_length
|
||||
|
||||
def construct(self, enc_states, enc_attention_mask):
|
||||
cur_input_ids = self.start_ids
|
||||
# beam search states
|
||||
state_log_probs = self.init_scores
|
||||
state_seq = self.init_seq
|
||||
state_finished = self.init_finished
|
||||
state_length = self.init_length
|
||||
|
||||
for _ in range(self.max_decode_length):
|
||||
# run one step decoder to get outputs of the current step
|
||||
# shape [batch*beam, 1, vocab]
|
||||
cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step(
|
||||
cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length)
|
||||
|
||||
# add length penalty scores
|
||||
penalty_len = self.length_penalty(state_length)
|
||||
# return penalty_len
|
||||
log_probs = self.real_div(state_log_probs, penalty_len)
|
||||
|
||||
# sort according to scores
|
||||
_, top_beam_indices = self.topk(log_probs, self.beam_width)
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
||||
# sort sequence
|
||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
||||
# take the first one
|
||||
predicted_ids = predicted_ids[::, 0:1:1, ::]
|
||||
return predicted_ids
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network config setting, will be used in dataset.py, train.py."""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from .transformer_model import TransformerConfig
|
||||
cfg = edict({
|
||||
'transformer_network': 'large',
|
||||
'init_loss_scale_value': 1024,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 2000,
|
||||
'optimizer': 'Adam',
|
||||
'lr_schedule': edict({
|
||||
'learning_rate': 2.0,
|
||||
'warmup_steps': 8000,
|
||||
'start_decay_step': 16000,
|
||||
'min_lr': 0.0,
|
||||
}),
|
||||
})
|
||||
'''
|
||||
two kinds of transformer model version
|
||||
'''
|
||||
if cfg.transformer_network == 'large':
|
||||
transformer_net_cfg = TransformerConfig(
|
||||
batch_size=96,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.2,
|
||||
attention_probs_dropout_prob=0.2,
|
||||
max_position_embeddings=128,
|
||||
initializer_range=0.02,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
||||
if cfg.transformer_network == 'base':
|
||||
transformer_net_cfg = TransformerConfig(
|
||||
batch_size=96,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
hidden_size=512,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=2048,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.2,
|
||||
attention_probs_dropout_prob=0.2,
|
||||
max_position_embeddings=128,
|
||||
initializer_range=0.02,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations, will be used in train.py."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import log as logger
|
||||
from .config import transformer_net_cfg
|
||||
|
||||
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
|
||||
dataset_path=None):
|
||||
"""create dataset"""
|
||||
repeat_count = epoch_count
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"],
|
||||
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="source_eos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="source_eos_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_sos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_sos_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_eos_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target_eos_mask", operations=type_cast_op)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
ds.channel_name = 'transformer'
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||
return ds, repeat_count
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network evaluation config setting, will be used in eval.py."""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from .transformer_model import TransformerConfig
|
||||
|
||||
cfg = edict({
|
||||
'transformer_network': 'large',
|
||||
'data_file': '/your/path/evaluation.mindrecord',
|
||||
'model_file': '/your/path/checkpoint_file',
|
||||
'output_file': '/your/path/output',
|
||||
})
|
||||
'''
|
||||
two kinds of transformer model version
|
||||
'''
|
||||
if cfg.transformer_network == 'large':
|
||||
transformer_net_cfg = TransformerConfig(
|
||||
batch_size=1,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=128,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
beam_width=4,
|
||||
max_decode_length=80,
|
||||
length_penalty_weight=1.0,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
||||
if cfg.transformer_network == 'base':
|
||||
transformer_net_cfg = TransformerConfig(
|
||||
batch_size=1,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
hidden_size=512,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=2048,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=128,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
beam_width=4,
|
||||
max_decode_length=80,
|
||||
length_penalty_weight=1.0,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate utilities."""
|
||||
|
||||
def linear_warmup(warmup_steps, current_step):
|
||||
return min([1.0, float(current_step)/float(warmup_steps)])
|
||||
|
||||
def rsqrt_decay(warmup_steps, current_step):
|
||||
return float(max([current_step, warmup_steps])) ** -0.5
|
||||
|
||||
def rsqrt_hidden(hidden_size):
|
||||
return float(hidden_size) ** -0.5
|
||||
|
||||
def create_dynamic_lr(schedule, training_steps, learning_rate, warmup_steps, hidden_size,
|
||||
start_decay_step=0, min_lr=0.):
|
||||
"""
|
||||
Generate dynamic learning rate.
|
||||
"""
|
||||
if start_decay_step < warmup_steps:
|
||||
start_decay_step = warmup_steps
|
||||
lr = []
|
||||
for current_step in range(1, training_steps+1):
|
||||
cur_lr = 1.0
|
||||
for name in schedule.split("*"):
|
||||
if name == "constant":
|
||||
cur_lr *= float(learning_rate)
|
||||
elif name == "rsqrt_hidden":
|
||||
cur_lr *= rsqrt_hidden(hidden_size)
|
||||
elif name == "linear_warmup":
|
||||
cur_lr *= linear_warmup(warmup_steps, current_step)
|
||||
elif name == "rsqrt_decay":
|
||||
cur_lr *= rsqrt_decay(warmup_steps, current_step-start_decay_step+warmup_steps)
|
||||
else:
|
||||
raise ValueError("unknown learning rate schedule")
|
||||
if warmup_steps < current_step < start_decay_step:
|
||||
cur_lr = lr[-1]
|
||||
if current_step > warmup_steps:
|
||||
cur_lr = max([cur_lr, min_lr])
|
||||
lr.append(cur_lr)
|
||||
return lr
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ids to tokens."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tokenization
|
||||
|
||||
# Explicitly set the encoding
|
||||
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="recore nbest with smoothed sentence-level bleu.")
|
||||
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
||||
|
||||
for line in sys.stdin:
|
||||
token_ids = [int(x) for x in line.strip().split()]
|
||||
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
||||
sent = " ".join(tokens)
|
||||
sent = sent.split("<s>")[-1]
|
||||
sent = sent.split("</s>")[0]
|
||||
print(sent.strip())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,193 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###############################################################################
|
||||
# Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes:
|
||||
# - Remove some unused classes and functions
|
||||
# - Modify load_vocab, convert_to_unicode, printable_text function
|
||||
# - Modify BasicTokenizer class
|
||||
# - Add WhiteSpaceTokenizer class
|
||||
###############################################################################
|
||||
|
||||
"""Tokenization utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
if six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
if isinstance(text, unicode):
|
||||
return text
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
if six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
if item in vocab:
|
||||
output.append(vocab[item])
|
||||
else:
|
||||
output.append(vocab["<unk>"])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class WhiteSpaceTokenizer():
|
||||
"""Runs end-to-end tokenziation."""
|
||||
def __init__(self, vocab_file):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer()
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.basic_tokenizer.tokenize(text)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
|
||||
class BasicTokenizer():
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
"""Constructs a BasicTokenizer."""
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
return whitespace_tokenize(text)
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char in (" ", "\t", "\n", "\r"):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char in ("\t", "\n", "\r"):
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,341 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer for training."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .transformer_model import TransformerModel
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
#return grads
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grads
|
||||
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
new_grads = new_grads + (t,)
|
||||
|
||||
return new_grads
|
||||
|
||||
|
||||
class TransformerTrainingLoss(nn.Cell):
|
||||
"""
|
||||
Provide transformer training loss.
|
||||
|
||||
Args:
|
||||
config (TransformerConfig): The config of Transformer.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(float(1-config.label_smoothing), mstype.float32)
|
||||
self.off_value = Tensor(config.label_smoothing/float(self.vocab_size-1), mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.flatten = P.Flatten()
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.flat_shape = (config.batch_size*config.seq_length,)
|
||||
|
||||
def construct(self, prediction_scores, label_ids, label_weights):
|
||||
"""Defines the computation performed."""
|
||||
label_ids = self.reshape(label_ids, self.flat_shape)
|
||||
label_weights = self.cast(self.reshape(label_weights, self.flat_shape), mstype.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + \
|
||||
self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
||||
loss = numerator / denominator
|
||||
return loss
|
||||
|
||||
|
||||
class TransformerNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide transformer training loss through network.
|
||||
|
||||
Args:
|
||||
config (TransformerConfig): The config of Transformer.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False)
|
||||
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = TransformerTrainingLoss(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights):
|
||||
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
||||
total_loss = self.loss(prediction_scores, label_ids, label_weights)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class TransformerTrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of transformer network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
target_sos_ids,
|
||||
target_sos_mask,
|
||||
target_eos_ids,
|
||||
target_eos_mask,):
|
||||
"""Defines the computation performed."""
|
||||
source_ids = source_eos_ids
|
||||
source_mask = source_eos_mask
|
||||
target_ids = target_sos_ids
|
||||
target_mask = target_sos_mask
|
||||
label_ids = target_eos_ids
|
||||
label_weights = target_eos_mask
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights)
|
||||
grads = self.grad(self.network, weights)(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of Transformer network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.add_flags(has_effect=True)
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
target_sos_ids,
|
||||
target_sos_mask,
|
||||
target_eos_ids,
|
||||
target_eos_mask,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
source_ids = source_eos_ids
|
||||
source_mask = source_eos_mask
|
||||
target_ids = target_sos_ids
|
||||
target_mask = target_sos_mask
|
||||
label_ids = target_eos_ids
|
||||
label_weights = target_eos_mask
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights)
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Weight init utilities."""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
def _average_units(shape):
|
||||
"""
|
||||
Average shape dim.
|
||||
"""
|
||||
if not shape:
|
||||
return 1.
|
||||
if len(shape) == 1:
|
||||
return float(shape[0])
|
||||
if len(shape) == 2:
|
||||
return float(shape[0] + shape[1]) / 2.
|
||||
raise RuntimeError("not support shape.")
|
||||
|
||||
def weight_variable(shape):
|
||||
scale_shape = shape
|
||||
avg_units = _average_units(scale_shape)
|
||||
scale = 1.0 / max(1., avg_units)
|
||||
limit = math.sqrt(3.0 * scale)
|
||||
values = np.random.uniform(-limit, limit, shape).astype(np.float32)
|
||||
return Tensor(values)
|
||||
|
||||
def one_weight(shape):
|
||||
ones = np.ones(shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
def zero_weight(shape):
|
||||
zeros = np.zeros(shape).astype(np.float32)
|
||||
return Tensor(zeros)
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer training script."""
|
||||
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore import context
|
||||
|
||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||
TransformerTrainOneStepWithLossScaleCell
|
||||
from src.config import cfg, transformer_net_cfg
|
||||
from src.dataset import create_transformer_dataset
|
||||
from src.weight_init import weight_variable, one_weight, zero_weight, normal_weight
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = get_ms_timestamp()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
global time_stamp_first
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss.log", "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def argparse_init():
|
||||
"""
|
||||
Argparse init.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='transformer')
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, "
|
||||
"default is true.")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, "
|
||||
"default is 2500.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
|
||||
"default is ./checkpoint/")
|
||||
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
return parser
|
||||
|
||||
def run_transformer_train():
|
||||
"""
|
||||
Transformer training.
|
||||
"""
|
||||
parser = argparse_init()
|
||||
args, _ = parser.parse_known_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
context.set_context(save_graphs=True, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||
|
||||
if args.distribute == "true":
|
||||
device_num = args.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
parameter_broadcast=True, device_num=device_num)
|
||||
D.init()
|
||||
rank_id = args.device_id % device_num
|
||||
else:
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num,
|
||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
enable_data_sink=args.enable_data_sink,
|
||||
dataset_path=args.data_path)
|
||||
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||
|
||||
if args.checkpoint_path:
|
||||
parameter_dict = load_checkpoint(args.checkpoint_path)
|
||||
else:
|
||||
parameter_dict = {}
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
name = param.name
|
||||
value = param.default_input
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
parameter_dict[name] = Parameter(one_weight(value.asnumpy().shape), name=name)
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
parameter_dict[name] = Parameter(zero_weight(value.asnumpy().shape), name=name)
|
||||
elif "embedding" in name:
|
||||
parameter_dict[name] = Parameter(normal_weight(value.asnumpy().shape,
|
||||
transformer_net_cfg.hidden_size), name=name)
|
||||
else:
|
||||
parameter_dict[name] = Parameter(weight_variable(value.asnumpy().shape), name=name)
|
||||
load_param_into_net(netwithloss, parameter_dict)
|
||||
|
||||
lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
|
||||
training_steps=dataset.get_dataset_size()*args.epoch_size,
|
||||
learning_rate=cfg.lr_schedule.learning_rate,
|
||||
warmup_steps=cfg.lr_schedule.warmup_steps,
|
||||
hidden_size=transformer_net_cfg.hidden_size), mstype.float32)
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr)
|
||||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
|
||||
if args.enable_save_ckpt == "true":
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
||||
if args.enable_lossscale == "true":
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_transformer_train()
|
Loading…
Reference in New Issue