forked from mindspore-Ecosystem/mindspore
modify Transformer model
This commit is contained in:
parent
bd3e8da6a7
commit
a9b7861a00
|
@ -78,9 +78,8 @@ def load_weights(model_path):
|
|||
|
||||
weights = {}
|
||||
for msname in ms_ckpt:
|
||||
infer_name = msname.replace("transformer.transformer.", "")
|
||||
infer_name = msname
|
||||
if "tfm_decoder" in msname:
|
||||
infer_name = infer_name.replace(".layers.", ".layer")
|
||||
infer_name = "tfm_decoder.decoder." + infer_name
|
||||
if is_npz:
|
||||
weights[infer_name] = ms_ckpt[msname]
|
||||
|
|
|
@ -20,11 +20,11 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from .beam_search import BeamSearchDecoder, TileBeam
|
||||
from .weight_init import normal_weight, weight_variable
|
||||
|
||||
class TransformerConfig:
|
||||
"""
|
||||
|
@ -118,9 +118,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]),
|
||||
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size),
|
||||
name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
|
@ -138,8 +136,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, self.embedding_table)
|
||||
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
|
@ -329,22 +326,22 @@ class MultiheadAttention(nn.Cell):
|
|||
units,
|
||||
activation=query_act,
|
||||
has_bias=False,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
has_bias=False,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
has_bias=False,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
|
||||
self.out_layer = nn.Dense(units,
|
||||
out_tensor_width,
|
||||
activation=out_act,
|
||||
has_bias=False,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
|
||||
|
||||
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
|
||||
|
@ -518,10 +515,10 @@ class FeedForward(nn.Cell):
|
|||
self.conv1 = nn.Dense(in_channels,
|
||||
hidden_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type)
|
||||
self.conv2 = nn.Dense(hidden_size,
|
||||
out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type)
|
||||
|
||||
self.preprocess = LayerPreprocess(in_channels=in_channels)
|
||||
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
|
||||
|
@ -1108,7 +1105,13 @@ class TransformerModel(nn.Cell):
|
|||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range)
|
||||
self.tfm_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
dropout_prob=config.hidden_dropout_prob)
|
||||
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
|
@ -1171,7 +1174,7 @@ class TransformerModel(nn.Cell):
|
|||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
embedding_lookup=self.tfm_embedding_lookup,
|
||||
embedding_processor=self.tfm_embedding_postprocessor,
|
||||
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
|
||||
projection=self.projection)
|
||||
self.tfm_decoder = BeamSearchDecoder(
|
||||
batch_size=config.batch_size,
|
||||
|
@ -1195,15 +1198,14 @@ class TransformerModel(nn.Cell):
|
|||
ones = np.ones(shape=(self.seq_length, self.seq_length))
|
||||
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
||||
else:
|
||||
self.tile_beam = TileBeam(
|
||||
beam_width=config.beam_width)
|
||||
self.tile_beam = TileBeam(beam_width=config.beam_width)
|
||||
ones = np.ones(shape=(config.batch_size, config.max_decode_length))
|
||||
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
|
||||
|
||||
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
||||
# process source sentence
|
||||
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
|
||||
src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings)
|
||||
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
|
||||
# transformer encoder
|
||||
|
@ -1213,7 +1215,7 @@ class TransformerModel(nn.Cell):
|
|||
if self.is_training:
|
||||
# process target sentence
|
||||
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
|
||||
tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings)
|
||||
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
|
||||
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0))
|
||||
|
@ -1223,15 +1225,14 @@ class TransformerModel(nn.Cell):
|
|||
encoder_output, enc_attention_mask)
|
||||
# calculate logits and log_probs
|
||||
log_probs = self.projection(decoder_output, embedding_tables)
|
||||
return log_probs
|
||||
ret = log_probs
|
||||
else:
|
||||
beam_encoder_output = self.tile_beam(encoder_output)
|
||||
|
||||
beam_encoder_output = self.tile_beam(encoder_output)
|
||||
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1))
|
||||
|
||||
enc_attention_mask = self.multiply(
|
||||
enc_attention_mask[::, 0:1:1, ::],
|
||||
self.expand(self.encdec_mask, -1))
|
||||
|
||||
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
|
||||
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
|
||||
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
|
||||
return predicted_ids
|
||||
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
|
||||
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
|
||||
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
|
||||
ret = predicted_ids
|
||||
return ret
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.train.model import Model
|
||||
|
@ -34,7 +34,6 @@ from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNe
|
|||
TransformerTrainOneStepWithLossScaleCell
|
||||
from src.config import cfg, transformer_net_cfg
|
||||
from src.dataset import create_transformer_dataset
|
||||
from src.weight_init import weight_variable, one_weight, zero_weight, normal_weight
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
|
||||
|
@ -108,7 +107,7 @@ def run_transformer_train():
|
|||
parser = argparse_init()
|
||||
args, _ = parser.parse_known_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
context.set_context(save_graphs=True, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||
|
||||
if args.distribute == "true":
|
||||
device_num = args.device_num
|
||||
|
@ -129,29 +128,15 @@ def run_transformer_train():
|
|||
|
||||
if args.checkpoint_path:
|
||||
parameter_dict = load_checkpoint(args.checkpoint_path)
|
||||
else:
|
||||
parameter_dict = {}
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
name = param.name
|
||||
value = param.default_input
|
||||
if isinstance(value, Tensor):
|
||||
if name.endswith(".gamma"):
|
||||
parameter_dict[name] = Parameter(one_weight(value.asnumpy().shape), name=name)
|
||||
elif name.endswith(".beta") or name.endswith(".bias"):
|
||||
parameter_dict[name] = Parameter(zero_weight(value.asnumpy().shape), name=name)
|
||||
elif "embedding" in name:
|
||||
parameter_dict[name] = Parameter(normal_weight(value.asnumpy().shape,
|
||||
transformer_net_cfg.hidden_size), name=name)
|
||||
else:
|
||||
parameter_dict[name] = Parameter(weight_variable(value.asnumpy().shape), name=name)
|
||||
load_param_into_net(netwithloss, parameter_dict)
|
||||
load_param_into_net(netwithloss, parameter_dict)
|
||||
|
||||
lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
|
||||
training_steps=dataset.get_dataset_size()*args.epoch_size,
|
||||
learning_rate=cfg.lr_schedule.learning_rate,
|
||||
warmup_steps=cfg.lr_schedule.warmup_steps,
|
||||
hidden_size=transformer_net_cfg.hidden_size), mstype.float32)
|
||||
hidden_size=transformer_net_cfg.hidden_size,
|
||||
start_decay_step=cfg.lr_schedule.start_decay_step,
|
||||
min_lr=cfg.lr_schedule.min_lr), mstype.float32)
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr)
|
||||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
|
||||
|
@ -176,4 +161,7 @@ def run_transformer_train():
|
|||
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
random_seed = 1
|
||||
np.random.seed(random_seed)
|
||||
|
||||
run_transformer_train()
|
||||
|
|
Loading…
Reference in New Issue