transformer bucket batch modification

This commit is contained in:
yuchaojie 2020-09-11 16:10:29 +08:00
parent a0e3fd6bf3
commit fa1247a85e
10 changed files with 168 additions and 129 deletions

View File

@ -54,10 +54,10 @@ After dataset preparation, you can start training and evaluation as follows:
```bash ```bash
# run training example # run training example
sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord00 sh scripts/run_standalone_train_ascend.sh 0 52 /path/ende-l128-mindrecord
# run distributed training example # run distributed training example
sh scripts/run_distribute_train_ascend.sh 8 52 /path/newstest2014-l128-mindrecord rank_table.json sh scripts/run_distribute_train_ascend.sh 8 52 /path/ende-l128-mindrecord rank_table.json
# run evaluation example # run evaluation example
python eval.py > eval.log 2>&1 & python eval.py > eval.log 2>&1 &
@ -104,6 +104,7 @@ usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [
[--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N] [--enable_data_sink ENABLE_DATA_SINK] [--save_checkpoint_steps N]
[--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH] [--save_checkpoint_num N] [--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--data_path DATA_PATH] [--data_path DATA_PATH]
[--bucket_boundaries BUCKET_LENGTH]
options: options:
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
@ -119,6 +120,7 @@ options:
--save_checkpoint_num number for saving checkpoint files: N, default is 30 --save_checkpoint_num number for saving checkpoint files: N, default is 30
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/" --save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
--data_path path to dataset file: PATH, default is "" --data_path path to dataset file: PATH, default is ""
--bucket_boundaries sequence lengths for different bucket: LIST, default is [16, 32, 48, 64, 128]
``` ```
### Running Options ### Running Options
@ -179,13 +181,13 @@ Parameters for learning rate:
``` bash ``` bash
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 --bucket [16, 32, 48, 64, 128]
``` ```
- Convert the original data to mindrecord for evaluation: - Convert the original data to mindrecord for evaluation:
``` bash ``` bash
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True --bucket [128]
``` ```

View File

@ -51,20 +51,29 @@ class SampleInstance():
return self.__str__() return self.__str__()
def write_instance_to_file(writer, instance, tokenizer, max_seq_length): def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
"""Create files from `SampleInstance`s.""" """Create files from `SampleInstance`s."""
def _find_bucket_length(num):
assert num <= bucket[-1]
for index in range(1, len(bucket)):
if bucket[index - 1] < num <= bucket[index]:
return bucket[index]
return bucket[0]
def _convert_ids_and_mask(input_tokens): def _convert_ids_and_mask(input_tokens):
input_ids = tokenizer.convert_tokens_to_ids(input_tokens) input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids) input_mask = [1] * len(input_ids)
assert len(input_ids) <= max_seq_length assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length: seq_max_bucket_length = _find_bucket_length(len(input_ids))
while len(input_ids) < seq_max_bucket_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
assert len(input_ids) == max_seq_length assert len(input_ids) == seq_max_bucket_length
assert len(input_mask) == max_seq_length assert len(input_mask) == seq_max_bucket_length
return input_ids, input_mask return input_ids, input_mask
@ -93,7 +102,6 @@ def create_training_instance(source_words, target_words, max_seq_length, clip_to
if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length: if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
if clip_to_max_len: if clip_to_max_len:
print("####lalalal")
source_words = source_words[:min([len(source_words, max_seq_length-1)])] source_words = source_words[:min([len(source_words, max_seq_length-1)])]
target_words = target_words[:min([len(target_words, max_seq_length-1)])] target_words = target_words[:min([len(target_words, max_seq_length-1)])]
else: else:
@ -123,6 +131,8 @@ def main():
parser.add_argument("--clip_to_max_len", type=bool, default=False, parser.add_argument("--clip_to_max_len", type=bool, default=False,
help='clip sequences to maximum sequence length.') help='clip sequences to maximum sequence length.')
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.') parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
parser.add_argument("--bucket", type=list, default=[16, 32, 48, 64, 128], help='bucket sequence length')
args = parser.parse_args() args = parser.parse_args()
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file) tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
@ -179,7 +189,7 @@ def main():
if instance is None: if instance is None:
continue continue
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length) features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length, args.bucket)
total_written += 1 total_written += 1
if total_written <= 20: if total_written <= 20:

View File

@ -52,7 +52,7 @@ do
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="false" \
--checkpoint_path="" \ --checkpoint_path="" \
--save_checkpoint_steps=2500 \ --save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \ --save_checkpoint_num=30 \

View File

@ -37,7 +37,7 @@ python train.py \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="false" \
--checkpoint_path="" \ --checkpoint_path="" \
--save_checkpoint_steps=2500 \ --save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \ --save_checkpoint_num=30 \

View File

@ -134,6 +134,7 @@ class BeamSearchDecoder(nn.Cell):
eos_id=2, eos_id=2,
compute_type=mstype.float32): compute_type=mstype.float32):
super(BeamSearchDecoder, self).__init__(auto_prefix=False) super(BeamSearchDecoder, self).__init__(auto_prefix=False)
self.seq_length = seq_length
self.batch_size = batch_size self.batch_size = batch_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.beam_width = beam_width self.beam_width = beam_width
@ -182,7 +183,7 @@ class BeamSearchDecoder(nn.Cell):
""" """
One step for decode One step for decode
""" """
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask) log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size)) log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
# select topk indices # select topk indices

View File

@ -15,30 +15,40 @@
"""Data operations, will be used in train.py.""" """Data operations, will be used in train.py."""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
from .config import transformer_net_cfg from .config import transformer_net_cfg
de.config.set_seed(1)
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true", def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
dataset_path=None): dataset_path=None, bucket_boundaries=None):
"""create dataset""" """create dataset"""
repeat_count = epoch_count def batch_per_bucket(bucket_len, dataset_path):
ds = de.MindDataset(dataset_path, dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
columns_list=["source_eos_ids", "source_eos_mask", ds = de.MindDataset(dataset_path,
"target_sos_ids", "target_sos_mask", columns_list=["source_eos_ids", "source_eos_mask",
"target_eos_ids", "target_eos_mask"], "target_sos_ids", "target_sos_mask",
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id) "target_eos_ids", "target_eos_mask"],
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
type_cast_op = deC.TypeCast(mstype.int32) # apply batch operations
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids") ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask") ds = ds.repeat(epoch_count)
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids") return ds
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
# apply batch operations
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
for i, _ in enumerate(bucket_boundaries):
bucket_len = bucket_boundaries[i]
ds_per = batch_per_bucket(bucket_len, dataset_path)
if i == 0:
ds = ds_per
else:
ds = ds + ds_per
ds = ds.shuffle(ds.get_dataset_size())
ds.channel_name = 'transformer'
return ds return ds

View File

@ -95,12 +95,13 @@ class TransformerTrainingLoss(nn.Cell):
self.flatten = P.Flatten() self.flatten = P.Flatten()
self.neg = P.Neg() self.neg = P.Neg()
self.cast = P.Cast() self.cast = P.Cast()
self.flat_shape = (config.batch_size * config.seq_length,) self.batch_size = config.batch_size
def construct(self, prediction_scores, label_ids, label_weights): def construct(self, prediction_scores, label_ids, label_weights, seq_length):
"""Defines the computation performed.""" """Defines the computation performed."""
label_ids = self.reshape(label_ids, self.flat_shape) flat_shape = (self.batch_size * seq_length,)
label_weights = self.cast(self.reshape(label_weights, self.flat_shape), mstype.float32) label_ids = self.reshape(label_ids, flat_shape)
label_weights = self.cast(self.reshape(label_weights, flat_shape), mstype.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
@ -128,6 +129,7 @@ class TransformerNetworkWithLoss(nn.Cell):
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings) self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
self.loss = TransformerTrainingLoss(config) self.loss = TransformerTrainingLoss(config)
self.cast = P.Cast() self.cast = P.Cast()
self.shape = P.Shape()
def construct(self, def construct(self,
source_ids, source_ids,
@ -136,8 +138,10 @@ class TransformerNetworkWithLoss(nn.Cell):
target_mask, target_mask,
label_ids, label_ids,
label_weights): label_weights):
"""Transformer network with loss."""
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask) prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
total_loss = self.loss(prediction_scores, label_ids, label_weights) seq_length = self.shape(source_ids)[1]
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
return self.cast(total_loss, mstype.float32) return self.cast(total_loss, mstype.float32)
@ -156,7 +160,6 @@ class TransformerTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False) super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)

View File

@ -23,6 +23,7 @@ import mindspore.ops.functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.ops.primitive import constexpr
from .beam_search import BeamSearchDecoder, TileBeam from .beam_search import BeamSearchDecoder, TileBeam
from .weight_init import normal_weight, weight_variable from .weight_init import normal_weight, weight_variable
@ -296,8 +297,6 @@ class MultiheadAttention(nn.Cell):
from_tensor_width, from_tensor_width,
to_tensor_width, to_tensor_width,
out_tensor_width, out_tensor_width,
from_seq_length,
to_seq_length,
num_attention_heads=1, num_attention_heads=1,
size_per_head=512, size_per_head=512,
query_act=None, query_act=None,
@ -312,12 +311,13 @@ class MultiheadAttention(nn.Cell):
compute_type=mstype.float32): compute_type=mstype.float32):
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.from_seq_length = from_seq_length
self.to_seq_length = to_seq_length
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask self.has_attention_mask = has_attention_mask
assert has_attention_mask assert has_attention_mask
self.use_one_hot_embeddings = use_one_hot_embeddings
self.initializer_range = initializer_range
self.do_return_2d_tensor = do_return_2d_tensor
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.reshape = P.Reshape() self.reshape = P.Reshape()
@ -345,9 +345,6 @@ class MultiheadAttention(nn.Cell):
has_bias=False, has_bias=False,
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type) weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True) self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul() self.multiply = P.Mul()
self.transpose = P.Transpose() self.transpose = P.Transpose()
@ -368,27 +365,33 @@ class MultiheadAttention(nn.Cell):
self.add = P.TensorAdd() self.add = P.TensorAdd()
self.cast = P.Cast() self.cast = P.Cast()
self.get_dtype = P.DType() self.get_dtype = P.DType()
if do_return_2d_tensor:
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
if from_seq_length == -1:
self.shape_return = (-1, num_attention_heads * size_per_head)
else:
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = CastWrapper(dst_type=compute_type) self.cast_compute_type = CastWrapper(dst_type=compute_type)
self.softmax_cast = P.Cast() self.softmax_cast = P.Cast()
def construct(self, from_tensor, to_tensor, attention_mask=None): def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None):
"""reshape 2d/3d input tensors to 2d""" """Apply multihead attention."""
from_seq_length = seq_length
to_seq_length = enc_seq_length
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head)
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head)
if self.do_return_2d_tensor:
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head)
if from_seq_length == -1:
shape_return = (-1, self.num_attention_heads * self.size_per_head)
else:
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head)
# reshape 2d/3d input tensors to 2d
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query_layer(from_tensor_2d) query_out = self.query_layer(from_tensor_2d)
key_out = self.key_layer(to_tensor_2d) key_out = self.key_layer(to_tensor_2d)
value_out = self.value_layer(to_tensor_2d) value_out = self.value_layer(to_tensor_2d)
query_layer = self.reshape(query_out, self.shape_from) query_layer = self.reshape(query_out, shape_from)
query_layer = self.transpose(query_layer, self.trans_shape) query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, self.shape_to) key_layer = self.reshape(key_out, shape_to)
key_layer = self.transpose(key_layer, self.trans_shape) key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer) attention_scores = self.matmul_trans_b(query_layer, key_layer)
@ -407,12 +410,12 @@ class MultiheadAttention(nn.Cell):
if self.use_dropout: if self.use_dropout:
attention_probs = self.dropout(attention_probs) attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, self.shape_to) value_layer = self.reshape(value_out, shape_to)
value_layer = self.transpose(value_layer, self.trans_shape) value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer) context_layer = self.matmul(attention_probs, value_layer)
context_layer = self.transpose(context_layer, self.trans_shape) context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, self.shape_return) context_layer = self.reshape(context_layer, shape_return)
context_layer = self.out_layer(context_layer) context_layer = self.out_layer(context_layer)
return context_layer return context_layer
@ -438,8 +441,6 @@ class SelfAttention(nn.Cell):
""" """
def __init__(self, def __init__(self,
batch_size, batch_size,
from_seq_length,
to_seq_length,
hidden_size, hidden_size,
num_attention_heads=16, num_attention_heads=16,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
@ -461,8 +462,6 @@ class SelfAttention(nn.Cell):
from_tensor_width=hidden_size, from_tensor_width=hidden_size,
to_tensor_width=hidden_size, to_tensor_width=hidden_size,
out_tensor_width=hidden_size, out_tensor_width=hidden_size,
from_seq_length=from_seq_length,
to_seq_length=to_seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
size_per_head=self.size_per_head, size_per_head=self.size_per_head,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
@ -477,7 +476,7 @@ class SelfAttention(nn.Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = (-1, hidden_size) self.shape = (-1, hidden_size)
def construct(self, input_tensor, memory_tensor, attention_mask): def construct(self, input_tensor, memory_tensor, attention_mask, seq_length, enc_seq_length):
"""Apply self-attention.""" """Apply self-attention."""
input_tensor = self.reshape(input_tensor, self.shape) input_tensor = self.reshape(input_tensor, self.shape)
memory_tensor = self.reshape(memory_tensor, self.shape) memory_tensor = self.reshape(memory_tensor, self.shape)
@ -487,7 +486,7 @@ class SelfAttention(nn.Cell):
if not self.is_encdec_att: if not self.is_encdec_att:
memory_tensor = output memory_tensor = output
attention_output = self.attention(output, memory_tensor, attention_mask) attention_output = self.attention(output, memory_tensor, seq_length, enc_seq_length, attention_mask)
output = self.postprocess(attention_output, input_tensor) output = self.postprocess(attention_output, input_tensor)
return output return output
@ -563,7 +562,6 @@ class EncoderCell(nn.Cell):
def __init__(self, def __init__(self,
batch_size, batch_size,
hidden_size=1024, hidden_size=1024,
seq_length=128,
num_attention_heads=16, num_attention_heads=16,
intermediate_size=4096, intermediate_size=4096,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
@ -576,8 +574,6 @@ class EncoderCell(nn.Cell):
self.attention = SelfAttention( self.attention = SelfAttention(
batch_size=batch_size, batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
from_seq_length=seq_length,
to_seq_length=seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_embeddings=use_one_hot_embeddings,
@ -594,9 +590,9 @@ class EncoderCell(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type) compute_type=compute_type)
def construct(self, hidden_states, attention_mask): def construct(self, hidden_states, attention_mask, seq_length):
# self-attention with ln, res # self-attention with ln, res
attention_output = self.attention(hidden_states, hidden_states, attention_mask) attention_output = self.attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
# feed forward with ln, res # feed forward with ln, res
output = self.feedforward(attention_output) output = self.feedforward(attention_output)
return output return output
@ -624,7 +620,6 @@ class TransformerEncoder(nn.Cell):
def __init__(self, def __init__(self,
batch_size, batch_size,
hidden_size, hidden_size,
seq_length,
num_hidden_layers, num_hidden_layers,
num_attention_heads=16, num_attention_heads=16,
intermediate_size=4096, intermediate_size=4096,
@ -636,12 +631,13 @@ class TransformerEncoder(nn.Cell):
compute_type=mstype.float32): compute_type=mstype.float32):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.batch_size = batch_size
self.hidden_size = hidden_size
layers = [] layers = []
for _ in range(num_hidden_layers): for _ in range(num_hidden_layers):
layer = EncoderCell(batch_size=batch_size, layer = EncoderCell(batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
@ -657,17 +653,18 @@ class TransformerEncoder(nn.Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = (-1, hidden_size) self.shape = (-1, hidden_size)
self.out_shape = (batch_size, seq_length, hidden_size)
def construct(self, input_tensor, attention_mask): def construct(self, input_tensor, attention_mask, seq_length):
"""Apply encoder."""
out_shape = (self.batch_size, seq_length, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape) prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers: for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask) layer_output = layer_module(prev_output, attention_mask, seq_length)
prev_output = layer_output prev_output = layer_output
prev_output = self.layer_preprocess(prev_output) prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, self.out_shape) output = self.reshape(prev_output, out_shape)
return output return output
@ -693,8 +690,6 @@ class DecoderCell(nn.Cell):
def __init__(self, def __init__(self,
batch_size, batch_size,
hidden_size=1024, hidden_size=1024,
seq_length=128,
enc_seq_length=128,
num_attention_heads=12, num_attention_heads=12,
intermediate_size=4096, intermediate_size=4096,
attention_probs_dropout_prob=0.02, attention_probs_dropout_prob=0.02,
@ -707,8 +702,6 @@ class DecoderCell(nn.Cell):
self.self_attention = SelfAttention( self.self_attention = SelfAttention(
batch_size=batch_size, batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
from_seq_length=seq_length,
to_seq_length=seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_embeddings=use_one_hot_embeddings,
@ -719,8 +712,6 @@ class DecoderCell(nn.Cell):
self.cross_attention = SelfAttention( self.cross_attention = SelfAttention(
batch_size=batch_size, batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
from_seq_length=seq_length,
to_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_embeddings=use_one_hot_embeddings,
@ -737,11 +728,12 @@ class DecoderCell(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type) compute_type=compute_type)
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask): def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
# self-attention with ln, res # self-attention with ln, res
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask) attention_output = self.self_attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
# cross-attention with ln, res # cross-attention with ln, res
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask) attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask,
seq_length, enc_seq_length)
# feed forward with ln, res # feed forward with ln, res
output = self.feedforward(attention_output) output = self.feedforward(attention_output)
return output return output
@ -770,8 +762,6 @@ class TransformerDecoder(nn.Cell):
def __init__(self, def __init__(self,
batch_size, batch_size,
hidden_size, hidden_size,
seq_length,
enc_seq_length,
num_hidden_layers, num_hidden_layers,
num_attention_heads=16, num_attention_heads=16,
intermediate_size=4096, intermediate_size=4096,
@ -788,8 +778,6 @@ class TransformerDecoder(nn.Cell):
for _ in range(num_hidden_layers): for _ in range(num_hidden_layers):
layer = DecoderCell(batch_size=batch_size, layer = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
@ -805,17 +793,21 @@ class TransformerDecoder(nn.Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = (-1, hidden_size) self.shape = (-1, hidden_size)
self.out_shape = (batch_size, seq_length, hidden_size) self.hidden_size = hidden_size
self.batch_size = batch_size
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
"""Apply decoder."""
out_shape = (self.batch_size, seq_length, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape) prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers: for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask,
seq_length, enc_seq_length)
prev_output = layer_output prev_output = layer_output
prev_output = self.layer_preprocess(prev_output) prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, self.out_shape) output = self.reshape(prev_output, out_shape)
return output return output
@ -860,13 +852,11 @@ class PredLogProbs(nn.Cell):
""" """
def __init__(self, def __init__(self,
batch_size, batch_size,
seq_length,
width, width,
compute_type=mstype.float32, compute_type=mstype.float32,
dtype=mstype.float32): dtype=mstype.float32):
super(PredLogProbs, self).__init__() super(PredLogProbs, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length
self.width = width self.width = width
self.compute_type = compute_type self.compute_type = compute_type
self.dtype = dtype self.dtype = dtype
@ -874,14 +864,16 @@ class PredLogProbs(nn.Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.log_softmax = nn.LogSoftmax(axis=-1) self.log_softmax = nn.LogSoftmax(axis=-1)
self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width)
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, def construct(self,
input_tensor, input_tensor,
output_weights): output_weights,
seq_length):
"""Get log probs.""" """Get log probs."""
input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width)
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor)
input_tensor = self.cast(input_tensor, self.compute_type) input_tensor = self.cast(input_tensor, self.compute_type)
output_weights = self.cast(output_weights, self.compute_type) output_weights = self.cast(output_weights, self.compute_type)
@ -918,7 +910,6 @@ class TransformerDecoderStep(nn.Cell):
def __init__(self, def __init__(self,
batch_size, batch_size,
hidden_size, hidden_size,
enc_seq_length,
max_decode_length, max_decode_length,
num_hidden_layers, num_hidden_layers,
num_attention_heads=16, num_attention_heads=16,
@ -942,8 +933,6 @@ class TransformerDecoderStep(nn.Cell):
self.tfm_decoder = TransformerDecoder( self.tfm_decoder = TransformerDecoder(
batch_size=batch_size, batch_size=batch_size,
hidden_size=hidden_size, hidden_size=hidden_size,
seq_length=-1, # -1 means length is not fixed
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers, num_hidden_layers=num_hidden_layers,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
@ -966,7 +955,7 @@ class TransformerDecoderStep(nn.Cell):
self.cast_compute_type = CastWrapper(dst_type=compute_type) self.cast_compute_type = CastWrapper(dst_type=compute_type)
def construct(self, input_ids, enc_states, enc_attention_mask): def construct(self, input_ids, enc_states, enc_attention_mask, seq_length):
""" """
Multi-layer transformer decoder step. Multi-layer transformer decoder step.
input_ids: [batch_size * beam_width] input_ids: [batch_size * beam_width]
@ -988,17 +977,23 @@ class TransformerDecoderStep(nn.Cell):
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
# call TransformerDecoder # call TransformerDecoder
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask) decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length)
# take the last step # take the last step
decoder_output = decoder_output[::, input_len-1:input_len:1, ::] decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
# projection and log_prob # projection and log_prob
log_probs = self.projection(decoder_output, embedding_tables) log_probs = self.projection(decoder_output, embedding_tables, 1)
return log_probs return log_probs
@constexpr
def convert_np_to_tensor_encoder(seq_length):
ones = np.ones(shape=(seq_length, seq_length))
return Tensor(np.tril(ones), dtype=mstype.float32)
class TransformerModel(nn.Cell): class TransformerModel(nn.Cell):
""" """
Transformer with encoder and decoder. Transformer with encoder and decoder.
@ -1021,12 +1016,13 @@ class TransformerModel(nn.Cell):
self.input_mask_from_dataset = config.input_mask_from_dataset self.input_mask_from_dataset = config.input_mask_from_dataset
self.batch_size = config.batch_size self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size self.embedding_size = config.hidden_size
self.last_idx = self.num_hidden_layers - 1 self.last_idx = self.num_hidden_layers - 1
self.beam_width = config.beam_width
self.max_decode_length = config.max_decode_length
self.tfm_embedding_lookup = EmbeddingLookup( self.tfm_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
@ -1048,7 +1044,6 @@ class TransformerModel(nn.Cell):
self.tfm_encoder = TransformerEncoder( self.tfm_encoder = TransformerEncoder(
batch_size=self.batch_size, batch_size=self.batch_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads, num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -1062,15 +1057,12 @@ class TransformerModel(nn.Cell):
if is_training: if is_training:
self.projection = PredLogProbs( self.projection = PredLogProbs(
batch_size=self.batch_size, batch_size=self.batch_size,
seq_length=self.seq_length,
width=self.hidden_size, width=self.hidden_size,
compute_type=config.compute_type, compute_type=config.compute_type,
dtype=config.dtype) dtype=config.dtype)
self.tfm_decoder = TransformerDecoder( self.tfm_decoder = TransformerDecoder(
batch_size=self.batch_size, batch_size=self.batch_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
seq_length=self.seq_length,
enc_seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads, num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -1083,14 +1075,12 @@ class TransformerModel(nn.Cell):
else: else:
self.projection = PredLogProbs( self.projection = PredLogProbs(
batch_size=self.batch_size * config.beam_width, batch_size=self.batch_size * config.beam_width,
seq_length=1,
width=self.hidden_size, width=self.hidden_size,
compute_type=config.compute_type, compute_type=config.compute_type,
dtype=config.dtype) dtype=config.dtype)
self.tfm_decoder = TransformerDecoderStep( self.tfm_decoder = TransformerDecoderStep(
batch_size=self.batch_size * config.beam_width, batch_size=self.batch_size * config.beam_width,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
enc_seq_length=self.seq_length,
max_decode_length=config.max_decode_length, max_decode_length=config.max_decode_length,
num_hidden_layers=config.num_hidden_layers, num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads, num_attention_heads=config.num_attention_heads,
@ -1113,24 +1103,24 @@ class TransformerModel(nn.Cell):
length_penalty_weight=config.length_penalty_weight, length_penalty_weight=config.length_penalty_weight,
max_decode_length=config.max_decode_length) max_decode_length=config.max_decode_length)
self.tfm_decoder.add_flags(loop_can_unroll=True)
self.tile_beam = TileBeam(beam_width=self.beam_width)
ones = np.ones(shape=(self.batch_size, self.max_decode_length))
self.encdec_mask = Tensor(ones, mstype.float32)
self.cast = P.Cast() self.cast = P.Cast()
self.dtype = config.dtype self.dtype = config.dtype
self.cast_compute_type = CastWrapper(dst_type=config.compute_type) self.cast_compute_type = CastWrapper(dst_type=config.compute_type)
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.multiply = P.Mul() self.multiply = P.Mul()
self.shape = P.Shape()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
if is_training:
ones = np.ones(shape=(self.seq_length, self.seq_length))
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
else:
self.tile_beam = TileBeam(beam_width=config.beam_width)
ones = np.ones(shape=(config.batch_size, config.max_decode_length))
self.encdec_mask = Tensor(ones, dtype=mstype.float32)
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
"""Transformer with encoder and decoder.""" """Transformer with encoder and decoder."""
seq_length = self.shape(source_ids)[1]
# process source sentence # process source sentence
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
@ -1138,21 +1128,24 @@ class TransformerModel(nn.Cell):
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
# transformer encoder # transformer encoder
encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output), encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output),
self.cast_compute_type(enc_attention_mask)) self.cast_compute_type(enc_attention_mask),
seq_length)
if self.is_training: if self.is_training:
future_mask = convert_np_to_tensor_encoder(seq_length)
# process target sentence # process target sentence
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings) tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
# attention mask [batch_size, seq_length, seq_length] # attention mask [batch_size, seq_length, seq_length]
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(future_mask, 0))
# transformer decoder # transformer decoder
decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output), decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output),
self.cast_compute_type(tgt_attention_mask), self.cast_compute_type(tgt_attention_mask),
encoder_output, enc_attention_mask) encoder_output, enc_attention_mask,
seq_length, seq_length)
# calculate logits and log_probs # calculate logits and log_probs
log_probs = self.projection(decoder_output, embedding_tables) log_probs = self.projection(decoder_output, embedding_tables, seq_length)
ret = log_probs ret = log_probs
else: else:
beam_encoder_output = self.tile_beam(encoder_output) beam_encoder_output = self.tile_beam(encoder_output)

View File

@ -105,6 +105,9 @@ def argparse_init():
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, " parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
"default is ./checkpoint/") "default is ./checkpoint/")
parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--bucket_boundaries", type=list, default=[16, 32, 48, 64, 128], help="sequence length for "
"different bucket")
return parser return parser
def run_transformer_train(): def run_transformer_train():
@ -129,7 +132,8 @@ def run_transformer_train():
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num, dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle, rank_id=rank_id, do_shuffle=args.do_shuffle,
enable_data_sink=args.enable_data_sink, enable_data_sink=args.enable_data_sink,
dataset_path=args.data_path) dataset_path=args.data_path,
bucket_boundaries=args.bucket_boundaries)
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)

View File

@ -24,12 +24,13 @@ from mindspore.nn.optim import Adam
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
from mindspore import context from mindspore import context
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \ from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, \
TransformerTrainOneStepWithLossScaleCell TransformerTrainOneStepWithLossScaleCell
from model_zoo.official.nlp.transformer.src.config import cfg from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
from model_zoo.official.nlp.transformer.src.dataset import create_transformer_dataset
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"] DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
@ -76,6 +77,24 @@ def get_config(version='base', batch_size=1):
transformer_cfg = TransformerConfig(batch_size=batch_size) transformer_cfg = TransformerConfig(batch_size=batch_size)
return transformer_cfg return transformer_cfg
def load_test_data(batch_size=1, data_file=None):
"""Load test dataset."""
ds = de.MindDataset(data_file,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
shuffle=False)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds
class ModelCallback(Callback): class ModelCallback(Callback):
def __init__(self): def __init__(self):
super(ModelCallback, self).__init__() super(ModelCallback, self).__init__()
@ -120,10 +139,7 @@ def test_transformer():
batch_size = 96 batch_size = 96
epoch_size = 3 epoch_size = 3
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
dataset = create_transformer_dataset(epoch_count=1, dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=DATA_DIR)
do_shuffle="false",
enable_data_sink="false",
dataset_path=DATA_DIR)
netwithloss = TransformerNetworkWithLoss(config, True) netwithloss = TransformerNetworkWithLoss(config, True)