!47686 transformer st用例新增pynative模式下支持动态shape

Merge pull request !47686 from zhangdong/zd_3
This commit is contained in:
i-robot 2023-02-25 02:58:07 +00:00 committed by Gitee
commit 890a949b4a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 268 additions and 451 deletions

View File

@ -142,12 +142,18 @@ def create_transformer_dynamic_dataset(rank_size=1, rank_id=0, do_shuffle="true"
return dataset
def get_train_loss():
def get_train_loss(is_graph_mode, device_target, device_id=0):
"""
Transformer training.
"""
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU", reserve_class_name_in_scope=False,
if is_graph_mode:
mode = ms.GRAPH_MODE
else:
mode = ms.PYNATIVE_MODE
ms.set_context(mode=mode, device_target=device_target, reserve_class_name_in_scope=False,
enable_graph_kernel=False)
# Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
if ms.get_context("mode") == ms.PYNATIVE_MODE:
ms.set_context(mempool_block_size="31GB")
@ -156,7 +162,7 @@ def get_train_loss():
rank_id = 0
dataset = create_transformer_dynamic_dataset(rank_size=device_num, rank_id=rank_id, do_shuffle=True)
netwithloss = TransformerNetworkWithLoss(True)
netwithloss = TransformerNetworkWithLoss(True, is_graph_mode=is_graph_mode)
hidden_size = 1024
learning_rate = 1.0
@ -177,8 +183,12 @@ def get_train_loss():
update_cell = scale_manager.get_update_cell()
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=update_cell)
data_col = Tensor(shape=[BATCH_SIZE, None], dtype=ms.int64)
netwithgrads.set_inputs(data_col, data_col, data_col, data_col, data_col, data_col, data_col)
if is_graph_mode:
data_col_int64 = Tensor(shape=[BATCH_SIZE, None], dtype=ms.int64)
data_col = Tensor(shape=[BATCH_SIZE, None], dtype=ms.float32)
netwithgrads.set_inputs(data_col_int64, data_col_int64, data_col_int64,
data_col_int64, data_col_int64, data_col_int64,
data_col)
netwithgrads.set_train(True)
model = Model(netwithgrads)
@ -187,15 +197,29 @@ def get_train_loss():
return loss_list
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train():
def test_train_graph_mode_gpu():
"""
Feature: Test the simplified dynamic shape transformer network with small data.
Description: The sequence length of inputs is dynamic.
Expectation: Assert that the training loss of fixed data is consistent with the expected loss.
"""
graph_loss = get_train_loss()
graph_loss = get_train_loss(True, "GPU")
expect_loss = [11.193909]
assert np.allclose(graph_loss, expect_loss, 5e-3, 5e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_pynative_mode_gpu():
"""
Feature: Test the simplified dynamic shape transformer network with small data.
Description: The sequence length of inputs is dynamic.
Expectation: Assert that the training loss of fixed data is consistent with the expected loss.
"""
graph_loss = get_train_loss(False, "GPU")
expect_loss = [11.112342]
assert np.allclose(graph_loss[0], expect_loss, 5e-3, 5e-3)

View File

@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
"""Transformer for training."""
import mindspore
from mindspore import jit
import mindspore as ms
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
@ -25,7 +27,9 @@ from tests.st.dynamic_shape.transformer.transformer_model import TransformerMode
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 5.0
BATCH_SIZE_VALUE = 32
VOCAB_SIZE = 36560
LABEL_SMOOTHING = 0.1
BATCH_SIZE = 32
clip_grad = ops.MultitypeFuncGraph("clip_grad")
@ -55,12 +59,21 @@ def _clip_grad(clip_type, clip_value, grad):
class TransformerTrainingLoss(nn.Cell):
def __init__(self):
"""
Provide transformer training loss.
Args:
is_graph_mode (bool): is graph mode.
Returns:
Tensor, total loss.
"""
def __init__(self, is_graph_mode):
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
self.vocab_size = 36560
self.vocab_size = VOCAB_SIZE
self.onehot = ops.OneHot()
self.on_value = Tensor(float(1 - 0.1), mindspore.float32)
self.off_value = Tensor(0.1 / float(self.vocab_size - 1), mindspore.float32)
self.on_value = Tensor(float(1 - LABEL_SMOOTHING), ms.float32)
self.off_value = Tensor(LABEL_SMOOTHING / float(self.vocab_size - 1), ms.float32)
self.reduce_sum = ops.ReduceSum()
self.reduce_mean = ops.ReduceMean()
self.reshape = ops.Reshape()
@ -68,20 +81,23 @@ class TransformerTrainingLoss(nn.Cell):
self.flatten = ops.Flatten()
self.neg = ops.Neg()
self.cast = ops.Cast()
self.batch_size = BATCH_SIZE_VALUE
self.batch_size = BATCH_SIZE
self.is_graph_mode = is_graph_mode
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
"""Defines the computation performed."""
flat_shape = (self.batch_size * seq_length,)
flat_shape = (-1,)
if self.is_graph_mode:
flat_shape = (-1,)
label_ids = self.reshape(label_ids, flat_shape)
label_weights = self.cast(self.reshape(label_weights, flat_shape), mindspore.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
label_weights = self.cast(self.reshape(label_weights, flat_shape), ms.float32)
one_hot_labels = self.onehot(self.cast(label_ids, ms.int32), self.cast(self.vocab_size, ms.int32),
self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
numerator = self.reduce_sum(label_weights * per_example_loss, ())
denominator = self.reduce_sum(label_weights, ()) + \
self.cast(ops.tuple_to_array((1e-5,)), mindspore.float32)
self.cast(ops.tuple_to_array((1e-5,)), ms.float32)
loss = numerator / denominator
return loss
@ -93,16 +109,19 @@ class TransformerNetworkWithLoss(nn.Cell):
Args:
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
is_graph_mode (bool): is graph mode.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, is_training, use_one_hot_embeddings=False):
def __init__(self, is_training, use_one_hot_embeddings=False, is_graph_mode=False):
super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False)
self.transformer = TransformerModel(is_training, use_one_hot_embeddings)
self.loss = TransformerTrainingLoss()
self.transformer = TransformerModel(is_training, use_one_hot_embeddings, is_graph_mode)
self.loss = TransformerTrainingLoss(is_graph_mode)
self.cast = ops.Cast()
self.shape = ops.TensorShape()
self.shape = ops.Shape()
if is_graph_mode:
self.shape = ops.TensorShape()
def construct(self,
source_ids,
@ -115,7 +134,7 @@ class TransformerNetworkWithLoss(nn.Cell):
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
seq_length = self.shape(source_ids)[1]
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
return self.cast(total_loss, mindspore.float32)
return self.cast(total_loss, ms.float32)
grad_scale = ops.MultitypeFuncGraph("grad_scale")
@ -158,7 +177,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mindspore.float32))
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=ms.float32))
self.enable_tuple_broaden = True
@jit
def clip_grads(self, grads):
grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return grads
@jit
def clip_scale_grads(self, scale, grads):
grads = self.hyper_map(ops.partial(grad_scale, scale * self.degree), grads)
return grads
def construct(self,
source_eos_ids,
@ -195,12 +225,12 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
label_ids,
label_weights,
self.cast(scaling_sens,
mindspore.float32))
ms.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.clip_scale_grads(scaling_sens, grads)
grads = self.clip_grads(grads)
cond = self.get_overflow_status(status, grads)
overflow = cond
@ -208,4 +238,39 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if not overflow:
self.optimizer(grads)
return (loss, cond, scaling_sens)
return (loss, cond, scaling_sens.value())
cast = ops.Cast()
add_grads = ops.MultitypeFuncGraph("add_grads")
@add_grads.register("Tensor", "Tensor")
def _add_grads(accu_grad, grad):
return accu_grad + cast(grad, ms.float32)
update_accu_grads = ops.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return ops.depend(succ, ops.assign(accu_grad, cast(grad, ms.float32)))
accumulate_accu_grads = ops.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True
return ops.depend(succ, ops.assign_add(accu_grad, cast(grad, ms.float32)))
zeroslike = ops.ZerosLike()
reset_accu_grads = ops.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return ops.depend(succ, ops.assign(accu_grad, zeroslike(accu_grad)))

View File

@ -23,234 +23,30 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops.primitive import constexpr
from mindspore.ops.operations import array_ops as array_op
import mindspore.ops.function as func
from mindspore.ops.operations import _inner_ops as inner_op
BATCH_SIZE_VALUE = 32
INF = 1. * 1e9
BATCH_SIZE = 32
HIDDEN_SIZE = 1024
NUMBER_HIDDEN_LAYERS = 6
MAX_DECODE_LENGTH = 80
VOCAB_SIZE = 36560
MAX_POSITION_EMBEDDINGS = 128
INTERMEDIATE_SIZE = 4096
INIT_RANGE = 0.02
HIDDEN_DROPOUT_PROB = 0.2
NUM_ATTENTION_HEADS = 16
ATTENTION_PROBS_DROPOUT_PROB = 0.2
HIDDEN_ACT = "relu"
COMPUTE_TYPE = ms.float16
DTYPE = ms.float32
BEAM_WIDTH = 4
SEQ_LENGTH = 128
LENGTH_PENALTY_WEIGHT = 1.0
class LengthPenalty(nn.Cell):
"""
Normalize scores of translations according to their length.
Args:
weight (float): Weight of length penalty. Default: 1.0.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
"""
def __init__(self,
weight=1.0,
compute_type=ms.float32):
super(LengthPenalty, self).__init__()
self.weight = weight
self.add = ops.Add()
self.pow = ops.Pow()
self.div = ops.RealDiv()
self.cast = ops.Cast()
self.five = Tensor(5.0, ms.float32)
self.six = Tensor(6.0, ms.float32)
def construct(self, length_tensor):
length_tensor = self.cast(length_tensor, ms.float32)
output = self.add(length_tensor, self.five)
output = self.div(output, self.six)
output = self.pow(output, self.weight)
return output
class Mod(nn.Cell):
"""
Mod function.
Args:
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
"""
def __init__(self,
compute_type=ms.float32):
super(Mod, self).__init__()
self.compute_type = compute_type
self.floor_div = ops.FloorDiv()
self.sub = ops.Sub()
self.multiply = ops.Mul()
def construct(self, input_x, input_y):
x = self.floor_div(input_x, input_y)
x = self.multiply(x, input_y)
x = self.sub(input_x, x)
return x
class BeamSearchDecoder(nn.Cell):
"""
Beam search decoder.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
vocab_size (int): Size of vocabulary.
decoder (:class:`TransformerDecoderStep`): Decoder module.
beam_width (int): beam width setting. Default: 4.
length_penalty_weight (float): Weight of length penalty. Default: 1.0.
max_decode_length (int): max decode length. Default: 128.
sos_id (int): Id of sequence start token. Default: 1.
eos_id (int): Id of sequence end token. Default: 2.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
"""
def __init__(self,
batch_size,
seq_length,
vocab_size,
decoder,
beam_width=4,
length_penalty_weight=1.0,
max_decode_length=128,
sos_id=1,
eos_id=2,
compute_type=ms.float32):
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
self.seq_length = seq_length
self.batch_size = batch_size
self.vocab_size = vocab_size
self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight
self.max_decode_length = max_decode_length
self.decoder = decoder
self.add = ops.Add()
self.expand = ops.ExpandDims()
self.reshape = ops.Reshape()
self.shape_flat = (-1,)
self.shape = ops.TensorShape()
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), ms.float32)
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), ms.float32)
self.select = ops.Select()
self.flat_shape = (batch_size, beam_width * vocab_size)
self.topk = ops.TopK(sorted=True)
self.floor_div = ops.FloorDiv()
self.vocab_size_tensor = Tensor(self.vocab_size, ms.int32)
self.real_div = ops.RealDiv()
self.mod = Mod()
self.equal = ops.Equal()
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), ms.int32)
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
self.beam_ids = Tensor(beam_ids, ms.int32)
batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width
self.batch_ids = Tensor(batch_ids, ms.int32)
self.concat = ops.Concat(axis=-1)
self.gather_nd = ops.GatherNd()
self.greater_equal = ops.GreaterEqual()
self.sub = ops.Sub()
self.cast = ops.Cast()
self.zeroslike = ops.ZerosLike()
# init inputs and states
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), ms.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), ms.int32)
init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1])
self.init_scores = Tensor(init_scores, ms.float32)
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
self.one = Tensor(1, ms.int32)
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_finished, state_length):
"""
One step for decode
"""
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
# select topk indices
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
# mask finished beams
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
# reshape scores to [batch, beam*vocab]
flat_scores = self.reshape(total_log_probs, self.flat_shape)
# select topk
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
temp = topk_indices
beam_indices = self.zeroslike(topk_indices)
for _ in range(self.beam_width - 1):
temp = self.sub(temp, self.vocab_size_tensor)
res = self.cast(self.greater_equal(temp, 0), ms.int32)
beam_indices = beam_indices + res
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
#======================================================================
# mask finished indices
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
word_indices = self.select(state_finished, self.eos_ids, word_indices)
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
###### put finished sequences to the end
# sort according to scores with -inf for finished beams
tmp_log_probs = self.select(
self.equal(word_indices, self.eos_ids),
self.ninf_tensor,
topk_scores)
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
# update
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
###### generate new beam_search states
# gather indices for selecting alive beams
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
# length add 1 if not finished in the previous step
length_add = self.add(state_length, self.one)
state_length = self.select(state_finished, state_length, length_add)
state_length = self.gather_nd(state_length, gather_indices)
# concat seq
seq = self.gather_nd(state_seq, gather_indices)
state_seq = self.concat((seq, self.expand(word_indices, -1)))
# new finished flag and log_probs
state_finished = self.equal(word_indices, self.eos_ids)
state_log_probs = topk_scores
###### generate new inputs and decoder states
cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1))
return cur_input_ids, state_log_probs, state_seq, state_finished, state_length
def construct(self, enc_states, enc_attention_mask):
"""Get beam search result."""
cur_input_ids = self.start_ids
# beam search states
state_log_probs = self.init_scores
state_seq = self.init_seq
state_finished = self.init_finished
state_length = self.init_length
for _ in range(self.max_decode_length):
# run one step decoder to get outputs of the current step
cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step(
cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length)
# add length penalty scores
penalty_len = self.length_penalty(state_length)
# get penalty length
log_probs = self.real_div(state_log_probs, penalty_len)
# sort according to scores
_, top_beam_indices = self.topk(log_probs, self.beam_width)
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
# sort sequence
predicted_ids = self.gather_nd(state_seq, gather_indices)
# take the first one
predicted_ids = predicted_ids[::, 0:1:1, ::]
return predicted_ids
def normal_weight(shape, num_units):
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
return Tensor(norm)
def _average_units(shape):
@ -275,11 +71,6 @@ def weight_variable(shape):
return Tensor(values)
def normal_weight(shape, num_units):
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
return Tensor(norm)
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
@ -291,12 +82,14 @@ class EmbeddingLookup(nn.Cell):
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
is_graph_mode,
batch_size,
vocab_size,
embedding_size,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.is_graph_mode = is_graph_mode
self.batch_size = batch_size
self.vocab_size = vocab_size
self.embedding_size = embedding_size
@ -310,10 +103,14 @@ class EmbeddingLookup(nn.Cell):
self.off_value = Tensor(0.0, ms.float32)
self.array_mul = ops.MatMul()
self.reshape = ops.Reshape()
self.shape = ops.TensorShape()
self.shape = ops.Shape()
if is_graph_mode:
self.shape = ops.TensorShape()
def construct(self, input_ids):
"""Get a embeddings lookup table with a fixed dictionary and size."""
input_shape = self.shape(input_ids)
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
@ -321,9 +118,11 @@ class EmbeddingLookup(nn.Cell):
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
out_shape = (self.batch_size, -1, self.embedding_size)
out_shape = input_shape + (self.embedding_size,)
if self.is_graph_mode:
out_shape = (self.batch_size, -1, self.embedding_size)
output = self.reshape(output_for_reshape, out_shape)
return output, self.embedding_table
return output, self.embedding_table.value()
def position_encoding(length,
@ -364,6 +163,7 @@ class EmbeddingPostprocessor(nn.Cell):
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
is_graph_mode,
embedding_size,
use_one_hot_embeddings=False,
initializer_range=0.02,
@ -378,7 +178,9 @@ class EmbeddingPostprocessor(nn.Cell):
self.expand_dims = ops.ExpandDims()
self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size),
ms.float32)
self.shape = ops.TensorShape()
self.shape = ops.Shape()
if is_graph_mode:
self.shape = ops.TensorShape()
def construct(self, word_embeddings):
"""Postprocessors apply positional embeddings to word embeddings."""
@ -472,6 +274,7 @@ class MultiheadAttention(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
from_tensor_width,
to_tensor_width,
@ -490,6 +293,7 @@ class MultiheadAttention(nn.Cell):
compute_type=ms.float32):
super(MultiheadAttention, self).__init__()
self.batch_size = batch_size
self.is_graph_mode = is_graph_mode
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask
@ -551,10 +355,16 @@ class MultiheadAttention(nn.Cell):
def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None):
"""Apply multihead attention."""
from_seq_length = seq_length
shape_from = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
shape_to = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
to_seq_length = enc_seq_length
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head)
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head)
if self.is_graph_mode:
shape_from = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
shape_to = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
if self.do_return_2d_tensor:
shape_return = (-1, self.num_attention_heads * self.size_per_head)
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head)
if self.is_graph_mode:
shape_return = (-1, self.num_attention_heads * self.size_per_head)
else:
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head)
@ -616,6 +426,7 @@ class SelfAttention(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
hidden_size,
num_attention_heads=16,
@ -634,6 +445,7 @@ class SelfAttention(nn.Cell):
self.is_encdec_att = is_encdec_att
self.attention = MultiheadAttention(
is_graph_mode=is_graph_mode,
batch_size=batch_size,
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
@ -737,6 +549,7 @@ class EncoderCell(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
hidden_size=1024,
num_attention_heads=16,
@ -749,6 +562,7 @@ class EncoderCell(nn.Cell):
compute_type=ms.float32):
super(EncoderCell, self).__init__()
self.attention = SelfAttention(
is_graph_mode=is_graph_mode,
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
@ -795,6 +609,7 @@ class TransformerEncoder(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
hidden_size,
num_hidden_layers,
@ -809,11 +624,13 @@ class TransformerEncoder(nn.Cell):
super(TransformerEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
self.batch_size = batch_size
self.is_graph_mode = is_graph_mode
self.hidden_size = hidden_size
layers = []
for _ in range(num_hidden_layers):
layer = EncoderCell(batch_size=batch_size,
layer = EncoderCell(is_graph_mode=is_graph_mode,
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
@ -833,7 +650,9 @@ class TransformerEncoder(nn.Cell):
def construct(self, input_tensor, attention_mask, seq_length):
"""Apply encoder."""
out_shape = (self.batch_size, -1, self.hidden_size)
out_shape = (self.batch_size, seq_length, self.hidden_size)
if self.is_graph_mode:
out_shape = (self.batch_size, -1, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers:
@ -865,6 +684,7 @@ class DecoderCell(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
hidden_size=1024,
num_attention_heads=12,
@ -877,6 +697,7 @@ class DecoderCell(nn.Cell):
compute_type=ms.float32):
super(DecoderCell, self).__init__()
self.self_attention = SelfAttention(
is_graph_mode=is_graph_mode,
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
@ -887,6 +708,7 @@ class DecoderCell(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
self.cross_attention = SelfAttention(
is_graph_mode=is_graph_mode,
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
@ -937,6 +759,7 @@ class TransformerDecoder(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
hidden_size,
num_hidden_layers,
@ -953,7 +776,8 @@ class TransformerDecoder(nn.Cell):
layers = []
for _ in range(num_hidden_layers):
layer = DecoderCell(batch_size=batch_size,
layer = DecoderCell(is_graph_mode=is_graph_mode,
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
@ -972,10 +796,13 @@ class TransformerDecoder(nn.Cell):
self.shape = (-1, hidden_size)
self.hidden_size = hidden_size
self.batch_size = batch_size
self.is_graph_mode = is_graph_mode
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
"""Apply decoder."""
out_shape = (self.batch_size, -1, self.hidden_size)
out_shape = (self.batch_size, seq_length, self.hidden_size)
if self.is_graph_mode:
out_shape = (self.batch_size, -1, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers:
@ -989,19 +816,36 @@ class TransformerDecoder(nn.Cell):
class CreateAttentionMaskFromInputMask(nn.Cell):
def __init__(self):
"""
Create attention mask according to input mask.
Args:
is_graph_mode (bool): is graph mode.
"""
def __init__(self, is_graph_mode):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.cast = ops.Cast()
self.reshape = ops.Reshape()
self.shape = ops.TensorShape()
self.shape = ops.Shape()
if is_graph_mode:
self.shape = ops.TensorShape()
self.batch_matmul = ops.BatchMatMul()
self.expand_dims = ops.ExpandDims()
self.is_graph_mode = is_graph_mode
def construct(self, input_mask):
"""Create attention mask according to input mask."""
input_shape = self.shape(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
input_mask = self.cast(input_mask, ms.float32)
mask_left = self.expand_dims(input_mask, 2)
mask_right = self.expand_dims(input_mask, 1)
if self.is_graph_mode:
mask_left = self.expand_dims(input_mask, 2)
mask_right = self.expand_dims(input_mask, 1)
else:
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.batch_matmul(mask_left, mask_right)
return attention_mask
@ -1019,11 +863,13 @@ class PredLogProbs(nn.Cell):
dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: ms.float32.
"""
def __init__(self,
is_graph_mode,
batch_size,
width,
compute_type=ms.float32,
dtype=ms.float32):
super(PredLogProbs, self).__init__()
self.is_graph_mode = is_graph_mode
self.batch_size = batch_size
self.width = width
self.compute_type = compute_type
@ -1039,7 +885,9 @@ class PredLogProbs(nn.Cell):
output_weights,
seq_length):
"""Get log probs."""
shape_flat_sequence_tensor = (-1, self.width)
shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width)
if self.is_graph_mode:
shape_flat_sequence_tensor = (-1, self.width)
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor)
input_tensor = self.cast(input_tensor, self.compute_type)
@ -1052,110 +900,6 @@ class PredLogProbs(nn.Cell):
return log_probs
class TransformerDecoderStep(nn.Cell):
"""
Multi-layer transformer decoder step.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
max_decode_length (int): Max decode length.
enc_seq_length (int): Length of source sentences.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 16.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module.
embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module.
projection (:class:`PredLogProbs`): PredLogProbs module
"""
def __init__(self,
batch_size,
hidden_size,
max_decode_length,
num_hidden_layers,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.3,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.3,
hidden_act="relu",
compute_type=ms.float32,
embedding_lookup=None,
embedding_processor=None,
projection=None):
super(TransformerDecoderStep, self).__init__(auto_prefix=False)
self.num_hidden_layers = num_hidden_layers
self.tfm_embedding_lookup = embedding_lookup
self.tfm_embedding_processor = embedding_processor
self.projection = projection
self.tfm_decoder = TransformerDecoder(
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.ones_like = ops.OnesLike()
self.shape = ops.TensorShape()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
self.expand = ops.ExpandDims()
self.multiply = ops.Mul()
ones = np.ones(shape=(max_decode_length, max_decode_length))
self.future_mask = Tensor(np.tril(ones), dtype=ms.float32)
self.cast_compute_type = CastWrapper(dst_type=compute_type)
def construct(self, input_ids, enc_states, enc_attention_mask, seq_length):
"""
Multi-layer transformer decoder step.
input_ids: [batch_size * beam_width]
"""
# process embedding
input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids)
input_embedding = self.tfm_embedding_processor(input_embedding)
input_embedding = self.cast_compute_type(input_embedding)
input_shape = self.shape(input_ids)
input_len = input_shape[1]
future_mask = self.future_mask[0:input_len:1, 0:input_len:1]
input_mask = self.ones_like(input_ids)
input_mask = self._create_attention_mask_from_input_mask(input_mask)
input_mask = self.multiply(input_mask, self.expand(future_mask, 0))
input_mask = self.cast_compute_type(input_mask)
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
# call TransformerDecoder
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length)
# take the last step
decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
# projection and log_prob
log_probs = self.projection(decoder_output, embedding_tables, 1)
return log_probs
@constexpr
def convert_np_to_tensor_encoder(seq_length):
ones = np.ones(shape=(seq_length, seq_length))
@ -1169,118 +913,97 @@ class TransformerModel(nn.Cell):
Args:
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
is_graph_mode (bool): is graph mode.
"""
def __init__(self,
is_training,
use_one_hot_embeddings=False):
use_one_hot_embeddings=False,
is_graph_mode=False):
super(TransformerModel, self).__init__()
self.is_training = is_training
self.batch_size = BATCH_SIZE_VALUE
self.hidden_size = 1024
self.num_hidden_layers = 6
self.embedding_size = 1024
self.batch_size = BATCH_SIZE
self.is_graph_mode = is_graph_mode
self.hidden_size = HIDDEN_SIZE
self.num_hidden_layers = NUMBER_HIDDEN_LAYERS
self.embedding_size = HIDDEN_SIZE
self.last_idx = self.num_hidden_layers - 1
self.beam_width = 4
self.max_decode_length = 80
self.beam_width = BEAM_WIDTH
self.max_decode_length = MAX_DECODE_LENGTH
self.tfm_embedding_lookup = EmbeddingLookup(
is_graph_mode=self.is_graph_mode,
batch_size=self.batch_size,
vocab_size=36560,
vocab_size=VOCAB_SIZE,
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02)
initializer_range=INIT_RANGE)
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
is_graph_mode=self.is_graph_mode,
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=128,
dropout_prob=0.2)
max_position_embeddings=MAX_POSITION_EMBEDDINGS,
dropout_prob=HIDDEN_DROPOUT_PROB)
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
is_graph_mode=self.is_graph_mode,
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=128,
dropout_prob=0.2)
max_position_embeddings=MAX_POSITION_EMBEDDINGS,
dropout_prob=HIDDEN_DROPOUT_PROB)
self.tfm_encoder = TransformerEncoder(
is_graph_mode=self.is_graph_mode,
batch_size=self.batch_size,
hidden_size=self.hidden_size,
num_attention_heads=16,
num_attention_heads=NUM_ATTENTION_HEADS,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=4096,
attention_probs_dropout_prob=0.2,
intermediate_size=INTERMEDIATE_SIZE,
attention_probs_dropout_prob=ATTENTION_PROBS_DROPOUT_PROB,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
hidden_dropout_prob=0.2,
hidden_act="relu",
compute_type=ms.float16)
initializer_range=INIT_RANGE,
hidden_dropout_prob=HIDDEN_DROPOUT_PROB,
hidden_act=HIDDEN_ACT,
compute_type=COMPUTE_TYPE)
if is_training:
self.projection = PredLogProbs(
is_graph_mode=self.is_graph_mode,
batch_size=self.batch_size,
width=self.hidden_size,
compute_type=ms.float16,
dtype=ms.float32)
compute_type=COMPUTE_TYPE,
dtype=DTYPE)
self.tfm_decoder = TransformerDecoder(
is_graph_mode=self.is_graph_mode,
batch_size=self.batch_size,
hidden_size=self.hidden_size,
num_attention_heads=16,
num_attention_heads=NUM_ATTENTION_HEADS,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=4096,
attention_probs_dropout_prob=0.2,
intermediate_size=INTERMEDIATE_SIZE,
attention_probs_dropout_prob=ATTENTION_PROBS_DROPOUT_PROB,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
hidden_dropout_prob=0.2,
hidden_act="relu",
compute_type=ms.float16)
else:
self.projection = PredLogProbs(
batch_size=self.batch_size * 4,
width=self.hidden_size,
compute_type=ms.float16,
dtype=ms.float32)
self.tfm_decoder = TransformerDecoderStep(
batch_size=self.batch_size * 4,
hidden_size=self.hidden_size,
max_decode_length=80,
num_hidden_layers=6,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.2,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.2,
hidden_act="relu",
compute_type=ms.float16,
embedding_lookup=self.tfm_embedding_lookup,
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
projection=self.projection)
self.tfm_decoder = BeamSearchDecoder(
batch_size=BATCH_SIZE_VALUE,
seq_length=128,
vocab_size=36560,
decoder=self.tfm_decoder,
beam_width=4,
length_penalty_weight=1.0,
max_decode_length=80)
self.tfm_decoder.add_flags(loop_can_unroll=True)
ones = np.ones(shape=(self.batch_size, self.max_decode_length))
self.encdec_mask = Tensor(ones, ms.float32)
initializer_range=INIT_RANGE,
hidden_dropout_prob=HIDDEN_DROPOUT_PROB,
hidden_act=HIDDEN_ACT,
compute_type=COMPUTE_TYPE)
self.cast = ops.Cast()
self.dtype = ms.float32
self.cast_compute_type = CastWrapper(dst_type=ms.float16)
self.dtype = DTYPE
self.cast_compute_type = CastWrapper(dst_type=COMPUTE_TYPE)
self.expand = ops.ExpandDims()
self.multiply = ops.Mul()
self.shape = ops.TensorShape()
self.shape = ops.Shape()
if self.is_graph_mode:
self.shape = ops.TensorShape()
self.tril = array_op.Tril()
self.dynamic_broadcast_to = inner_op.DynamicBroadcastTo()
self.ones_like = array_op.OnesLike()
self.stack = array_op.Stack(0)
self.concatenate = array_op.Concat()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.is_graph_mode)
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
"""Transformer with encoder and decoder."""
@ -1296,11 +1019,16 @@ class TransformerModel(nn.Cell):
self.cast_compute_type(enc_attention_mask),
seq_length)
ones_inner = self.ones_like(source_ids[0, :])
seq_length = self.shape(ones_inner)
broadcast_shape = self.concatenate([seq_length, seq_length])
ones = func.broadcast_to(ones_inner, broadcast_shape)
future_mask = self.tril(ones)
if self.is_graph_mode:
ones_inner = self.ones_like(source_ids[0, :])
seq_length = self.shape(ones_inner)
broadcast_shape = self.concatenate([seq_length, seq_length])
ones = self.dynamic_broadcast_to(ones_inner, broadcast_shape)
future_mask = self.tril(ones)
else:
future_mask = convert_np_to_tensor_encoder(seq_length)
# process target sentence
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)