!2202 fix decoder loop for Transformer model

Merge pull request !2202 from yuchaojie/transformer
This commit is contained in:
mindspore-ci-bot 2020-06-18 22:07:38 +08:00 committed by Gitee
commit 961e29b211
2 changed files with 25 additions and 102 deletions

View File

@ -781,84 +781,9 @@ class TransformerDecoder(nn.Cell):
super(TransformerDecoder, self).__init__() super(TransformerDecoder, self).__init__()
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
# wait to be supported layers = []
# layers = [] for _ in range(num_hidden_layers):
# for _ in range(num_hidden_layers): layer = DecoderCell(batch_size=batch_size,
# layer = DecoderCell(batch_size=batch_size,
# hidden_size=hidden_size,
# seq_length=seq_length,
# enc_seq_length=enc_seq_length,
# num_attention_heads=num_attention_heads,
# intermediate_size=intermediate_size,
# attention_probs_dropout_prob=attention_probs_dropout_prob,
# use_one_hot_embeddings=use_one_hot_embeddings,
# initializer_range=initializer_range,
# hidden_dropout_prob=hidden_dropout_prob,
# hidden_act=hidden_act,
# compute_type=compute_type)
# layers.append(layer)
# self.layers = nn.CellList(layers)
self.layer0 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer1 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer2 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer3 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer4 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer5 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
seq_length=seq_length, seq_length=seq_length,
enc_seq_length=enc_seq_length, enc_seq_length=enc_seq_length,
@ -870,6 +795,8 @@ class TransformerDecoder(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act, hidden_act=hidden_act,
compute_type=compute_type) compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell):
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
prev_output = self.reshape(input_tensor, self.shape) prev_output = self.reshape(input_tensor, self.shape)
# wait to be supported for layer_module in self.layers:
# for layer_module in self.layers: layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) prev_output = layer_output
# prev_output = layer_output
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer_preprocess(prev_output) prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, self.out_shape) output = self.reshape(prev_output, self.out_shape)

View File

@ -16,6 +16,7 @@
import time import time
import argparse import argparse
import random
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import Callback, TimeMonitor from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
import mindspore.communication.management as D import mindspore.communication.management as D
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore import context from mindspore import context
@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg
from src.dataset import create_transformer_dataset from src.dataset import create_transformer_dataset
from src.lr_schedule import create_dynamic_lr from src.lr_schedule import create_dynamic_lr
random_seed = 1
random.seed(random_seed)
np.random.seed(random_seed)
de.config.set_seed(random_seed)
def get_ms_timestamp(): def get_ms_timestamp():
t = time.time() t = time.time()
@ -161,7 +167,4 @@ def run_transformer_train():
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__': if __name__ == '__main__':
random_seed = 1
np.random.seed(random_seed)
run_transformer_train() run_transformer_train()