forked from mindspore-Ecosystem/mindspore
transformer bucket batch modification
This commit is contained in:
parent
a0e3fd6bf3
commit
fa1247a85e
|
@ -54,10 +54,10 @@ After dataset preparation, you can start training and evaluation as follows:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run training example
|
# run training example
|
||||||
sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord00
|
sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord
|
||||||
|
|
||||||
# run distributed training example
|
# run distributed training example
|
||||||
sh scripts/run_distribute_train_ascend.sh 8 52 /path/newstest2014-l128-mindrecord rank_table.json
|
sh scripts/run_distribute_train_ascend.sh 8 52 /path/ende-l128-mindrecord rank_table.json
|
||||||
|
|
||||||
# run evaluation example
|
# run evaluation example
|
||||||
python eval.py > eval.log 2>&1 &
|
python eval.py > eval.log 2>&1 &
|
||||||
|
@ -104,6 +104,7 @@ usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [
|
||||||
[--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N]
|
[--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N]
|
||||||
[--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
[--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
|
||||||
[--data_path DATA_PATH]
|
[--data_path DATA_PATH]
|
||||||
|
[--bucket_boundaries BUCKET_LENGTH]
|
||||||
|
|
||||||
options:
|
options:
|
||||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||||
|
@ -119,6 +120,7 @@ options:
|
||||||
--save_checkpoint_num number for saving checkpoint files: N, default is 30
|
--save_checkpoint_num number for saving checkpoint files: N, default is 30
|
||||||
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
|
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
|
||||||
--data_path path to dataset file: PATH, default is ""
|
--data_path path to dataset file: PATH, default is ""
|
||||||
|
--bucket_boundaries sequence lengths for different bucket: LIST, default is [16, 32, 48, 64, 128]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Running Options
|
### Running Options
|
||||||
|
@ -179,13 +181,13 @@ Parameters for learning rate:
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
|
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
|
||||||
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128
|
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 --bucket [16, 32, 48, 64, 128]
|
||||||
```
|
```
|
||||||
- Convert the original data to mindrecord for evaluation:
|
- Convert the original data to mindrecord for evaluation:
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
|
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
|
||||||
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True
|
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True --bucket [128]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -51,20 +51,29 @@ class SampleInstance():
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
def write_instance_to_file(writer, instance, tokenizer, max_seq_length):
|
def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
|
||||||
"""Create files from `SampleInstance`s."""
|
"""Create files from `SampleInstance`s."""
|
||||||
|
def _find_bucket_length(num):
|
||||||
|
assert num <= bucket[-1]
|
||||||
|
|
||||||
|
for index in range(1, len(bucket)):
|
||||||
|
if bucket[index - 1] < num <= bucket[index]:
|
||||||
|
return bucket[index]
|
||||||
|
return bucket[0]
|
||||||
|
|
||||||
def _convert_ids_and_mask(input_tokens):
|
def _convert_ids_and_mask(input_tokens):
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1] * len(input_ids)
|
||||||
assert len(input_ids) <= max_seq_length
|
assert len(input_ids) <= max_seq_length
|
||||||
|
|
||||||
while len(input_ids) < max_seq_length:
|
seq_max_bucket_length = _find_bucket_length(len(input_ids))
|
||||||
|
|
||||||
|
while len(input_ids) < seq_max_bucket_length:
|
||||||
input_ids.append(0)
|
input_ids.append(0)
|
||||||
input_mask.append(0)
|
input_mask.append(0)
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
assert len(input_ids) == seq_max_bucket_length
|
||||||
assert len(input_mask) == max_seq_length
|
assert len(input_mask) == seq_max_bucket_length
|
||||||
|
|
||||||
return input_ids, input_mask
|
return input_ids, input_mask
|
||||||
|
|
||||||
|
@ -93,7 +102,6 @@ def create_training_instance(source_words, target_words, max_seq_length, clip_to
|
||||||
|
|
||||||
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
|
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
|
||||||
if clip_to_max_len:
|
if clip_to_max_len:
|
||||||
print("####lalalal")
|
|
||||||
source_words = source_words[:min([len(source_words, max_seq_length-1)])]
|
source_words = source_words[:min([len(source_words, max_seq_length-1)])]
|
||||||
target_words = target_words[:min([len(target_words, max_seq_length-1)])]
|
target_words = target_words[:min([len(target_words, max_seq_length-1)])]
|
||||||
else:
|
else:
|
||||||
|
@ -123,6 +131,8 @@ def main():
|
||||||
parser.add_argument("--clip_to_max_len", type=bool, default=False,
|
parser.add_argument("--clip_to_max_len", type=bool, default=False,
|
||||||
help='clip sequences to maximum sequence length.')
|
help='clip sequences to maximum sequence length.')
|
||||||
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
|
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
|
||||||
|
parser.add_argument("--bucket", type=list, default=[16, 32, 48, 64, 128], help='bucket sequence length')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
||||||
|
@ -179,7 +189,7 @@ def main():
|
||||||
if instance is None:
|
if instance is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length)
|
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length, args.bucket)
|
||||||
total_written += 1
|
total_written += 1
|
||||||
|
|
||||||
if total_written <= 20:
|
if total_written <= 20:
|
||||||
|
|
|
@ -52,7 +52,7 @@ do
|
||||||
--enable_save_ckpt="true" \
|
--enable_save_ckpt="true" \
|
||||||
--enable_lossscale="true" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="false" \
|
||||||
--checkpoint_path="" \
|
--checkpoint_path="" \
|
||||||
--save_checkpoint_steps=2500 \
|
--save_checkpoint_steps=2500 \
|
||||||
--save_checkpoint_num=30 \
|
--save_checkpoint_num=30 \
|
||||||
|
|
|
@ -37,7 +37,7 @@ python train.py \
|
||||||
--enable_save_ckpt="true" \
|
--enable_save_ckpt="true" \
|
||||||
--enable_lossscale="true" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="false" \
|
||||||
--checkpoint_path="" \
|
--checkpoint_path="" \
|
||||||
--save_checkpoint_steps=2500 \
|
--save_checkpoint_steps=2500 \
|
||||||
--save_checkpoint_num=30 \
|
--save_checkpoint_num=30 \
|
||||||
|
|
|
@ -134,6 +134,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
eos_id=2,
|
eos_id=2,
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32):
|
||||||
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
|
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
|
||||||
|
self.seq_length = seq_length
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.beam_width = beam_width
|
self.beam_width = beam_width
|
||||||
|
@ -182,7 +183,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
"""
|
"""
|
||||||
One step for decode
|
One step for decode
|
||||||
"""
|
"""
|
||||||
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask)
|
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
|
||||||
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
|
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
|
||||||
|
|
||||||
# select topk indices
|
# select topk indices
|
||||||
|
|
|
@ -15,30 +15,40 @@
|
||||||
"""Data operations, will be used in train.py."""
|
"""Data operations, will be used in train.py."""
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset.engine.datasets as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.c_transforms as deC
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
from .config import transformer_net_cfg
|
from .config import transformer_net_cfg
|
||||||
|
de.config.set_seed(1)
|
||||||
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
|
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
|
||||||
dataset_path=None):
|
dataset_path=None, bucket_boundaries=None):
|
||||||
"""create dataset"""
|
"""create dataset"""
|
||||||
repeat_count = epoch_count
|
def batch_per_bucket(bucket_len, dataset_path):
|
||||||
ds = de.MindDataset(dataset_path,
|
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
|
||||||
columns_list=["source_eos_ids", "source_eos_mask",
|
ds = de.MindDataset(dataset_path,
|
||||||
"target_sos_ids", "target_sos_mask",
|
columns_list=["source_eos_ids", "source_eos_mask",
|
||||||
"target_eos_ids", "target_eos_mask"],
|
"target_sos_ids", "target_sos_mask",
|
||||||
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
|
"target_eos_ids", "target_eos_mask"],
|
||||||
|
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
||||||
|
|
||||||
type_cast_op = deC.TypeCast(mstype.int32)
|
# apply batch operations
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
|
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
|
ds = ds.repeat(epoch_count)
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
|
return ds
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
|
||||||
|
|
||||||
# apply batch operations
|
|
||||||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
|
||||||
ds = ds.repeat(repeat_count)
|
|
||||||
|
|
||||||
|
for i, _ in enumerate(bucket_boundaries):
|
||||||
|
bucket_len = bucket_boundaries[i]
|
||||||
|
ds_per = batch_per_bucket(bucket_len, dataset_path)
|
||||||
|
if i == 0:
|
||||||
|
ds = ds_per
|
||||||
|
else:
|
||||||
|
ds = ds + ds_per
|
||||||
|
ds = ds.shuffle(ds.get_dataset_size())
|
||||||
|
ds.channel_name = 'transformer'
|
||||||
return ds
|
return ds
|
||||||
|
|
|
@ -95,12 +95,13 @@ class TransformerTrainingLoss(nn.Cell):
|
||||||
self.flatten = P.Flatten()
|
self.flatten = P.Flatten()
|
||||||
self.neg = P.Neg()
|
self.neg = P.Neg()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.flat_shape = (config.batch_size * config.seq_length,)
|
self.batch_size = config.batch_size
|
||||||
|
|
||||||
def construct(self, prediction_scores, label_ids, label_weights):
|
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
|
||||||
"""Defines the computation performed."""
|
"""Defines the computation performed."""
|
||||||
label_ids = self.reshape(label_ids, self.flat_shape)
|
flat_shape = (self.batch_size * seq_length,)
|
||||||
label_weights = self.cast(self.reshape(label_weights, self.flat_shape), mstype.float32)
|
label_ids = self.reshape(label_ids, flat_shape)
|
||||||
|
label_weights = self.cast(self.reshape(label_weights, flat_shape), mstype.float32)
|
||||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||||
|
|
||||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||||
|
@ -128,6 +129,7 @@ class TransformerNetworkWithLoss(nn.Cell):
|
||||||
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
|
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
|
||||||
self.loss = TransformerTrainingLoss(config)
|
self.loss = TransformerTrainingLoss(config)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
|
||||||
def construct(self,
|
def construct(self,
|
||||||
source_ids,
|
source_ids,
|
||||||
|
@ -136,8 +138,10 @@ class TransformerNetworkWithLoss(nn.Cell):
|
||||||
target_mask,
|
target_mask,
|
||||||
label_ids,
|
label_ids,
|
||||||
label_weights):
|
label_weights):
|
||||||
|
"""Transformer network with loss."""
|
||||||
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
||||||
total_loss = self.loss(prediction_scores, label_ids, label_weights)
|
seq_length = self.shape(source_ids)[1]
|
||||||
|
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
|
||||||
return self.cast(total_loss, mstype.float32)
|
return self.cast(total_loss, mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,7 +160,6 @@ class TransformerTrainOneStepCell(nn.Cell):
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
|
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.network = network
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = ParameterTuple(network.trainable_params())
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import mindspore.ops.functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
from .beam_search import BeamSearchDecoder, TileBeam
|
from .beam_search import BeamSearchDecoder, TileBeam
|
||||||
from .weight_init import normal_weight, weight_variable
|
from .weight_init import normal_weight, weight_variable
|
||||||
|
|
||||||
|
@ -296,8 +297,6 @@ class MultiheadAttention(nn.Cell):
|
||||||
from_tensor_width,
|
from_tensor_width,
|
||||||
to_tensor_width,
|
to_tensor_width,
|
||||||
out_tensor_width,
|
out_tensor_width,
|
||||||
from_seq_length,
|
|
||||||
to_seq_length,
|
|
||||||
num_attention_heads=1,
|
num_attention_heads=1,
|
||||||
size_per_head=512,
|
size_per_head=512,
|
||||||
query_act=None,
|
query_act=None,
|
||||||
|
@ -312,12 +311,13 @@ class MultiheadAttention(nn.Cell):
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32):
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.from_seq_length = from_seq_length
|
|
||||||
self.to_seq_length = to_seq_length
|
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.size_per_head = size_per_head
|
self.size_per_head = size_per_head
|
||||||
self.has_attention_mask = has_attention_mask
|
self.has_attention_mask = has_attention_mask
|
||||||
assert has_attention_mask
|
assert has_attention_mask
|
||||||
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.do_return_2d_tensor = do_return_2d_tensor
|
||||||
|
|
||||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
|
@ -345,9 +345,6 @@ class MultiheadAttention(nn.Cell):
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
|
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
|
||||||
|
|
||||||
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
|
|
||||||
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
|
|
||||||
|
|
||||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||||
self.multiply = P.Mul()
|
self.multiply = P.Mul()
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
|
@ -368,27 +365,33 @@ class MultiheadAttention(nn.Cell):
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.get_dtype = P.DType()
|
self.get_dtype = P.DType()
|
||||||
if do_return_2d_tensor:
|
|
||||||
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
|
|
||||||
if from_seq_length == -1:
|
|
||||||
self.shape_return = (-1, num_attention_heads * size_per_head)
|
|
||||||
else:
|
|
||||||
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
|
|
||||||
|
|
||||||
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
||||||
self.softmax_cast = P.Cast()
|
self.softmax_cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, from_tensor, to_tensor, attention_mask=None):
|
def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None):
|
||||||
"""reshape 2d/3d input tensors to 2d"""
|
"""Apply multihead attention."""
|
||||||
|
from_seq_length = seq_length
|
||||||
|
to_seq_length = enc_seq_length
|
||||||
|
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head)
|
||||||
|
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head)
|
||||||
|
if self.do_return_2d_tensor:
|
||||||
|
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head)
|
||||||
|
if from_seq_length == -1:
|
||||||
|
shape_return = (-1, self.num_attention_heads * self.size_per_head)
|
||||||
|
else:
|
||||||
|
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head)
|
||||||
|
|
||||||
|
# reshape 2d/3d input tensors to 2d
|
||||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||||
query_out = self.query_layer(from_tensor_2d)
|
query_out = self.query_layer(from_tensor_2d)
|
||||||
key_out = self.key_layer(to_tensor_2d)
|
key_out = self.key_layer(to_tensor_2d)
|
||||||
value_out = self.value_layer(to_tensor_2d)
|
value_out = self.value_layer(to_tensor_2d)
|
||||||
|
|
||||||
query_layer = self.reshape(query_out, self.shape_from)
|
query_layer = self.reshape(query_out, shape_from)
|
||||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||||
key_layer = self.reshape(key_out, self.shape_to)
|
key_layer = self.reshape(key_out, shape_to)
|
||||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||||
|
|
||||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||||
|
@ -407,12 +410,12 @@ class MultiheadAttention(nn.Cell):
|
||||||
if self.use_dropout:
|
if self.use_dropout:
|
||||||
attention_probs = self.dropout(attention_probs)
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
value_layer = self.reshape(value_out, self.shape_to)
|
value_layer = self.reshape(value_out, shape_to)
|
||||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||||
context_layer = self.matmul(attention_probs, value_layer)
|
context_layer = self.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||||
context_layer = self.reshape(context_layer, self.shape_return)
|
context_layer = self.reshape(context_layer, shape_return)
|
||||||
context_layer = self.out_layer(context_layer)
|
context_layer = self.out_layer(context_layer)
|
||||||
return context_layer
|
return context_layer
|
||||||
|
|
||||||
|
@ -438,8 +441,6 @@ class SelfAttention(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
from_seq_length,
|
|
||||||
to_seq_length,
|
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
|
@ -461,8 +462,6 @@ class SelfAttention(nn.Cell):
|
||||||
from_tensor_width=hidden_size,
|
from_tensor_width=hidden_size,
|
||||||
to_tensor_width=hidden_size,
|
to_tensor_width=hidden_size,
|
||||||
out_tensor_width=hidden_size,
|
out_tensor_width=hidden_size,
|
||||||
from_seq_length=from_seq_length,
|
|
||||||
to_seq_length=to_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
size_per_head=self.size_per_head,
|
size_per_head=self.size_per_head,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
@ -477,7 +476,7 @@ class SelfAttention(nn.Cell):
|
||||||
|
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = (-1, hidden_size)
|
self.shape = (-1, hidden_size)
|
||||||
def construct(self, input_tensor, memory_tensor, attention_mask):
|
def construct(self, input_tensor, memory_tensor, attention_mask, seq_length, enc_seq_length):
|
||||||
"""Apply self-attention."""
|
"""Apply self-attention."""
|
||||||
input_tensor = self.reshape(input_tensor, self.shape)
|
input_tensor = self.reshape(input_tensor, self.shape)
|
||||||
memory_tensor = self.reshape(memory_tensor, self.shape)
|
memory_tensor = self.reshape(memory_tensor, self.shape)
|
||||||
|
@ -487,7 +486,7 @@ class SelfAttention(nn.Cell):
|
||||||
if not self.is_encdec_att:
|
if not self.is_encdec_att:
|
||||||
memory_tensor = output
|
memory_tensor = output
|
||||||
|
|
||||||
attention_output = self.attention(output, memory_tensor, attention_mask)
|
attention_output = self.attention(output, memory_tensor, seq_length, enc_seq_length, attention_mask)
|
||||||
output = self.postprocess(attention_output, input_tensor)
|
output = self.postprocess(attention_output, input_tensor)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -563,7 +562,6 @@ class EncoderCell(nn.Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
seq_length=128,
|
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
intermediate_size=4096,
|
intermediate_size=4096,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
|
@ -576,8 +574,6 @@ class EncoderCell(nn.Cell):
|
||||||
self.attention = SelfAttention(
|
self.attention = SelfAttention(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
from_seq_length=seq_length,
|
|
||||||
to_seq_length=seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
@ -594,9 +590,9 @@ class EncoderCell(nn.Cell):
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type)
|
||||||
|
|
||||||
def construct(self, hidden_states, attention_mask):
|
def construct(self, hidden_states, attention_mask, seq_length):
|
||||||
# self-attention with ln, res
|
# self-attention with ln, res
|
||||||
attention_output = self.attention(hidden_states, hidden_states, attention_mask)
|
attention_output = self.attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
|
||||||
# feed forward with ln, res
|
# feed forward with ln, res
|
||||||
output = self.feedforward(attention_output)
|
output = self.feedforward(attention_output)
|
||||||
return output
|
return output
|
||||||
|
@ -624,7 +620,6 @@ class TransformerEncoder(nn.Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
seq_length,
|
|
||||||
num_hidden_layers,
|
num_hidden_layers,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
intermediate_size=4096,
|
intermediate_size=4096,
|
||||||
|
@ -636,12 +631,13 @@ class TransformerEncoder(nn.Cell):
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32):
|
||||||
super(TransformerEncoder, self).__init__()
|
super(TransformerEncoder, self).__init__()
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for _ in range(num_hidden_layers):
|
for _ in range(num_hidden_layers):
|
||||||
layer = EncoderCell(batch_size=batch_size,
|
layer = EncoderCell(batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
seq_length=seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
@ -657,17 +653,18 @@ class TransformerEncoder(nn.Cell):
|
||||||
|
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = (-1, hidden_size)
|
self.shape = (-1, hidden_size)
|
||||||
self.out_shape = (batch_size, seq_length, hidden_size)
|
|
||||||
|
|
||||||
def construct(self, input_tensor, attention_mask):
|
def construct(self, input_tensor, attention_mask, seq_length):
|
||||||
|
"""Apply encoder."""
|
||||||
|
out_shape = (self.batch_size, seq_length, self.hidden_size)
|
||||||
prev_output = self.reshape(input_tensor, self.shape)
|
prev_output = self.reshape(input_tensor, self.shape)
|
||||||
|
|
||||||
for layer_module in self.layers:
|
for layer_module in self.layers:
|
||||||
layer_output = layer_module(prev_output, attention_mask)
|
layer_output = layer_module(prev_output, attention_mask, seq_length)
|
||||||
prev_output = layer_output
|
prev_output = layer_output
|
||||||
|
|
||||||
prev_output = self.layer_preprocess(prev_output)
|
prev_output = self.layer_preprocess(prev_output)
|
||||||
output = self.reshape(prev_output, self.out_shape)
|
output = self.reshape(prev_output, out_shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -693,8 +690,6 @@ class DecoderCell(nn.Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
seq_length=128,
|
|
||||||
enc_seq_length=128,
|
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
intermediate_size=4096,
|
intermediate_size=4096,
|
||||||
attention_probs_dropout_prob=0.02,
|
attention_probs_dropout_prob=0.02,
|
||||||
|
@ -707,8 +702,6 @@ class DecoderCell(nn.Cell):
|
||||||
self.self_attention = SelfAttention(
|
self.self_attention = SelfAttention(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
from_seq_length=seq_length,
|
|
||||||
to_seq_length=seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
@ -719,8 +712,6 @@ class DecoderCell(nn.Cell):
|
||||||
self.cross_attention = SelfAttention(
|
self.cross_attention = SelfAttention(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
from_seq_length=seq_length,
|
|
||||||
to_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
@ -737,11 +728,12 @@ class DecoderCell(nn.Cell):
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type)
|
||||||
|
|
||||||
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask):
|
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
|
||||||
# self-attention with ln, res
|
# self-attention with ln, res
|
||||||
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask)
|
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
|
||||||
# cross-attention with ln, res
|
# cross-attention with ln, res
|
||||||
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask)
|
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask,
|
||||||
|
seq_length, enc_seq_length)
|
||||||
# feed forward with ln, res
|
# feed forward with ln, res
|
||||||
output = self.feedforward(attention_output)
|
output = self.feedforward(attention_output)
|
||||||
return output
|
return output
|
||||||
|
@ -770,8 +762,6 @@ class TransformerDecoder(nn.Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
seq_length,
|
|
||||||
enc_seq_length,
|
|
||||||
num_hidden_layers,
|
num_hidden_layers,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
intermediate_size=4096,
|
intermediate_size=4096,
|
||||||
|
@ -788,8 +778,6 @@ class TransformerDecoder(nn.Cell):
|
||||||
for _ in range(num_hidden_layers):
|
for _ in range(num_hidden_layers):
|
||||||
layer = DecoderCell(batch_size=batch_size,
|
layer = DecoderCell(batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
seq_length=seq_length,
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
@ -805,17 +793,21 @@ class TransformerDecoder(nn.Cell):
|
||||||
|
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = (-1, hidden_size)
|
self.shape = (-1, hidden_size)
|
||||||
self.out_shape = (batch_size, seq_length, hidden_size)
|
self.hidden_size = hidden_size
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
|
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
|
||||||
|
"""Apply decoder."""
|
||||||
|
out_shape = (self.batch_size, seq_length, self.hidden_size)
|
||||||
prev_output = self.reshape(input_tensor, self.shape)
|
prev_output = self.reshape(input_tensor, self.shape)
|
||||||
|
|
||||||
for layer_module in self.layers:
|
for layer_module in self.layers:
|
||||||
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
|
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask,
|
||||||
|
seq_length, enc_seq_length)
|
||||||
prev_output = layer_output
|
prev_output = layer_output
|
||||||
|
|
||||||
prev_output = self.layer_preprocess(prev_output)
|
prev_output = self.layer_preprocess(prev_output)
|
||||||
output = self.reshape(prev_output, self.out_shape)
|
output = self.reshape(prev_output, out_shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -860,13 +852,11 @@ class PredLogProbs(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
seq_length,
|
|
||||||
width,
|
width,
|
||||||
compute_type=mstype.float32,
|
compute_type=mstype.float32,
|
||||||
dtype=mstype.float32):
|
dtype=mstype.float32):
|
||||||
super(PredLogProbs, self).__init__()
|
super(PredLogProbs, self).__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
|
||||||
self.width = width
|
self.width = width
|
||||||
self.compute_type = compute_type
|
self.compute_type = compute_type
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
@ -874,14 +864,16 @@ class PredLogProbs(nn.Cell):
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||||
self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width)
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self,
|
def construct(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
output_weights):
|
output_weights,
|
||||||
|
seq_length):
|
||||||
"""Get log probs."""
|
"""Get log probs."""
|
||||||
input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width)
|
||||||
|
|
||||||
|
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor)
|
||||||
input_tensor = self.cast(input_tensor, self.compute_type)
|
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||||
output_weights = self.cast(output_weights, self.compute_type)
|
output_weights = self.cast(output_weights, self.compute_type)
|
||||||
|
|
||||||
|
@ -918,7 +910,6 @@ class TransformerDecoderStep(nn.Cell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size,
|
batch_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
enc_seq_length,
|
|
||||||
max_decode_length,
|
max_decode_length,
|
||||||
num_hidden_layers,
|
num_hidden_layers,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
@ -942,8 +933,6 @@ class TransformerDecoderStep(nn.Cell):
|
||||||
self.tfm_decoder = TransformerDecoder(
|
self.tfm_decoder = TransformerDecoder(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
seq_length=-1, # -1 means length is not fixed
|
|
||||||
enc_seq_length=enc_seq_length,
|
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
num_hidden_layers=num_hidden_layers,
|
num_hidden_layers=num_hidden_layers,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
|
@ -966,7 +955,7 @@ class TransformerDecoderStep(nn.Cell):
|
||||||
|
|
||||||
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
||||||
|
|
||||||
def construct(self, input_ids, enc_states, enc_attention_mask):
|
def construct(self, input_ids, enc_states, enc_attention_mask, seq_length):
|
||||||
"""
|
"""
|
||||||
Multi-layer transformer decoder step.
|
Multi-layer transformer decoder step.
|
||||||
input_ids: [batch_size * beam_width]
|
input_ids: [batch_size * beam_width]
|
||||||
|
@ -988,17 +977,23 @@ class TransformerDecoderStep(nn.Cell):
|
||||||
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
|
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
|
||||||
|
|
||||||
# call TransformerDecoder
|
# call TransformerDecoder
|
||||||
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask)
|
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length)
|
||||||
|
|
||||||
# take the last step
|
# take the last step
|
||||||
decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
|
decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
|
||||||
|
|
||||||
# projection and log_prob
|
# projection and log_prob
|
||||||
log_probs = self.projection(decoder_output, embedding_tables)
|
log_probs = self.projection(decoder_output, embedding_tables, 1)
|
||||||
|
|
||||||
return log_probs
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def convert_np_to_tensor_encoder(seq_length):
|
||||||
|
ones = np.ones(shape=(seq_length, seq_length))
|
||||||
|
return Tensor(np.tril(ones), dtype=mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
class TransformerModel(nn.Cell):
|
class TransformerModel(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Transformer with encoder and decoder.
|
Transformer with encoder and decoder.
|
||||||
|
@ -1021,12 +1016,13 @@ class TransformerModel(nn.Cell):
|
||||||
|
|
||||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||||
self.batch_size = config.batch_size
|
self.batch_size = config.batch_size
|
||||||
self.seq_length = config.seq_length
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
self.embedding_size = config.hidden_size
|
self.embedding_size = config.hidden_size
|
||||||
|
|
||||||
self.last_idx = self.num_hidden_layers - 1
|
self.last_idx = self.num_hidden_layers - 1
|
||||||
|
self.beam_width = config.beam_width
|
||||||
|
self.max_decode_length = config.max_decode_length
|
||||||
|
|
||||||
self.tfm_embedding_lookup = EmbeddingLookup(
|
self.tfm_embedding_lookup = EmbeddingLookup(
|
||||||
vocab_size=config.vocab_size,
|
vocab_size=config.vocab_size,
|
||||||
|
@ -1048,7 +1044,6 @@ class TransformerModel(nn.Cell):
|
||||||
self.tfm_encoder = TransformerEncoder(
|
self.tfm_encoder = TransformerEncoder(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
seq_length=self.seq_length,
|
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
@ -1062,15 +1057,12 @@ class TransformerModel(nn.Cell):
|
||||||
if is_training:
|
if is_training:
|
||||||
self.projection = PredLogProbs(
|
self.projection = PredLogProbs(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
seq_length=self.seq_length,
|
|
||||||
width=self.hidden_size,
|
width=self.hidden_size,
|
||||||
compute_type=config.compute_type,
|
compute_type=config.compute_type,
|
||||||
dtype=config.dtype)
|
dtype=config.dtype)
|
||||||
self.tfm_decoder = TransformerDecoder(
|
self.tfm_decoder = TransformerDecoder(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
seq_length=self.seq_length,
|
|
||||||
enc_seq_length=self.seq_length,
|
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
@ -1083,14 +1075,12 @@ class TransformerModel(nn.Cell):
|
||||||
else:
|
else:
|
||||||
self.projection = PredLogProbs(
|
self.projection = PredLogProbs(
|
||||||
batch_size=self.batch_size * config.beam_width,
|
batch_size=self.batch_size * config.beam_width,
|
||||||
seq_length=1,
|
|
||||||
width=self.hidden_size,
|
width=self.hidden_size,
|
||||||
compute_type=config.compute_type,
|
compute_type=config.compute_type,
|
||||||
dtype=config.dtype)
|
dtype=config.dtype)
|
||||||
self.tfm_decoder = TransformerDecoderStep(
|
self.tfm_decoder = TransformerDecoderStep(
|
||||||
batch_size=self.batch_size * config.beam_width,
|
batch_size=self.batch_size * config.beam_width,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
enc_seq_length=self.seq_length,
|
|
||||||
max_decode_length=config.max_decode_length,
|
max_decode_length=config.max_decode_length,
|
||||||
num_hidden_layers=config.num_hidden_layers,
|
num_hidden_layers=config.num_hidden_layers,
|
||||||
num_attention_heads=config.num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
@ -1113,24 +1103,24 @@ class TransformerModel(nn.Cell):
|
||||||
length_penalty_weight=config.length_penalty_weight,
|
length_penalty_weight=config.length_penalty_weight,
|
||||||
max_decode_length=config.max_decode_length)
|
max_decode_length=config.max_decode_length)
|
||||||
|
|
||||||
|
self.tfm_decoder.add_flags(loop_can_unroll=True)
|
||||||
|
self.tile_beam = TileBeam(beam_width=self.beam_width)
|
||||||
|
ones = np.ones(shape=(self.batch_size, self.max_decode_length))
|
||||||
|
self.encdec_mask = Tensor(ones, mstype.float32)
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.dtype = config.dtype
|
self.dtype = config.dtype
|
||||||
self.cast_compute_type = CastWrapper(dst_type=config.compute_type)
|
self.cast_compute_type = CastWrapper(dst_type=config.compute_type)
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
self.multiply = P.Mul()
|
self.multiply = P.Mul()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
|
||||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
|
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
|
||||||
|
|
||||||
if is_training:
|
|
||||||
ones = np.ones(shape=(self.seq_length, self.seq_length))
|
|
||||||
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
|
|
||||||
else:
|
|
||||||
self.tile_beam = TileBeam(beam_width=config.beam_width)
|
|
||||||
ones = np.ones(shape=(config.batch_size, config.max_decode_length))
|
|
||||||
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
|
|
||||||
|
|
||||||
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
||||||
"""Transformer with encoder and decoder."""
|
"""Transformer with encoder and decoder."""
|
||||||
|
seq_length = self.shape(source_ids)[1]
|
||||||
|
|
||||||
# process source sentence
|
# process source sentence
|
||||||
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
|
||||||
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
|
||||||
|
@ -1138,21 +1128,24 @@ class TransformerModel(nn.Cell):
|
||||||
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
|
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
|
||||||
# transformer encoder
|
# transformer encoder
|
||||||
encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output),
|
encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output),
|
||||||
self.cast_compute_type(enc_attention_mask))
|
self.cast_compute_type(enc_attention_mask),
|
||||||
|
seq_length)
|
||||||
|
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
|
future_mask = convert_np_to_tensor_encoder(seq_length)
|
||||||
# process target sentence
|
# process target sentence
|
||||||
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
|
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
|
||||||
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
|
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
|
||||||
# attention mask [batch_size, seq_length, seq_length]
|
# attention mask [batch_size, seq_length, seq_length]
|
||||||
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
|
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
|
||||||
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0))
|
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(future_mask, 0))
|
||||||
# transformer decoder
|
# transformer decoder
|
||||||
decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output),
|
decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output),
|
||||||
self.cast_compute_type(tgt_attention_mask),
|
self.cast_compute_type(tgt_attention_mask),
|
||||||
encoder_output, enc_attention_mask)
|
encoder_output, enc_attention_mask,
|
||||||
|
seq_length, seq_length)
|
||||||
# calculate logits and log_probs
|
# calculate logits and log_probs
|
||||||
log_probs = self.projection(decoder_output, embedding_tables)
|
log_probs = self.projection(decoder_output, embedding_tables, seq_length)
|
||||||
ret = log_probs
|
ret = log_probs
|
||||||
else:
|
else:
|
||||||
beam_encoder_output = self.tile_beam(encoder_output)
|
beam_encoder_output = self.tile_beam(encoder_output)
|
||||||
|
|
|
@ -105,6 +105,9 @@ def argparse_init():
|
||||||
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
|
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
|
||||||
"default is ./checkpoint/")
|
"default is ./checkpoint/")
|
||||||
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
|
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--bucket_boundaries", type=list, default=[16, 32, 48, 64, 128], help="sequence length for "
|
||||||
|
"different bucket")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def run_transformer_train():
|
def run_transformer_train():
|
||||||
|
@ -129,7 +132,8 @@ def run_transformer_train():
|
||||||
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
|
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
|
||||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||||
enable_data_sink=args.enable_data_sink,
|
enable_data_sink=args.enable_data_sink,
|
||||||
dataset_path=args.data_path)
|
dataset_path=args.data_path,
|
||||||
|
bucket_boundaries=args.bucket_boundaries)
|
||||||
|
|
||||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||||
|
|
||||||
|
|
|
@ -24,12 +24,13 @@ from mindspore.nn.optim import Adam
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
|
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
|
||||||
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \
|
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \
|
||||||
TransformerTrainOneStepWithLossScaleCell
|
TransformerTrainOneStepWithLossScaleCell
|
||||||
from model_zoo.official.nlp.transformer.src.config import cfg
|
from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
|
||||||
from model_zoo.official.nlp.transformer.src.dataset import create_transformer_dataset
|
|
||||||
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
|
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
|
||||||
|
|
||||||
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
|
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
|
||||||
|
@ -76,6 +77,24 @@ def get_config(version='base', batch_size=1):
|
||||||
transformer_cfg = TransformerConfig(batch_size=batch_size)
|
transformer_cfg = TransformerConfig(batch_size=batch_size)
|
||||||
return transformer_cfg
|
return transformer_cfg
|
||||||
|
|
||||||
|
def load_test_data(batch_size=1, data_file=None):
|
||||||
|
"""Load test dataset."""
|
||||||
|
ds = de.MindDataset(data_file,
|
||||||
|
columns_list=["source_eos_ids", "source_eos_mask",
|
||||||
|
"target_sos_ids", "target_sos_mask",
|
||||||
|
"target_eos_ids", "target_eos_mask"],
|
||||||
|
shuffle=False)
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
return ds
|
||||||
|
|
||||||
class ModelCallback(Callback):
|
class ModelCallback(Callback):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ModelCallback, self).__init__()
|
super(ModelCallback, self).__init__()
|
||||||
|
@ -120,10 +139,7 @@ def test_transformer():
|
||||||
batch_size = 96
|
batch_size = 96
|
||||||
epoch_size = 3
|
epoch_size = 3
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version, batch_size=batch_size)
|
||||||
dataset = create_transformer_dataset(epoch_count=1,
|
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=DATA_DIR)
|
||||||
do_shuffle="false",
|
|
||||||
enable_data_sink="false",
|
|
||||||
dataset_path=DATA_DIR)
|
|
||||||
|
|
||||||
netwithloss = TransformerNetworkWithLoss(config, True)
|
netwithloss = TransformerNetworkWithLoss(config, True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue