!2202 fix decoder loop for Transformer model

Merge pull request !2202 from yuchaojie/transformer
This commit is contained in:
mindspore-ci-bot 2020-06-18 22:07:38 +08:00 committed by Gitee
commit 961e29b211
2 changed files with 25 additions and 102 deletions

View File

@ -781,95 +781,22 @@ class TransformerDecoder(nn.Cell):
super(TransformerDecoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
# wait to be supported
# layers = []
# for _ in range(num_hidden_layers):
# layer = DecoderCell(batch_size=batch_size,
# hidden_size=hidden_size,
# seq_length=seq_length,
# enc_seq_length=enc_seq_length,
# num_attention_heads=num_attention_heads,
# intermediate_size=intermediate_size,
# attention_probs_dropout_prob=attention_probs_dropout_prob,
# use_one_hot_embeddings=use_one_hot_embeddings,
# initializer_range=initializer_range,
# hidden_dropout_prob=hidden_dropout_prob,
# hidden_act=hidden_act,
# compute_type=compute_type)
# layers.append(layer)
# self.layers = nn.CellList(layers)
self.layer0 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer1 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer2 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer3 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer4 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer5 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers = []
for _ in range(num_hidden_layers):
layer = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell):
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
prev_output = self.reshape(input_tensor, self.shape)
# wait to be supported
# for layer_module in self.layers:
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
# prev_output = layer_output
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask)
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = layer_output
prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, self.out_shape)

View File

@ -16,6 +16,7 @@
import time
import argparse
import random
import numpy as np
import mindspore.common.dtype as mstype
@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
import mindspore.communication.management as D
from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg
from src.dataset import create_transformer_dataset
from src.lr_schedule import create_dynamic_lr
random_seed = 1
random.seed(random_seed)
np.random.seed(random_seed)
de.config.set_seed(random_seed)
def get_ms_timestamp():
t = time.time()
@ -161,7 +167,4 @@ def run_transformer_train():
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__':
random_seed = 1
np.random.seed(random_seed)
run_transformer_train()