forked from mindspore-Ecosystem/mindspore
!47686 transformer st用例新增pynative模式下支持动态shape
Merge pull request !47686 from zhangdong/zd_3
This commit is contained in:
commit
890a949b4a
|
@ -142,12 +142,18 @@ def create_transformer_dynamic_dataset(rank_size=1, rank_id=0, do_shuffle="true"
|
|||
return dataset
|
||||
|
||||
|
||||
def get_train_loss():
|
||||
def get_train_loss(is_graph_mode, device_target, device_id=0):
|
||||
"""
|
||||
Transformer training.
|
||||
"""
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU", reserve_class_name_in_scope=False,
|
||||
if is_graph_mode:
|
||||
mode = ms.GRAPH_MODE
|
||||
else:
|
||||
mode = ms.PYNATIVE_MODE
|
||||
|
||||
ms.set_context(mode=mode, device_target=device_target, reserve_class_name_in_scope=False,
|
||||
enable_graph_kernel=False)
|
||||
|
||||
# Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
|
||||
if ms.get_context("mode") == ms.PYNATIVE_MODE:
|
||||
ms.set_context(mempool_block_size="31GB")
|
||||
|
@ -156,7 +162,7 @@ def get_train_loss():
|
|||
rank_id = 0
|
||||
|
||||
dataset = create_transformer_dynamic_dataset(rank_size=device_num, rank_id=rank_id, do_shuffle=True)
|
||||
netwithloss = TransformerNetworkWithLoss(True)
|
||||
netwithloss = TransformerNetworkWithLoss(True, is_graph_mode=is_graph_mode)
|
||||
|
||||
hidden_size = 1024
|
||||
learning_rate = 1.0
|
||||
|
@ -177,8 +183,12 @@ def get_train_loss():
|
|||
update_cell = scale_manager.get_update_cell()
|
||||
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
data_col = Tensor(shape=[BATCH_SIZE, None], dtype=ms.int64)
|
||||
netwithgrads.set_inputs(data_col, data_col, data_col, data_col, data_col, data_col, data_col)
|
||||
if is_graph_mode:
|
||||
data_col_int64 = Tensor(shape=[BATCH_SIZE, None], dtype=ms.int64)
|
||||
data_col = Tensor(shape=[BATCH_SIZE, None], dtype=ms.float32)
|
||||
netwithgrads.set_inputs(data_col_int64, data_col_int64, data_col_int64,
|
||||
data_col_int64, data_col_int64, data_col_int64,
|
||||
data_col)
|
||||
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
|
@ -187,15 +197,29 @@ def get_train_loss():
|
|||
return loss_list
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train():
|
||||
def test_train_graph_mode_gpu():
|
||||
"""
|
||||
Feature: Test the simplified dynamic shape transformer network with small data.
|
||||
Description: The sequence length of inputs is dynamic.
|
||||
Expectation: Assert that the training loss of fixed data is consistent with the expected loss.
|
||||
"""
|
||||
graph_loss = get_train_loss()
|
||||
graph_loss = get_train_loss(True, "GPU")
|
||||
expect_loss = [11.193909]
|
||||
assert np.allclose(graph_loss, expect_loss, 5e-3, 5e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train_pynative_mode_gpu():
|
||||
"""
|
||||
Feature: Test the simplified dynamic shape transformer network with small data.
|
||||
Description: The sequence length of inputs is dynamic.
|
||||
Expectation: Assert that the training loss of fixed data is consistent with the expected loss.
|
||||
"""
|
||||
graph_loss = get_train_loss(False, "GPU")
|
||||
expect_loss = [11.112342]
|
||||
assert np.allclose(graph_loss[0], expect_loss, 5e-3, 5e-3)
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer for training."""
|
||||
import mindspore
|
||||
from mindspore import jit
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -25,7 +27,9 @@ from tests.st.dynamic_shape.transformer.transformer_model import TransformerMode
|
|||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
BATCH_SIZE_VALUE = 32
|
||||
VOCAB_SIZE = 36560
|
||||
LABEL_SMOOTHING = 0.1
|
||||
BATCH_SIZE = 32
|
||||
|
||||
clip_grad = ops.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
@ -55,12 +59,21 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
|
||||
|
||||
class TransformerTrainingLoss(nn.Cell):
|
||||
def __init__(self):
|
||||
"""
|
||||
Provide transformer training loss.
|
||||
|
||||
Args:
|
||||
is_graph_mode (bool): is graph mode.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
def __init__(self, is_graph_mode):
|
||||
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
|
||||
self.vocab_size = 36560
|
||||
self.vocab_size = VOCAB_SIZE
|
||||
self.onehot = ops.OneHot()
|
||||
self.on_value = Tensor(float(1 - 0.1), mindspore.float32)
|
||||
self.off_value = Tensor(0.1 / float(self.vocab_size - 1), mindspore.float32)
|
||||
self.on_value = Tensor(float(1 - LABEL_SMOOTHING), ms.float32)
|
||||
self.off_value = Tensor(LABEL_SMOOTHING / float(self.vocab_size - 1), ms.float32)
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_mean = ops.ReduceMean()
|
||||
self.reshape = ops.Reshape()
|
||||
|
@ -68,20 +81,23 @@ class TransformerTrainingLoss(nn.Cell):
|
|||
self.flatten = ops.Flatten()
|
||||
self.neg = ops.Neg()
|
||||
self.cast = ops.Cast()
|
||||
self.batch_size = BATCH_SIZE_VALUE
|
||||
self.batch_size = BATCH_SIZE
|
||||
self.is_graph_mode = is_graph_mode
|
||||
|
||||
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
|
||||
"""Defines the computation performed."""
|
||||
flat_shape = (self.batch_size * seq_length,)
|
||||
flat_shape = (-1,)
|
||||
if self.is_graph_mode:
|
||||
flat_shape = (-1,)
|
||||
label_ids = self.reshape(label_ids, flat_shape)
|
||||
label_weights = self.cast(self.reshape(label_weights, flat_shape), mindspore.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
label_weights = self.cast(self.reshape(label_weights, flat_shape), ms.float32)
|
||||
one_hot_labels = self.onehot(self.cast(label_ids, ms.int32), self.cast(self.vocab_size, ms.int32),
|
||||
self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + \
|
||||
self.cast(ops.tuple_to_array((1e-5,)), mindspore.float32)
|
||||
self.cast(ops.tuple_to_array((1e-5,)), ms.float32)
|
||||
loss = numerator / denominator
|
||||
return loss
|
||||
|
||||
|
@ -93,16 +109,19 @@ class TransformerNetworkWithLoss(nn.Cell):
|
|||
Args:
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
is_graph_mode (bool): is graph mode.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, is_training, use_one_hot_embeddings=False):
|
||||
def __init__(self, is_training, use_one_hot_embeddings=False, is_graph_mode=False):
|
||||
super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False)
|
||||
self.transformer = TransformerModel(is_training, use_one_hot_embeddings)
|
||||
self.loss = TransformerTrainingLoss()
|
||||
self.transformer = TransformerModel(is_training, use_one_hot_embeddings, is_graph_mode)
|
||||
self.loss = TransformerTrainingLoss(is_graph_mode)
|
||||
self.cast = ops.Cast()
|
||||
self.shape = ops.TensorShape()
|
||||
self.shape = ops.Shape()
|
||||
if is_graph_mode:
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
|
@ -115,7 +134,7 @@ class TransformerNetworkWithLoss(nn.Cell):
|
|||
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
||||
seq_length = self.shape(source_ids)[1]
|
||||
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
|
||||
return self.cast(total_loss, mindspore.float32)
|
||||
return self.cast(total_loss, ms.float32)
|
||||
|
||||
|
||||
grad_scale = ops.MultitypeFuncGraph("grad_scale")
|
||||
|
@ -158,7 +177,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
|
|||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mindspore.float32))
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=ms.float32))
|
||||
self.enable_tuple_broaden = True
|
||||
|
||||
@jit
|
||||
def clip_grads(self, grads):
|
||||
grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
return grads
|
||||
|
||||
@jit
|
||||
def clip_scale_grads(self, scale, grads):
|
||||
grads = self.hyper_map(ops.partial(grad_scale, scale * self.degree), grads)
|
||||
return grads
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
|
@ -195,12 +225,12 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
|
|||
label_ids,
|
||||
label_weights,
|
||||
self.cast(scaling_sens,
|
||||
mindspore.float32))
|
||||
ms.float32))
|
||||
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
grads = self.clip_scale_grads(scaling_sens, grads)
|
||||
grads = self.clip_grads(grads)
|
||||
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
|
@ -208,4 +238,39 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
|
|||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, scaling_sens)
|
||||
return (loss, cond, scaling_sens.value())
|
||||
|
||||
|
||||
cast = ops.Cast()
|
||||
add_grads = ops.MultitypeFuncGraph("add_grads")
|
||||
|
||||
|
||||
@add_grads.register("Tensor", "Tensor")
|
||||
def _add_grads(accu_grad, grad):
|
||||
return accu_grad + cast(grad, ms.float32)
|
||||
|
||||
update_accu_grads = ops.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return ops.depend(succ, ops.assign(accu_grad, cast(grad, ms.float32)))
|
||||
|
||||
accumulate_accu_grads = ops.MultitypeFuncGraph("accumulate_accu_grads")
|
||||
|
||||
|
||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||
def _accumulate_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return ops.depend(succ, ops.assign_add(accu_grad, cast(grad, ms.float32)))
|
||||
|
||||
|
||||
zeroslike = ops.ZerosLike()
|
||||
reset_accu_grads = ops.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return ops.depend(succ, ops.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
|
|
@ -23,234 +23,30 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.operations import array_ops as array_op
|
||||
import mindspore.ops.function as func
|
||||
from mindspore.ops.operations import _inner_ops as inner_op
|
||||
|
||||
BATCH_SIZE_VALUE = 32
|
||||
INF = 1. * 1e9
|
||||
BATCH_SIZE = 32
|
||||
HIDDEN_SIZE = 1024
|
||||
NUMBER_HIDDEN_LAYERS = 6
|
||||
MAX_DECODE_LENGTH = 80
|
||||
VOCAB_SIZE = 36560
|
||||
MAX_POSITION_EMBEDDINGS = 128
|
||||
INTERMEDIATE_SIZE = 4096
|
||||
INIT_RANGE = 0.02
|
||||
HIDDEN_DROPOUT_PROB = 0.2
|
||||
NUM_ATTENTION_HEADS = 16
|
||||
ATTENTION_PROBS_DROPOUT_PROB = 0.2
|
||||
HIDDEN_ACT = "relu"
|
||||
COMPUTE_TYPE = ms.float16
|
||||
DTYPE = ms.float32
|
||||
BEAM_WIDTH = 4
|
||||
SEQ_LENGTH = 128
|
||||
LENGTH_PENALTY_WEIGHT = 1.0
|
||||
|
||||
|
||||
class LengthPenalty(nn.Cell):
|
||||
"""
|
||||
Normalize scores of translations according to their length.
|
||||
|
||||
Args:
|
||||
weight (float): Weight of length penalty. Default: 1.0.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
weight=1.0,
|
||||
compute_type=ms.float32):
|
||||
super(LengthPenalty, self).__init__()
|
||||
self.weight = weight
|
||||
self.add = ops.Add()
|
||||
self.pow = ops.Pow()
|
||||
self.div = ops.RealDiv()
|
||||
self.cast = ops.Cast()
|
||||
self.five = Tensor(5.0, ms.float32)
|
||||
self.six = Tensor(6.0, ms.float32)
|
||||
|
||||
def construct(self, length_tensor):
|
||||
length_tensor = self.cast(length_tensor, ms.float32)
|
||||
output = self.add(length_tensor, self.five)
|
||||
output = self.div(output, self.six)
|
||||
output = self.pow(output, self.weight)
|
||||
return output
|
||||
|
||||
|
||||
class Mod(nn.Cell):
|
||||
"""
|
||||
Mod function.
|
||||
|
||||
Args:
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
compute_type=ms.float32):
|
||||
super(Mod, self).__init__()
|
||||
self.compute_type = compute_type
|
||||
self.floor_div = ops.FloorDiv()
|
||||
self.sub = ops.Sub()
|
||||
self.multiply = ops.Mul()
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
x = self.floor_div(input_x, input_y)
|
||||
x = self.multiply(x, input_y)
|
||||
x = self.sub(input_x, x)
|
||||
return x
|
||||
|
||||
|
||||
class BeamSearchDecoder(nn.Cell):
|
||||
"""
|
||||
Beam search decoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence.
|
||||
vocab_size (int): Size of vocabulary.
|
||||
decoder (:class:`TransformerDecoderStep`): Decoder module.
|
||||
beam_width (int): beam width setting. Default: 4.
|
||||
length_penalty_weight (float): Weight of length penalty. Default: 1.0.
|
||||
max_decode_length (int): max decode length. Default: 128.
|
||||
sos_id (int): Id of sequence start token. Default: 1.
|
||||
eos_id (int): Id of sequence end token. Default: 2.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
vocab_size,
|
||||
decoder,
|
||||
beam_width=4,
|
||||
length_penalty_weight=1.0,
|
||||
max_decode_length=128,
|
||||
sos_id=1,
|
||||
eos_id=2,
|
||||
compute_type=ms.float32):
|
||||
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
|
||||
self.seq_length = seq_length
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.beam_width = beam_width
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.max_decode_length = max_decode_length
|
||||
self.decoder = decoder
|
||||
|
||||
self.add = ops.Add()
|
||||
self.expand = ops.ExpandDims()
|
||||
self.reshape = ops.Reshape()
|
||||
self.shape_flat = (-1,)
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), ms.float32)
|
||||
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), ms.float32)
|
||||
|
||||
self.select = ops.Select()
|
||||
self.flat_shape = (batch_size, beam_width * vocab_size)
|
||||
self.topk = ops.TopK(sorted=True)
|
||||
self.floor_div = ops.FloorDiv()
|
||||
self.vocab_size_tensor = Tensor(self.vocab_size, ms.int32)
|
||||
self.real_div = ops.RealDiv()
|
||||
self.mod = Mod()
|
||||
self.equal = ops.Equal()
|
||||
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), ms.int32)
|
||||
|
||||
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
|
||||
self.beam_ids = Tensor(beam_ids, ms.int32)
|
||||
batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width
|
||||
self.batch_ids = Tensor(batch_ids, ms.int32)
|
||||
self.concat = ops.Concat(axis=-1)
|
||||
self.gather_nd = ops.GatherNd()
|
||||
|
||||
self.greater_equal = ops.GreaterEqual()
|
||||
self.sub = ops.Sub()
|
||||
self.cast = ops.Cast()
|
||||
self.zeroslike = ops.ZerosLike()
|
||||
|
||||
# init inputs and states
|
||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), ms.int32)
|
||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), ms.int32)
|
||||
init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1])
|
||||
self.init_scores = Tensor(init_scores, ms.float32)
|
||||
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
|
||||
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
||||
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
|
||||
self.one = Tensor(1, ms.int32)
|
||||
|
||||
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||
state_seq, state_finished, state_length):
|
||||
"""
|
||||
One step for decode
|
||||
"""
|
||||
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
|
||||
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
|
||||
|
||||
# select topk indices
|
||||
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
|
||||
|
||||
# mask finished beams
|
||||
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
|
||||
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
|
||||
|
||||
# reshape scores to [batch, beam*vocab]
|
||||
flat_scores = self.reshape(total_log_probs, self.flat_shape)
|
||||
# select topk
|
||||
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
||||
|
||||
temp = topk_indices
|
||||
beam_indices = self.zeroslike(topk_indices)
|
||||
for _ in range(self.beam_width - 1):
|
||||
temp = self.sub(temp, self.vocab_size_tensor)
|
||||
res = self.cast(self.greater_equal(temp, 0), ms.int32)
|
||||
beam_indices = beam_indices + res
|
||||
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
|
||||
#======================================================================
|
||||
|
||||
# mask finished indices
|
||||
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
||||
word_indices = self.select(state_finished, self.eos_ids, word_indices)
|
||||
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
|
||||
|
||||
###### put finished sequences to the end
|
||||
# sort according to scores with -inf for finished beams
|
||||
tmp_log_probs = self.select(
|
||||
self.equal(word_indices, self.eos_ids),
|
||||
self.ninf_tensor,
|
||||
topk_scores)
|
||||
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
|
||||
# update
|
||||
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
|
||||
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
|
||||
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
|
||||
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
|
||||
|
||||
###### generate new beam_search states
|
||||
# gather indices for selecting alive beams
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
|
||||
|
||||
# length add 1 if not finished in the previous step
|
||||
length_add = self.add(state_length, self.one)
|
||||
state_length = self.select(state_finished, state_length, length_add)
|
||||
state_length = self.gather_nd(state_length, gather_indices)
|
||||
|
||||
# concat seq
|
||||
seq = self.gather_nd(state_seq, gather_indices)
|
||||
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
||||
|
||||
# new finished flag and log_probs
|
||||
state_finished = self.equal(word_indices, self.eos_ids)
|
||||
state_log_probs = topk_scores
|
||||
|
||||
###### generate new inputs and decoder states
|
||||
cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1))
|
||||
return cur_input_ids, state_log_probs, state_seq, state_finished, state_length
|
||||
|
||||
def construct(self, enc_states, enc_attention_mask):
|
||||
"""Get beam search result."""
|
||||
cur_input_ids = self.start_ids
|
||||
# beam search states
|
||||
state_log_probs = self.init_scores
|
||||
state_seq = self.init_seq
|
||||
state_finished = self.init_finished
|
||||
state_length = self.init_length
|
||||
|
||||
for _ in range(self.max_decode_length):
|
||||
# run one step decoder to get outputs of the current step
|
||||
cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step(
|
||||
cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length)
|
||||
|
||||
# add length penalty scores
|
||||
penalty_len = self.length_penalty(state_length)
|
||||
# get penalty length
|
||||
log_probs = self.real_div(state_log_probs, penalty_len)
|
||||
|
||||
# sort according to scores
|
||||
_, top_beam_indices = self.topk(log_probs, self.beam_width)
|
||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
||||
# sort sequence
|
||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
||||
# take the first one
|
||||
predicted_ids = predicted_ids[::, 0:1:1, ::]
|
||||
return predicted_ids
|
||||
def normal_weight(shape, num_units):
|
||||
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
||||
|
||||
|
||||
def _average_units(shape):
|
||||
|
@ -275,11 +71,6 @@ def weight_variable(shape):
|
|||
return Tensor(values)
|
||||
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
@ -291,12 +82,14 @@ class EmbeddingLookup(nn.Cell):
|
|||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.is_graph_mode = is_graph_mode
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
|
@ -310,10 +103,14 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.off_value = Tensor(0.0, ms.float32)
|
||||
self.array_mul = ops.MatMul()
|
||||
self.reshape = ops.Reshape()
|
||||
self.shape = ops.TensorShape()
|
||||
self.shape = ops.Shape()
|
||||
if is_graph_mode:
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""Get a embeddings lookup table with a fixed dictionary and size."""
|
||||
input_shape = self.shape(input_ids)
|
||||
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
@ -321,9 +118,11 @@ class EmbeddingLookup(nn.Cell):
|
|||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
out_shape = (self.batch_size, -1, self.embedding_size)
|
||||
out_shape = input_shape + (self.embedding_size,)
|
||||
if self.is_graph_mode:
|
||||
out_shape = (self.batch_size, -1, self.embedding_size)
|
||||
output = self.reshape(output_for_reshape, out_shape)
|
||||
return output, self.embedding_table
|
||||
return output, self.embedding_table.value()
|
||||
|
||||
|
||||
def position_encoding(length,
|
||||
|
@ -364,6 +163,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
embedding_size,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
|
@ -378,7 +178,9 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.expand_dims = ops.ExpandDims()
|
||||
self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size),
|
||||
ms.float32)
|
||||
self.shape = ops.TensorShape()
|
||||
self.shape = ops.Shape()
|
||||
if is_graph_mode:
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
def construct(self, word_embeddings):
|
||||
"""Postprocessors apply positional embeddings to word embeddings."""
|
||||
|
@ -472,6 +274,7 @@ class MultiheadAttention(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
from_tensor_width,
|
||||
to_tensor_width,
|
||||
|
@ -490,6 +293,7 @@ class MultiheadAttention(nn.Cell):
|
|||
compute_type=ms.float32):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.is_graph_mode = is_graph_mode
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.size_per_head = size_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
|
@ -551,10 +355,16 @@ class MultiheadAttention(nn.Cell):
|
|||
def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None):
|
||||
"""Apply multihead attention."""
|
||||
from_seq_length = seq_length
|
||||
shape_from = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
|
||||
shape_to = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
|
||||
to_seq_length = enc_seq_length
|
||||
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head)
|
||||
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head)
|
||||
if self.is_graph_mode:
|
||||
shape_from = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
|
||||
shape_to = (self.batch_size, -1, self.num_attention_heads, self.size_per_head)
|
||||
if self.do_return_2d_tensor:
|
||||
shape_return = (-1, self.num_attention_heads * self.size_per_head)
|
||||
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head)
|
||||
if self.is_graph_mode:
|
||||
shape_return = (-1, self.num_attention_heads * self.size_per_head)
|
||||
else:
|
||||
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head)
|
||||
|
||||
|
@ -616,6 +426,7 @@ class SelfAttention(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_attention_heads=16,
|
||||
|
@ -634,6 +445,7 @@ class SelfAttention(nn.Cell):
|
|||
self.is_encdec_att = is_encdec_att
|
||||
|
||||
self.attention = MultiheadAttention(
|
||||
is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
from_tensor_width=hidden_size,
|
||||
to_tensor_width=hidden_size,
|
||||
|
@ -737,6 +549,7 @@ class EncoderCell(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
|
@ -749,6 +562,7 @@ class EncoderCell(nn.Cell):
|
|||
compute_type=ms.float32):
|
||||
super(EncoderCell, self).__init__()
|
||||
self.attention = SelfAttention(
|
||||
is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
|
@ -795,6 +609,7 @@ class TransformerEncoder(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_hidden_layers,
|
||||
|
@ -809,11 +624,13 @@ class TransformerEncoder(nn.Cell):
|
|||
super(TransformerEncoder, self).__init__()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.batch_size = batch_size
|
||||
self.is_graph_mode = is_graph_mode
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = EncoderCell(batch_size=batch_size,
|
||||
layer = EncoderCell(is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
|
@ -833,7 +650,9 @@ class TransformerEncoder(nn.Cell):
|
|||
|
||||
def construct(self, input_tensor, attention_mask, seq_length):
|
||||
"""Apply encoder."""
|
||||
out_shape = (self.batch_size, -1, self.hidden_size)
|
||||
out_shape = (self.batch_size, seq_length, self.hidden_size)
|
||||
if self.is_graph_mode:
|
||||
out_shape = (self.batch_size, -1, self.hidden_size)
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
for layer_module in self.layers:
|
||||
|
@ -865,6 +684,7 @@ class DecoderCell(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=12,
|
||||
|
@ -877,6 +697,7 @@ class DecoderCell(nn.Cell):
|
|||
compute_type=ms.float32):
|
||||
super(DecoderCell, self).__init__()
|
||||
self.self_attention = SelfAttention(
|
||||
is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
|
@ -887,6 +708,7 @@ class DecoderCell(nn.Cell):
|
|||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
self.cross_attention = SelfAttention(
|
||||
is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
|
@ -937,6 +759,7 @@ class TransformerDecoder(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
num_hidden_layers,
|
||||
|
@ -953,7 +776,8 @@ class TransformerDecoder(nn.Cell):
|
|||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = DecoderCell(batch_size=batch_size,
|
||||
layer = DecoderCell(is_graph_mode=is_graph_mode,
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
|
@ -972,10 +796,13 @@ class TransformerDecoder(nn.Cell):
|
|||
self.shape = (-1, hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.is_graph_mode = is_graph_mode
|
||||
|
||||
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
|
||||
"""Apply decoder."""
|
||||
out_shape = (self.batch_size, -1, self.hidden_size)
|
||||
out_shape = (self.batch_size, seq_length, self.hidden_size)
|
||||
if self.is_graph_mode:
|
||||
out_shape = (self.batch_size, -1, self.hidden_size)
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
for layer_module in self.layers:
|
||||
|
@ -989,19 +816,36 @@ class TransformerDecoder(nn.Cell):
|
|||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
def __init__(self):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
is_graph_mode (bool): is graph mode.
|
||||
"""
|
||||
def __init__(self, is_graph_mode):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.cast = ops.Cast()
|
||||
self.reshape = ops.Reshape()
|
||||
self.shape = ops.TensorShape()
|
||||
self.shape = ops.Shape()
|
||||
if is_graph_mode:
|
||||
self.shape = ops.TensorShape()
|
||||
self.batch_matmul = ops.BatchMatMul()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
self.is_graph_mode = is_graph_mode
|
||||
|
||||
def construct(self, input_mask):
|
||||
"""Create attention mask according to input mask."""
|
||||
input_shape = self.shape(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
shape_left = input_shape + (1,)
|
||||
|
||||
input_mask = self.cast(input_mask, ms.float32)
|
||||
mask_left = self.expand_dims(input_mask, 2)
|
||||
mask_right = self.expand_dims(input_mask, 1)
|
||||
if self.is_graph_mode:
|
||||
mask_left = self.expand_dims(input_mask, 2)
|
||||
mask_right = self.expand_dims(input_mask, 1)
|
||||
else:
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
attention_mask = self.batch_matmul(mask_left, mask_right)
|
||||
|
||||
return attention_mask
|
||||
|
@ -1019,11 +863,13 @@ class PredLogProbs(nn.Cell):
|
|||
dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: ms.float32.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_graph_mode,
|
||||
batch_size,
|
||||
width,
|
||||
compute_type=ms.float32,
|
||||
dtype=ms.float32):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.is_graph_mode = is_graph_mode
|
||||
self.batch_size = batch_size
|
||||
self.width = width
|
||||
self.compute_type = compute_type
|
||||
|
@ -1039,7 +885,9 @@ class PredLogProbs(nn.Cell):
|
|||
output_weights,
|
||||
seq_length):
|
||||
"""Get log probs."""
|
||||
shape_flat_sequence_tensor = (-1, self.width)
|
||||
shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width)
|
||||
if self.is_graph_mode:
|
||||
shape_flat_sequence_tensor = (-1, self.width)
|
||||
|
||||
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor)
|
||||
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||
|
@ -1052,110 +900,6 @@ class PredLogProbs(nn.Cell):
|
|||
return log_probs
|
||||
|
||||
|
||||
class TransformerDecoderStep(nn.Cell):
|
||||
"""
|
||||
Multi-layer transformer decoder step.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
hidden_size (int): Size of the encoder layers.
|
||||
max_decode_length (int): Max decode length.
|
||||
enc_seq_length (int): Length of source sentences.
|
||||
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||
num_attention_heads (int): Number of attention heads in encoder cells. Default: 16.
|
||||
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
SelfAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type. Default: ms.float32.
|
||||
embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module.
|
||||
embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module.
|
||||
projection (:class:`PredLogProbs`): PredLogProbs module
|
||||
"""
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
max_decode_length,
|
||||
num_hidden_layers,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
attention_probs_dropout_prob=0.3,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.3,
|
||||
hidden_act="relu",
|
||||
compute_type=ms.float32,
|
||||
embedding_lookup=None,
|
||||
embedding_processor=None,
|
||||
projection=None):
|
||||
super(TransformerDecoderStep, self).__init__(auto_prefix=False)
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
self.tfm_embedding_lookup = embedding_lookup
|
||||
self.tfm_embedding_processor = embedding_processor
|
||||
self.projection = projection
|
||||
|
||||
self.tfm_decoder = TransformerDecoder(
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
|
||||
self.ones_like = ops.OnesLike()
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
|
||||
self.expand = ops.ExpandDims()
|
||||
self.multiply = ops.Mul()
|
||||
|
||||
ones = np.ones(shape=(max_decode_length, max_decode_length))
|
||||
self.future_mask = Tensor(np.tril(ones), dtype=ms.float32)
|
||||
|
||||
self.cast_compute_type = CastWrapper(dst_type=compute_type)
|
||||
|
||||
def construct(self, input_ids, enc_states, enc_attention_mask, seq_length):
|
||||
"""
|
||||
Multi-layer transformer decoder step.
|
||||
input_ids: [batch_size * beam_width]
|
||||
"""
|
||||
# process embedding
|
||||
input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids)
|
||||
input_embedding = self.tfm_embedding_processor(input_embedding)
|
||||
input_embedding = self.cast_compute_type(input_embedding)
|
||||
|
||||
input_shape = self.shape(input_ids)
|
||||
input_len = input_shape[1]
|
||||
future_mask = self.future_mask[0:input_len:1, 0:input_len:1]
|
||||
|
||||
input_mask = self.ones_like(input_ids)
|
||||
input_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
input_mask = self.multiply(input_mask, self.expand(future_mask, 0))
|
||||
input_mask = self.cast_compute_type(input_mask)
|
||||
|
||||
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
|
||||
|
||||
# call TransformerDecoder
|
||||
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length)
|
||||
|
||||
# take the last step
|
||||
decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
|
||||
|
||||
# projection and log_prob
|
||||
log_probs = self.projection(decoder_output, embedding_tables, 1)
|
||||
|
||||
return log_probs
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_np_to_tensor_encoder(seq_length):
|
||||
ones = np.ones(shape=(seq_length, seq_length))
|
||||
|
@ -1169,118 +913,97 @@ class TransformerModel(nn.Cell):
|
|||
Args:
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
is_graph_mode (bool): is graph mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
use_one_hot_embeddings=False,
|
||||
is_graph_mode=False):
|
||||
super(TransformerModel, self).__init__()
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = BATCH_SIZE_VALUE
|
||||
self.hidden_size = 1024
|
||||
self.num_hidden_layers = 6
|
||||
self.embedding_size = 1024
|
||||
|
||||
self.batch_size = BATCH_SIZE
|
||||
self.is_graph_mode = is_graph_mode
|
||||
self.hidden_size = HIDDEN_SIZE
|
||||
self.num_hidden_layers = NUMBER_HIDDEN_LAYERS
|
||||
self.embedding_size = HIDDEN_SIZE
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
self.beam_width = 4
|
||||
self.max_decode_length = 80
|
||||
self.beam_width = BEAM_WIDTH
|
||||
self.max_decode_length = MAX_DECODE_LENGTH
|
||||
|
||||
self.tfm_embedding_lookup = EmbeddingLookup(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
batch_size=self.batch_size,
|
||||
vocab_size=36560,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02)
|
||||
initializer_range=INIT_RANGE)
|
||||
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=128,
|
||||
dropout_prob=0.2)
|
||||
max_position_embeddings=MAX_POSITION_EMBEDDINGS,
|
||||
dropout_prob=HIDDEN_DROPOUT_PROB)
|
||||
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=128,
|
||||
dropout_prob=0.2)
|
||||
max_position_embeddings=MAX_POSITION_EMBEDDINGS,
|
||||
dropout_prob=HIDDEN_DROPOUT_PROB)
|
||||
self.tfm_encoder = TransformerEncoder(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
batch_size=self.batch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=16,
|
||||
num_attention_heads=NUM_ATTENTION_HEADS,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=4096,
|
||||
attention_probs_dropout_prob=0.2,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
attention_probs_dropout_prob=ATTENTION_PROBS_DROPOUT_PROB,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.2,
|
||||
hidden_act="relu",
|
||||
compute_type=ms.float16)
|
||||
initializer_range=INIT_RANGE,
|
||||
hidden_dropout_prob=HIDDEN_DROPOUT_PROB,
|
||||
hidden_act=HIDDEN_ACT,
|
||||
compute_type=COMPUTE_TYPE)
|
||||
|
||||
if is_training:
|
||||
self.projection = PredLogProbs(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
batch_size=self.batch_size,
|
||||
width=self.hidden_size,
|
||||
compute_type=ms.float16,
|
||||
dtype=ms.float32)
|
||||
compute_type=COMPUTE_TYPE,
|
||||
dtype=DTYPE)
|
||||
self.tfm_decoder = TransformerDecoder(
|
||||
is_graph_mode=self.is_graph_mode,
|
||||
batch_size=self.batch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=16,
|
||||
num_attention_heads=NUM_ATTENTION_HEADS,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=4096,
|
||||
attention_probs_dropout_prob=0.2,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
attention_probs_dropout_prob=ATTENTION_PROBS_DROPOUT_PROB,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.2,
|
||||
hidden_act="relu",
|
||||
compute_type=ms.float16)
|
||||
else:
|
||||
self.projection = PredLogProbs(
|
||||
batch_size=self.batch_size * 4,
|
||||
width=self.hidden_size,
|
||||
compute_type=ms.float16,
|
||||
dtype=ms.float32)
|
||||
self.tfm_decoder = TransformerDecoderStep(
|
||||
batch_size=self.batch_size * 4,
|
||||
hidden_size=self.hidden_size,
|
||||
max_decode_length=80,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
attention_probs_dropout_prob=0.2,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.2,
|
||||
hidden_act="relu",
|
||||
compute_type=ms.float16,
|
||||
embedding_lookup=self.tfm_embedding_lookup,
|
||||
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
|
||||
projection=self.projection)
|
||||
self.tfm_decoder = BeamSearchDecoder(
|
||||
batch_size=BATCH_SIZE_VALUE,
|
||||
seq_length=128,
|
||||
vocab_size=36560,
|
||||
decoder=self.tfm_decoder,
|
||||
beam_width=4,
|
||||
length_penalty_weight=1.0,
|
||||
max_decode_length=80)
|
||||
|
||||
self.tfm_decoder.add_flags(loop_can_unroll=True)
|
||||
ones = np.ones(shape=(self.batch_size, self.max_decode_length))
|
||||
self.encdec_mask = Tensor(ones, ms.float32)
|
||||
initializer_range=INIT_RANGE,
|
||||
hidden_dropout_prob=HIDDEN_DROPOUT_PROB,
|
||||
hidden_act=HIDDEN_ACT,
|
||||
compute_type=COMPUTE_TYPE)
|
||||
|
||||
self.cast = ops.Cast()
|
||||
self.dtype = ms.float32
|
||||
self.cast_compute_type = CastWrapper(dst_type=ms.float16)
|
||||
self.dtype = DTYPE
|
||||
self.cast_compute_type = CastWrapper(dst_type=COMPUTE_TYPE)
|
||||
self.expand = ops.ExpandDims()
|
||||
self.multiply = ops.Mul()
|
||||
self.shape = ops.TensorShape()
|
||||
self.shape = ops.Shape()
|
||||
if self.is_graph_mode:
|
||||
self.shape = ops.TensorShape()
|
||||
self.tril = array_op.Tril()
|
||||
self.dynamic_broadcast_to = inner_op.DynamicBroadcastTo()
|
||||
self.ones_like = array_op.OnesLike()
|
||||
self.stack = array_op.Stack(0)
|
||||
self.concatenate = array_op.Concat()
|
||||
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.is_graph_mode)
|
||||
|
||||
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
|
||||
"""Transformer with encoder and decoder."""
|
||||
|
@ -1296,11 +1019,16 @@ class TransformerModel(nn.Cell):
|
|||
self.cast_compute_type(enc_attention_mask),
|
||||
seq_length)
|
||||
|
||||
ones_inner = self.ones_like(source_ids[0, :])
|
||||
seq_length = self.shape(ones_inner)
|
||||
broadcast_shape = self.concatenate([seq_length, seq_length])
|
||||
ones = func.broadcast_to(ones_inner, broadcast_shape)
|
||||
future_mask = self.tril(ones)
|
||||
|
||||
if self.is_graph_mode:
|
||||
ones_inner = self.ones_like(source_ids[0, :])
|
||||
seq_length = self.shape(ones_inner)
|
||||
broadcast_shape = self.concatenate([seq_length, seq_length])
|
||||
ones = self.dynamic_broadcast_to(ones_inner, broadcast_shape)
|
||||
future_mask = self.tril(ones)
|
||||
else:
|
||||
future_mask = convert_np_to_tensor_encoder(seq_length)
|
||||
|
||||
|
||||
# process target sentence
|
||||
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
|
||||
|
|
Loading…
Reference in New Issue