modify Transformer model

This commit is contained in:
yuchaojie 2020-06-05 09:24:23 +08:00
parent bd3e8da6a7
commit a9b7861a00
3 changed files with 39 additions and 51 deletions

View File

@ -78,9 +78,8 @@ def load_weights(model_path):
weights = {}
for msname in ms_ckpt:
infer_name = msname.replace("transformer.transformer.", "")
infer_name = msname
if "tfm_decoder" in msname:
infer_name = infer_name.replace(".layers.", ".layer")
infer_name = "tfm_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[msname]

View File

@ -20,11 +20,11 @@ import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .beam_search import BeamSearchDecoder, TileBeam
from .weight_init import normal_weight, weight_variable
class TransformerConfig:
"""
@ -118,9 +118,7 @@ class EmbeddingLookup(nn.Cell):
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]),
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size),
name='embedding_table')
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
@ -138,8 +136,7 @@ class EmbeddingLookup(nn.Cell):
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
@ -329,22 +326,22 @@ class MultiheadAttention(nn.Cell):
units,
activation=query_act,
has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width,
units,
activation=key_act,
has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width,
units,
activation=value_act,
has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.out_layer = nn.Dense(units,
out_tensor_width,
activation=out_act,
has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
@ -518,10 +515,10 @@ class FeedForward(nn.Cell):
self.conv1 = nn.Dense(in_channels,
hidden_size,
activation=hidden_act,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type)
self.conv2 = nn.Dense(hidden_size,
out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type)
self.preprocess = LayerPreprocess(in_channels=in_channels)
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
@ -1108,7 +1105,13 @@ class TransformerModel(nn.Cell):
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
self.tfm_embedding_postprocessor = EmbeddingPostprocessor(
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
@ -1171,7 +1174,7 @@ class TransformerModel(nn.Cell):
hidden_act=config.hidden_act,
compute_type=config.compute_type,
embedding_lookup=self.tfm_embedding_lookup,
embedding_processor=self.tfm_embedding_postprocessor,
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
projection=self.projection)
self.tfm_decoder = BeamSearchDecoder(
batch_size=config.batch_size,
@ -1195,15 +1198,14 @@ class TransformerModel(nn.Cell):
ones = np.ones(shape=(self.seq_length, self.seq_length))
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
else:
self.tile_beam = TileBeam(
beam_width=config.beam_width)
self.tile_beam = TileBeam(beam_width=config.beam_width)
ones = np.ones(shape=(config.batch_size, config.max_decode_length))
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
# process source sentence
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings)
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
# transformer encoder
@ -1213,7 +1215,7 @@ class TransformerModel(nn.Cell):
if self.is_training:
# process target sentence
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings)
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0))
@ -1223,15 +1225,14 @@ class TransformerModel(nn.Cell):
encoder_output, enc_attention_mask)
# calculate logits and log_probs
log_probs = self.projection(decoder_output, embedding_tables)
return log_probs
ret = log_probs
else:
beam_encoder_output = self.tile_beam(encoder_output)
beam_encoder_output = self.tile_beam(encoder_output)
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1))
enc_attention_mask = self.multiply(
enc_attention_mask[::, 0:1:1, ::],
self.expand(self.encdec_mask, -1))
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
return predicted_ids
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
ret = predicted_ids
return ret

View File

@ -16,9 +16,9 @@
import time
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.nn.optim import Adam
from mindspore.train.model import Model
@ -34,7 +34,6 @@ from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNe
TransformerTrainOneStepWithLossScaleCell
from src.config import cfg, transformer_net_cfg
from src.dataset import create_transformer_dataset
from src.weight_init import weight_variable, one_weight, zero_weight, normal_weight
from src.lr_schedule import create_dynamic_lr
@ -108,7 +107,7 @@ def run_transformer_train():
parser = argparse_init()
args, _ = parser.parse_known_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
context.set_context(save_graphs=True, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
if args.distribute == "true":
device_num = args.device_num
@ -129,29 +128,15 @@ def run_transformer_train():
if args.checkpoint_path:
parameter_dict = load_checkpoint(args.checkpoint_path)
else:
parameter_dict = {}
params = netwithloss.trainable_params()
for param in params:
name = param.name
value = param.default_input
if isinstance(value, Tensor):
if name.endswith(".gamma"):
parameter_dict[name] = Parameter(one_weight(value.asnumpy().shape), name=name)
elif name.endswith(".beta") or name.endswith(".bias"):
parameter_dict[name] = Parameter(zero_weight(value.asnumpy().shape), name=name)
elif "embedding" in name:
parameter_dict[name] = Parameter(normal_weight(value.asnumpy().shape,
transformer_net_cfg.hidden_size), name=name)
else:
parameter_dict[name] = Parameter(weight_variable(value.asnumpy().shape), name=name)
load_param_into_net(netwithloss, parameter_dict)
load_param_into_net(netwithloss, parameter_dict)
lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
training_steps=dataset.get_dataset_size()*args.epoch_size,
learning_rate=cfg.lr_schedule.learning_rate,
warmup_steps=cfg.lr_schedule.warmup_steps,
hidden_size=transformer_net_cfg.hidden_size), mstype.float32)
hidden_size=transformer_net_cfg.hidden_size,
start_decay_step=cfg.lr_schedule.start_decay_step,
min_lr=cfg.lr_schedule.min_lr), mstype.float32)
optimizer = Adam(netwithloss.trainable_params(), lr)
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
@ -176,4 +161,7 @@ def run_transformer_train():
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__':
random_seed = 1
np.random.seed(random_seed)
run_transformer_train()