forked from mindspore-Ecosystem/mindspore
!7945 Add self-realized Softmax and LayerNorm
Merge pull request !7945 from wangjun/master
This commit is contained in:
commit
fc8479a459
|
@ -25,6 +25,61 @@ from mindspore.common.initializer import TruncatedNormal, initializer
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
class LayerNorm(nn.Cell):
|
||||||
|
"""
|
||||||
|
Layer Normalization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape: the corresponding shape of the normalized axes
|
||||||
|
eps: epsilon, a small number avoiding zero division
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
x: input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rescaled_output: Tensor, returned tensor after layernorm
|
||||||
|
"""
|
||||||
|
def __init__(self, normalized_shape, eps=1e-5):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma")
|
||||||
|
self.beta = Parameter(initializer('zeros', normalized_shape), name="beta")
|
||||||
|
self.mean = P.ReduceMean(keep_dims=True)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
mean = self.mean(x, -1)
|
||||||
|
variance = self.mean(F.square(x - mean))
|
||||||
|
output = (x - mean) / F.sqrt(variance + self.eps)
|
||||||
|
rescaled_output = output * self.gamma + self.beta
|
||||||
|
return rescaled_output
|
||||||
|
|
||||||
|
class Softmax(nn.Cell):
|
||||||
|
"""
|
||||||
|
softmax realization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis: the axis to be applied softmax
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
x: input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: Tensor, returned tensor after softmax
|
||||||
|
"""
|
||||||
|
def __init__(self, axis=-1):
|
||||||
|
super(Softmax, self).__init__()
|
||||||
|
self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True)
|
||||||
|
self.sum = P.ReduceSum(keep_dims=True)
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
_, max_value = self.max(x)
|
||||||
|
exp_x = F.tensor_pow(np.e, x - max_value)
|
||||||
|
sum_x = self.sum(exp_x, self.axis)
|
||||||
|
output = exp_x / sum_x
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Mapping(nn.Cell):
|
class Mapping(nn.Cell):
|
||||||
"""
|
"""
|
||||||
A mapping function with a 3d input
|
A mapping function with a 3d input
|
||||||
|
@ -162,7 +217,6 @@ class Attention(nn.Cell):
|
||||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.get_attention_mask = AttentionMask(config)
|
self.get_attention_mask = AttentionMask(config)
|
||||||
self.expand_mapping = Mapping(config.embedding_size, 3*config.embedding_size, config.compute_dtype)
|
|
||||||
self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
|
self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
|
||||||
self.split = P.Split(axis=-1, output_num=3)
|
self.split = P.Split(axis=-1, output_num=3)
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
|
@ -182,7 +236,6 @@ class Attention(nn.Cell):
|
||||||
self.use_past = config.use_past
|
self.use_past = config.use_past
|
||||||
self.dropout = nn.Dropout(1-config.dropout_rate)
|
self.dropout = nn.Dropout(1-config.dropout_rate)
|
||||||
self.prob_dropout = nn.Dropout(1-config.dropout_rate)
|
self.prob_dropout = nn.Dropout(1-config.dropout_rate)
|
||||||
self.softmax = nn.Softmax()
|
|
||||||
|
|
||||||
self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
|
self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
|
||||||
self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
|
self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
|
||||||
|
@ -285,9 +338,7 @@ class Attention(nn.Cell):
|
||||||
attention_scores = adder + score
|
attention_scores = adder + score
|
||||||
|
|
||||||
attention_scores = P.Cast()(attention_scores, ori_dtype)
|
attention_scores = P.Cast()(attention_scores, ori_dtype)
|
||||||
shape = F.shape(attention_scores)
|
attention_probs = Softmax()(attention_scores)
|
||||||
attention_probs = nn.Softmax()(F.reshape(attention_scores, (-1, shape[-1])))
|
|
||||||
attention_probs = F.reshape(attention_probs, shape)
|
|
||||||
|
|
||||||
attention_probs = self.prob_dropout(attention_probs)
|
attention_probs = self.prob_dropout(attention_probs)
|
||||||
weighted_values = self.batch_matmul(attention_probs, value)
|
weighted_values = self.batch_matmul(attention_probs, value)
|
||||||
|
@ -313,9 +364,9 @@ class Block(nn.Cell):
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
scale = 1 / math.sqrt(2.0*layer_idx)
|
scale = 1 / math.sqrt(2.0*layer_idx)
|
||||||
self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
||||||
self.attention = Attention(config, scale, layer_idx)
|
self.attention = Attention(config, scale, layer_idx)
|
||||||
self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
||||||
self.output = Output(config, scale)
|
self.output = Output(config, scale)
|
||||||
self.post_layernorm_residual = config.post_layernorm_residual
|
self.post_layernorm_residual = config.post_layernorm_residual
|
||||||
|
|
||||||
|
@ -362,7 +413,7 @@ class GPT_Model(nn.Cell):
|
||||||
self.blocks = nn.CellList()
|
self.blocks = nn.CellList()
|
||||||
for i in range(config.num_layers):
|
for i in range(config.num_layers):
|
||||||
self.blocks.append(Block(config, i+1))
|
self.blocks.append(Block(config, i+1))
|
||||||
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
||||||
self.use_past = config.use_past
|
self.use_past = config.use_past
|
||||||
self.past = tuple([None]*config.num_layers)
|
self.past = tuple([None]*config.num_layers)
|
||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
|
|
Loading…
Reference in New Issue