!7945 Add self-realized Softmax and LayerNorm

Merge pull request !7945 from wangjun/master
This commit is contained in:
mindspore-ci-bot 2020-10-30 09:26:02 +08:00 committed by Gitee
commit fc8479a459
1 changed files with 59 additions and 8 deletions

View File

@ -25,6 +25,61 @@ from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class LayerNorm(nn.Cell):
"""
Layer Normalization
Args:
normalized_shape: the corresponding shape of the normalized axes
eps: epsilon, a small number avoiding zero division
Inputs:
x: input tensor
Returns:
rescaled_output: Tensor, returned tensor after layernorm
"""
def __init__(self, normalized_shape, eps=1e-5):
super(LayerNorm, self).__init__()
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma")
self.beta = Parameter(initializer('zeros', normalized_shape), name="beta")
self.mean = P.ReduceMean(keep_dims=True)
self.eps = eps
def construct(self, x):
mean = self.mean(x, -1)
variance = self.mean(F.square(x - mean))
output = (x - mean) / F.sqrt(variance + self.eps)
rescaled_output = output * self.gamma + self.beta
return rescaled_output
class Softmax(nn.Cell):
"""
softmax realization
Args:
axis: the axis to be applied softmax
Inputs:
x: input tensor
Returns:
output: Tensor, returned tensor after softmax
"""
def __init__(self, axis=-1):
super(Softmax, self).__init__()
self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True)
self.sum = P.ReduceSum(keep_dims=True)
self.axis = axis
def construct(self, x):
_, max_value = self.max(x)
exp_x = F.tensor_pow(np.e, x - max_value)
sum_x = self.sum(exp_x, self.axis)
output = exp_x / sum_x
return output
class Mapping(nn.Cell):
"""
A mapping function with a 3d input
@ -162,7 +217,6 @@ class Attention(nn.Cell):
def __init__(self, config, scale=1.0, layer_idx=None):
super(Attention, self).__init__()
self.get_attention_mask = AttentionMask(config)
self.expand_mapping = Mapping(config.embedding_size, 3*config.embedding_size, config.compute_dtype)
self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
self.split = P.Split(axis=-1, output_num=3)
self.transpose = P.Transpose()
@ -182,7 +236,6 @@ class Attention(nn.Cell):
self.use_past = config.use_past
self.dropout = nn.Dropout(1-config.dropout_rate)
self.prob_dropout = nn.Dropout(1-config.dropout_rate)
self.softmax = nn.Softmax()
self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
@ -285,9 +338,7 @@ class Attention(nn.Cell):
attention_scores = adder + score
attention_scores = P.Cast()(attention_scores, ori_dtype)
shape = F.shape(attention_scores)
attention_probs = nn.Softmax()(F.reshape(attention_scores, (-1, shape[-1])))
attention_probs = F.reshape(attention_probs, shape)
attention_probs = Softmax()(attention_scores)
attention_probs = self.prob_dropout(attention_probs)
weighted_values = self.batch_matmul(attention_probs, value)
@ -313,9 +364,9 @@ class Block(nn.Cell):
def __init__(self, config, layer_idx):
super(Block, self).__init__()
scale = 1 / math.sqrt(2.0*layer_idx)
self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.attention = Attention(config, scale, layer_idx)
self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual
@ -362,7 +413,7 @@ class GPT_Model(nn.Cell):
self.blocks = nn.CellList()
for i in range(config.num_layers):
self.blocks.append(Block(config, i+1))
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
self.use_past = config.use_past
self.past = tuple([None]*config.num_layers)
self.num_layers = config.num_layers