forked from mindspore-Ecosystem/mindspore
!2202 fix decoder loop for Transformer model
Merge pull request !2202 from yuchaojie/transformer
This commit is contained in:
commit
961e29b211
|
@ -781,95 +781,22 @@ class TransformerDecoder(nn.Cell):
|
|||
super(TransformerDecoder, self).__init__()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
# wait to be supported
|
||||
# layers = []
|
||||
# for _ in range(num_hidden_layers):
|
||||
# layer = DecoderCell(batch_size=batch_size,
|
||||
# hidden_size=hidden_size,
|
||||
# seq_length=seq_length,
|
||||
# enc_seq_length=enc_seq_length,
|
||||
# num_attention_heads=num_attention_heads,
|
||||
# intermediate_size=intermediate_size,
|
||||
# attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
# use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
# initializer_range=initializer_range,
|
||||
# hidden_dropout_prob=hidden_dropout_prob,
|
||||
# hidden_act=hidden_act,
|
||||
# compute_type=compute_type)
|
||||
# layers.append(layer)
|
||||
# self.layers = nn.CellList(layers)
|
||||
self.layer0 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
self.layer1 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
self.layer2 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
self.layer3 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
self.layer4 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
self.layer5 = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = DecoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
enc_seq_length=enc_seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
|
||||
|
||||
|
@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell):
|
|||
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
# wait to be supported
|
||||
# for layer_module in self.layers:
|
||||
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
# prev_output = layer_output
|
||||
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
for layer_module in self.layers:
|
||||
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||
prev_output = layer_output
|
||||
|
||||
prev_output = self.layer_preprocess(prev_output)
|
||||
output = self.reshape(prev_output, self.out_shape)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import time
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore import context
|
||||
|
@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg
|
|||
from src.dataset import create_transformer_dataset
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
random_seed = 1
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
de.config.set_seed(random_seed)
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
|
@ -161,7 +167,4 @@ def run_transformer_train():
|
|||
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
random_seed = 1
|
||||
np.random.seed(random_seed)
|
||||
|
||||
run_transformer_train()
|
||||
|
|
Loading…
Reference in New Issue