From 96db04b49d95f649a1d779a696a118532692e893 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Wed, 17 Jun 2020 11:19:13 +0800 Subject: [PATCH] fix decoder loop for Transformer model --- .../Transformer/src/transformer_model.py | 118 +++--------------- model_zoo/Transformer/train.py | 9 +- 2 files changed, 25 insertions(+), 102 deletions(-) diff --git a/model_zoo/Transformer/src/transformer_model.py b/model_zoo/Transformer/src/transformer_model.py index da1dfc0955a..409f8965eb3 100644 --- a/model_zoo/Transformer/src/transformer_model.py +++ b/model_zoo/Transformer/src/transformer_model.py @@ -781,95 +781,22 @@ class TransformerDecoder(nn.Cell): super(TransformerDecoder, self).__init__() self.num_hidden_layers = num_hidden_layers - # wait to be supported - # layers = [] - # for _ in range(num_hidden_layers): - # layer = DecoderCell(batch_size=batch_size, - # hidden_size=hidden_size, - # seq_length=seq_length, - # enc_seq_length=enc_seq_length, - # num_attention_heads=num_attention_heads, - # intermediate_size=intermediate_size, - # attention_probs_dropout_prob=attention_probs_dropout_prob, - # use_one_hot_embeddings=use_one_hot_embeddings, - # initializer_range=initializer_range, - # hidden_dropout_prob=hidden_dropout_prob, - # hidden_act=hidden_act, - # compute_type=compute_type) - # layers.append(layer) - # self.layers = nn.CellList(layers) - self.layer0 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - self.layer1 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - self.layer2 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - self.layer3 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - self.layer4 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) - self.layer5 = DecoderCell(batch_size=batch_size, - hidden_size=hidden_size, - seq_length=seq_length, - enc_seq_length=enc_seq_length, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - hidden_act=hidden_act, - compute_type=compute_type) + layers = [] + for _ in range(num_hidden_layers): + layer = DecoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + enc_seq_length=enc_seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + self.layers = nn.CellList(layers) self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) @@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell): def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): prev_output = self.reshape(input_tensor, self.shape) - # wait to be supported - # for layer_module in self.layers: - # layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) - # prev_output = layer_output - prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask) - prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask) + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) + prev_output = layer_output prev_output = self.layer_preprocess(prev_output) output = self.reshape(prev_output, self.out_shape) diff --git a/model_zoo/Transformer/train.py b/model_zoo/Transformer/train.py index 4e1a1b4a091..23c0eb78fd5 100644 --- a/model_zoo/Transformer/train.py +++ b/model_zoo/Transformer/train.py @@ -16,6 +16,7 @@ import time import argparse +import random import numpy as np import mindspore.common.dtype as mstype @@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import Callback, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de import mindspore.communication.management as D from mindspore.train.parallel_utils import ParallelMode from mindspore import context @@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg from src.dataset import create_transformer_dataset from src.lr_schedule import create_dynamic_lr +random_seed = 1 +random.seed(random_seed) +np.random.seed(random_seed) +de.config.set_seed(random_seed) def get_ms_timestamp(): t = time.time() @@ -161,7 +167,4 @@ def run_transformer_train(): model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) if __name__ == '__main__': - random_seed = 1 - np.random.seed(random_seed) - run_transformer_train()