forked from mindspore-Ecosystem/mindspore
!2202 fix decoder loop for Transformer model
Merge pull request !2202 from yuchaojie/transformer
This commit is contained in:
commit
961e29b211
|
@ -781,84 +781,9 @@ class TransformerDecoder(nn.Cell):
|
||||||
super(TransformerDecoder, self).__init__()
|
super(TransformerDecoder, self).__init__()
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
|
||||||
# wait to be supported
|
layers = []
|
||||||
# layers = []
|
for _ in range(num_hidden_layers):
|
||||||
# for _ in range(num_hidden_layers):
|
layer = DecoderCell(batch_size=batch_size,
|
||||||
# layer = DecoderCell(batch_size=batch_size,
|
|
||||||
# hidden_size=hidden_size,
|
|
||||||
# seq_length=seq_length,
|
|
||||||
# enc_seq_length=enc_seq_length,
|
|
||||||
# num_attention_heads=num_attention_heads,
|
|
||||||
# intermediate_size=intermediate_size,
|
|
||||||
# attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
# use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
# initializer_range=initializer_range,
|
|
||||||
# hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
# hidden_act=hidden_act,
|
|
||||||
# compute_type=compute_type)
|
|
||||||
# layers.append(layer)
|
|
||||||
# self.layers = nn.CellList(layers)
|
|
||||||
self.layer0 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
hidden_act=hidden_act,
|
|
||||||
compute_type=compute_type)
|
|
||||||
self.layer1 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
hidden_act=hidden_act,
|
|
||||||
compute_type=compute_type)
|
|
||||||
self.layer2 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
hidden_act=hidden_act,
|
|
||||||
compute_type=compute_type)
|
|
||||||
self.layer3 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
hidden_act=hidden_act,
|
|
||||||
compute_type=compute_type)
|
|
||||||
self.layer4 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
|
||||||
hidden_act=hidden_act,
|
|
||||||
compute_type=compute_type)
|
|
||||||
self.layer5 = DecoderCell(batch_size=batch_size,
|
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
seq_length=seq_length,
|
seq_length=seq_length,
|
||||||
enc_seq_length=enc_seq_length,
|
enc_seq_length=enc_seq_length,
|
||||||
|
@ -870,6 +795,8 @@ class TransformerDecoder(nn.Cell):
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type)
|
||||||
|
layers.append(layer)
|
||||||
|
self.layers = nn.CellList(layers)
|
||||||
|
|
||||||
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
|
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
|
||||||
|
|
||||||
|
@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell):
|
||||||
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
|
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
|
||||||
prev_output = self.reshape(input_tensor, self.shape)
|
prev_output = self.reshape(input_tensor, self.shape)
|
||||||
|
|
||||||
# wait to be supported
|
for layer_module in self.layers:
|
||||||
# for layer_module in self.layers:
|
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
|
||||||
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
|
prev_output = layer_output
|
||||||
# prev_output = layer_output
|
|
||||||
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask)
|
|
||||||
|
|
||||||
prev_output = self.layer_preprocess(prev_output)
|
prev_output = self.layer_preprocess(prev_output)
|
||||||
output = self.reshape(prev_output, self.out_shape)
|
output = self.reshape(prev_output, self.out_shape)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
from mindspore.train.callback import Callback, TimeMonitor
|
from mindspore.train.callback import Callback, TimeMonitor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg
|
||||||
from src.dataset import create_transformer_dataset
|
from src.dataset import create_transformer_dataset
|
||||||
from src.lr_schedule import create_dynamic_lr
|
from src.lr_schedule import create_dynamic_lr
|
||||||
|
|
||||||
|
random_seed = 1
|
||||||
|
random.seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
de.config.set_seed(random_seed)
|
||||||
|
|
||||||
def get_ms_timestamp():
|
def get_ms_timestamp():
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
@ -161,7 +167,4 @@ def run_transformer_train():
|
||||||
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
random_seed = 1
|
|
||||||
np.random.seed(random_seed)
|
|
||||||
|
|
||||||
run_transformer_train()
|
run_transformer_train()
|
||||||
|
|
Loading…
Reference in New Issue