forked from mindspore-Ecosystem/mindspore
!35210 修复 ST probabilistic failure in daily version
Merge pull request !35210 from 王程浩/master
This commit is contained in:
commit
c79d9cb6d6
|
@ -0,0 +1,911 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert for pretraining."""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from src.bert_model import BertModel
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
class GetMaskedLMOutput(nn.Cell):
|
||||
"""
|
||||
Get masked lm output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, masked lm output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(self.width,
|
||||
config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
activation=config.hidden_act).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
self.output_bias = Parameter(
|
||||
initializer(
|
||||
'zero',
|
||||
config.vocab_size))
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.shape_flat_offsets = (-1, 1)
|
||||
self.last_idx = (-1,)
|
||||
self.shape_flat_sequence_tensor = (-1, self.width)
|
||||
self.cast = P.Cast()
|
||||
self.compute_type = config.compute_type
|
||||
self.dtype = config.dtype
|
||||
|
||||
def construct(self,
|
||||
input_tensor,
|
||||
output_weights,
|
||||
positions):
|
||||
"""Get output log_probs"""
|
||||
input_shape = P.Shape()(input_tensor)
|
||||
rng = F.tuple_to_array(F.make_range(input_shape[0]))
|
||||
flat_offsets = self.reshape(rng * input_shape[1], self.shape_flat_offsets)
|
||||
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
|
||||
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
||||
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
|
||||
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||
output_weights = self.cast(output_weights, self.compute_type)
|
||||
input_tensor = self.dense(input_tensor)
|
||||
input_tensor = self.layernorm(input_tensor)
|
||||
logits = self.matmul(input_tensor, output_weights)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = logits + self.output_bias
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class GetNextSentenceOutput(nn.Cell):
|
||||
"""
|
||||
Get next sentence output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of Bert.
|
||||
|
||||
Returns:
|
||||
Tensor, next sentence output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetNextSentenceOutput, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax()
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(config.hidden_size, 2,
|
||||
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
|
||||
self.dtype = config.dtype
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
logits = self.dense(input_tensor)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
log_prob = self.log_softmax(logits)
|
||||
return log_prob
|
||||
|
||||
|
||||
class BertPreTraining(nn.Cell):
|
||||
"""
|
||||
Bert pretraining network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores, seq_relationship_score.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(BertPreTraining, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cls1 = GetMaskedLMOutput(config)
|
||||
self.cls2 = GetNextSentenceOutput(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id,
|
||||
masked_lm_positions):
|
||||
sequence_output, pooled_output, embedding_table = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
prediction_scores = self.cls1(sequence_output,
|
||||
embedding_table,
|
||||
masked_lm_positions)
|
||||
seq_relationship_score = self.cls2(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertPretrainingLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertPretrainingLoss, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_labels):
|
||||
"""Defines the computation performed."""
|
||||
label_ids = self.reshape(masked_lm_ids, self.last_idx)
|
||||
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
||||
masked_lm_loss = numerator / denominator
|
||||
|
||||
# next_sentence_loss
|
||||
labels = self.reshape(next_sentence_labels, self.last_idx)
|
||||
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(
|
||||
one_hot_labels * seq_relationship_score, self.last_idx))
|
||||
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
|
||||
# total_loss
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class BertNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(BertNetworkWithLoss, self).__init__()
|
||||
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = BertPretrainingLoss(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Get pre-training loss"""
|
||||
prediction_scores, seq_relationship_score = \
|
||||
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||
total_loss = self.loss(prediction_scores, seq_relationship_score,
|
||||
masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class BertTrainOneStepCell(nn.TrainOneStepCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True):
|
||||
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.enable_clip_grad = enable_clip_grad
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
if self.enable_clip_grad:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
degree_sens = self.cast(scaling_sens * self.degree, mstype.float32)
|
||||
grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, scaling_sens)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
|
||||
condition as input.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||
self.optimizer(grads, overflow)
|
||||
return (loss, cond, scaling_sens)
|
||||
|
||||
|
||||
cast = P.Cast()
|
||||
add_grads = C.MultitypeFuncGraph("add_grads")
|
||||
|
||||
|
||||
@add_grads.register("Tensor", "Tensor")
|
||||
def _add_grads(accu_grad, grad):
|
||||
return accu_grad + cast(grad, mstype.float32)
|
||||
|
||||
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
||||
|
||||
|
||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||
def _accumulate_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
zeroslike = P.ZerosLike()
|
||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will only be implemented in the weight updated step,
|
||||
i.e. the sub-step after gradients accumulated N times.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_status = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
||||
mean_loss = F.depend(mean_loss, accu_succ)
|
||||
|
||||
init = F.depend(init, mean_loss)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
|
||||
if not is_accu_step:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(self.accu_grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
accu_overflow = F.depend(accu_overflow, grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
overflow = F.depend(overflow, accu_succ)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
|
||||
return (mean_loss, overflow, scaling_sens)
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
|
||||
will be overided by backend optimization pass.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
flag_reduce = self.overflow_reducer(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
else:
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertNetworkMatchBucket(nn.Cell):
|
||||
'''
|
||||
Bert execute according to different sentence lengths.
|
||||
'''
|
||||
|
||||
def __init__(self, network, seq_length, bucket_list=None):
|
||||
super(BertNetworkMatchBucket, self).__init__()
|
||||
self.network = network
|
||||
if not bucket_list or not isinstance(bucket_list, list):
|
||||
bucket_list = [seq_length]
|
||||
self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length]
|
||||
|
||||
if network.reducer_flag:
|
||||
reuse_attr = 'reuse_communication_node'
|
||||
if not network.grad_reducer.split_fusion:
|
||||
hccl_op = network.grad_reducer.allreduce
|
||||
network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
|
||||
else:
|
||||
new_op_list = []
|
||||
for hccl_op in network.grad_reducer.op_list:
|
||||
new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
|
||||
new_op_list.append(new_op)
|
||||
network.grad_reducer.op_list = new_op_list
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sentence_flag):
|
||||
"""Switch network according to sentence length."""
|
||||
for bucket in self.bucket_list:
|
||||
if sentence_flag == bucket:
|
||||
input_ids = input_ids[:, :bucket]
|
||||
input_mask = input_mask[:, :bucket]
|
||||
token_type_id = token_type_id[:, :bucket]
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
return loss
|
||||
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
return loss
|
||||
|
||||
|
||||
class BertPretrainEval(nn.Cell):
|
||||
'''
|
||||
Evaluate MaskedLM prediction scores
|
||||
'''
|
||||
|
||||
def __init__(self, config, network=None):
|
||||
super(BertPretrainEval, self).__init__(auto_prefix=False)
|
||||
if network is None:
|
||||
self.network = BertPreTraining(config, False, False)
|
||||
else:
|
||||
self.network = network
|
||||
self.argmax = P.Argmax(axis=-1, output_type=mstype.int32)
|
||||
self.equal = P.Equal()
|
||||
self.sum = P.ReduceSum()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
self.allreduce = P.AllReduce()
|
||||
self.reduce_flag = False
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reduce_flag = True
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Calculate prediction scores"""
|
||||
bs, _ = self.shape(input_ids)
|
||||
mlm, _ = self.network(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||
index = self.argmax(mlm)
|
||||
index = self.reshape(index, (bs, -1))
|
||||
eval_acc = self.equal(index, masked_lm_ids)
|
||||
eval_acc = self.cast(eval_acc, mstype.float32)
|
||||
real_acc = eval_acc * masked_lm_weights
|
||||
acc = self.sum(real_acc)
|
||||
total = self.sum(masked_lm_weights)
|
||||
|
||||
if self.reduce_flag:
|
||||
acc = self.allreduce(acc)
|
||||
total = self.allreduce(total)
|
||||
|
||||
return acc, total
|
|
@ -0,0 +1,857 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert model."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class BertConfig:
|
||||
"""
|
||||
Configuration for `BertModel`.
|
||||
|
||||
Args:
|
||||
seq_length (int): Length of input sequence. Default: 128.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
||||
cell. Default: 12.
|
||||
num_attention_heads (int): Number of attention heads in the BertTransformer
|
||||
encoder cell. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
||||
encoder cell. Default: 3072.
|
||||
hidden_act (str): Activation function used in the BertTransformer encoder
|
||||
cell. Default: "gelu".
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seq_length=128,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32):
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""Get output and embeddings lookup table"""
|
||||
extended_ids = self.expand(input_ids, -1)
|
||||
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(
|
||||
one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output = self.reshape(output_for_reshape, self.shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingPostprocessor(nn.Cell):
|
||||
"""
|
||||
Postprocessors apply positional and token type embeddings to word embeddings.
|
||||
|
||||
Args:
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_relative_positions=False,
|
||||
use_token_type=False,
|
||||
token_type_vocab_size=16,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=512,
|
||||
dropout_prob=0.1):
|
||||
super(EmbeddingPostprocessor, self).__init__()
|
||||
self.use_token_type = use_token_type
|
||||
self.token_type_vocab_size = token_type_vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.token_type_embedding = nn.Embedding(
|
||||
vocab_size=token_type_vocab_size,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.1, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
_, seq, _ = self.shape
|
||||
self.full_position_embedding = nn.Embedding(
|
||||
vocab_size=max_position_embeddings,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=False)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
token_type_embeddings = self.token_type_embedding(token_type_ids)
|
||||
output = self.add(output, token_type_embeddings)
|
||||
if not self.use_relative_positions:
|
||||
shape = F.shape(output)
|
||||
position_ids = self.position_ids[:, :shape[1]]
|
||||
position_embeddings = self.full_position_embedding(position_ids)
|
||||
output = self.add(output, position_embeddings)
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class BertOutput(nn.Cell):
|
||||
"""
|
||||
Apply a linear computation to hidden status and a residual computation to input.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
out_channels (int): Output channels.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
initializer_range=0.02,
|
||||
dropout_prob=0.1,
|
||||
compute_type=mstype.float32):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
return output
|
||||
|
||||
|
||||
class RelaPosMatrixGenerator(nn.Cell):
|
||||
"""
|
||||
Generates matrix of relative positions between inputs.
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
max_relative_position (int): Max value of relative position.
|
||||
"""
|
||||
|
||||
def __init__(self, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._max_relative_position = max_relative_position
|
||||
self._min_relative_position = -max_relative_position
|
||||
|
||||
self.tile = P.Tile()
|
||||
self.range_mat = P.Reshape()
|
||||
self.sub = P.Sub()
|
||||
self.expanddims = P.ExpandDims()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, length):
|
||||
"""Generates matrix of relative positions between inputs."""
|
||||
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(length)), mstype.int32)
|
||||
range_vec_col_out = self.range_mat(range_vec_row_out, (length, -1))
|
||||
tile_row_out = self.tile(range_vec_row_out, (length,))
|
||||
tile_col_out = self.tile(range_vec_col_out, (1, length))
|
||||
range_mat_out = self.range_mat(tile_row_out, (length, length))
|
||||
transpose_out = self.range_mat(tile_col_out, (length, length))
|
||||
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||
|
||||
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||
self._min_relative_position,
|
||||
self._max_relative_position)
|
||||
|
||||
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||
# relative position difference.
|
||||
final_mat = distance_mat_clipped + self._max_relative_position
|
||||
return final_mat
|
||||
|
||||
|
||||
class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||
"""
|
||||
Generates tensor of size [length, length, depth].
|
||||
|
||||
Args:
|
||||
length (int): Length of one dim for the matrix to be generated.
|
||||
depth (int): Size of each attention head.
|
||||
max_relative_position (int): Maxmum value of relative position.
|
||||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
max_relative_position,
|
||||
initializer_range,
|
||||
use_one_hot_embeddings=False):
|
||||
super(RelaPosEmbeddingsGenerator, self).__init__()
|
||||
self.depth = depth
|
||||
self.vocab_size = max_relative_position * 2 + 1
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
|
||||
self.embeddings_table = Parameter(
|
||||
initializer(TruncatedNormal(initializer_range),
|
||||
[self.vocab_size, self.depth]))
|
||||
|
||||
self.relative_positions_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position)
|
||||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self, length):
|
||||
"""Generate embedding for each relative position of dimension depth."""
|
||||
relative_positions_matrix_out = self.relative_positions_matrix(length)
|
||||
|
||||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
flat_relative_positions_matrix)
|
||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||
embeddings = self.reshape(embeddings, my_shape)
|
||||
else:
|
||||
embeddings = self.gather(self.embeddings_table,
|
||||
relative_positions_matrix_out, 0)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SaturateCast(nn.Cell):
|
||||
"""
|
||||
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
||||
the danger that the value will overflow or underflow.
|
||||
|
||||
Args:
|
||||
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
|
||||
self.tensor_min_type = float(np.finfo(np_type).min)
|
||||
self.tensor_max_type = float(np.finfo(np_type).max)
|
||||
|
||||
self.min_op = P.Minimum()
|
||||
self.max_op = P.Maximum()
|
||||
self.cast = P.Cast()
|
||||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
out = self.max_op(x, self.tensor_min_type)
|
||||
out = self.min_op(out, self.tensor_max_type)
|
||||
return self.cast(out, self.dst_type)
|
||||
|
||||
|
||||
class BertAttention(nn.Cell):
|
||||
"""
|
||||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
||||
Args:
|
||||
from_tensor_width (int): Size of last dim of from_tensor.
|
||||
to_tensor_width (int): Size of last dim of to_tensor.
|
||||
num_attention_heads (int): Number of attention heads. Default: 1.
|
||||
size_per_head (int): Size of each attention head. Default: 512.
|
||||
query_act (str): Activation function for the query transform. Default: None.
|
||||
key_act (str): Activation function for the key transform. Default: None.
|
||||
value_act (str): Activation function for the value transform. Default: None.
|
||||
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.0.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
from_tensor_width,
|
||||
to_tensor_width,
|
||||
num_attention_heads=1,
|
||||
size_per_head=512,
|
||||
query_act=None,
|
||||
key_act=None,
|
||||
value_act=None,
|
||||
has_attention_mask=False,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
|
||||
super(BertAttention, self).__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.size_per_head = size_per_head
|
||||
self.has_attention_mask = has_attention_mask
|
||||
self.use_relative_positions = use_relative_positions
|
||||
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
activation=query_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.multiply = P.Mul()
|
||||
self.transpose = P.Transpose()
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = -10000.0
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sub = P.Sub()
|
||||
self.add = P.Add()
|
||||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
|
||||
self.shape_return = (-1, num_attention_heads * size_per_head)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
if self.use_relative_positions:
|
||||
self._generate_relative_positions_embeddings = \
|
||||
RelaPosEmbeddingsGenerator(depth=size_per_head,
|
||||
max_relative_position=16,
|
||||
initializer_range=initializer_range,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""reshape 2d/3d input tensors to 2d"""
|
||||
shape_from = F.shape(attention_mask)[2]
|
||||
from_tensor = F.depend(from_tensor, shape_from)
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
query_out = self.query_layer(from_tensor_2d)
|
||||
key_out = self.key_layer(to_tensor_2d)
|
||||
value_out = self.value_layer(to_tensor_2d)
|
||||
|
||||
query_layer = self.reshape(query_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
|
||||
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||
key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
|
||||
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||
|
||||
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
relations_keys = self._generate_relative_positions_embeddings(shape_from)
|
||||
relations_keys = self.cast_compute_type(relations_keys)
|
||||
|
||||
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(shape_from,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
relations_keys)
|
||||
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(shape_from,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
shape_from))
|
||||
|
||||
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
attention_scores = attention_scores + key_position_scores_r_t
|
||||
|
||||
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||
|
||||
if self.has_attention_mask:
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
||||
|
||||
adder = self.multiply(multiply_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
value_layer = self.reshape(value_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
|
||||
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||
context_layer = self.matmul(attention_probs, value_layer)
|
||||
|
||||
# use_relative_position, supplementary logic
|
||||
if self.use_relative_positions:
|
||||
relations_values = self._generate_relative_positions_embeddings(shape_from)
|
||||
relations_values = self.cast_compute_type(relations_values)
|
||||
|
||||
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||
|
||||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(shape_from,
|
||||
-1,
|
||||
shape_from))
|
||||
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
relations_values)
|
||||
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(shape_from,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
|
||||
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||
self.trans_shape_position)
|
||||
context_layer = context_layer + value_position_scores_r_t
|
||||
|
||||
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||
context_layer = self.reshape(context_layer, self.shape_return)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Cell):
|
||||
"""
|
||||
Apply self-attention.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the bert encoder layers.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
"of attention heads (%d)" % (hidden_size, num_attention_heads))
|
||||
|
||||
self.size_per_head = int(hidden_size / num_attention_heads)
|
||||
|
||||
self.attention = BertAttention(
|
||||
from_tensor_width=hidden_size,
|
||||
to_tensor_width=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
size_per_head=self.size_per_head,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
use_relative_positions=use_relative_positions,
|
||||
has_attention_mask=True,
|
||||
compute_type=compute_type)
|
||||
|
||||
self.output = BertOutput(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
|
||||
output = self.output(attention_output, input_tensor)
|
||||
return output
|
||||
|
||||
|
||||
class BertEncoderCell(nn.Cell):
|
||||
"""
|
||||
Encoder cells used in BertTransformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.02.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.02,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32):
|
||||
super(BertEncoderCell, self).__init__()
|
||||
self.attention = BertSelfAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
compute_type=compute_type)
|
||||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
# self-attention
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
# add and normalize
|
||||
output = self.output(intermediate_output, attention_output)
|
||||
return output
|
||||
|
||||
|
||||
class BertTransformer(nn.Cell):
|
||||
"""
|
||||
Multi-layer bert transformer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of the encoder layers.
|
||||
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
|
||||
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
|
||||
attention_probs_dropout_prob (float): The dropout probability for
|
||||
BertAttention. Default: 0.1.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_hidden_layers,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
return_all_encoders=False):
|
||||
super(BertTransformer, self).__init__()
|
||||
self.return_all_encoders = return_all_encoders
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = BertEncoderCell(hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""Multi-layer bert transformer."""
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
all_encoder_layers = ()
|
||||
for layer_module in self.layers:
|
||||
layer_output = layer_module(prev_output, attention_mask)
|
||||
prev_output = layer_output
|
||||
|
||||
if self.return_all_encoders:
|
||||
shape = F.shape(input_tensor)
|
||||
layer_output = self.reshape(layer_output, shape)
|
||||
all_encoder_layers = all_encoder_layers + (layer_output,)
|
||||
|
||||
if not self.return_all_encoders:
|
||||
shape = F.shape(input_tensor)
|
||||
prev_output = self.reshape(prev_output, shape)
|
||||
all_encoder_layers = all_encoder_layers + (prev_output,)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||
"""
|
||||
Create attention mask according to input mask.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask = None
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, input_mask):
|
||||
seq_length = F.shape(input_mask)[1]
|
||||
attention_mask = self.cast(self.reshape(input_mask, (-1, 1, seq_length)), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BertModel(nn.Cell):
|
||||
"""
|
||||
Bidirectional Encoder Representations from Transformers.
|
||||
|
||||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
use_one_hot_embeddings=False):
|
||||
super(BertModel, self).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embedding_size = config.hidden_size
|
||||
self.token_type_ids = None
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [-1, config.seq_length, self.embedding_size]
|
||||
|
||||
self.bert_embedding_lookup = nn.Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings,
|
||||
embedding_table=TruncatedNormal(config.initializer_range))
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
use_token_type=True,
|
||||
token_type_vocab_size=config.type_vocab_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
dropout_prob=config.hidden_dropout_prob)
|
||||
|
||||
self.bert_encoder = BertTransformer(
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=config.intermediate_size,
|
||||
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range,
|
||||
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
return_all_encoders=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
self.slice = P.StridedSlice()
|
||||
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
word_embeddings = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
||||
# attention mask [batch_size, seq_length, seq_length]
|
||||
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||
|
||||
# bert encoder
|
||||
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
|
||||
attention_mask)
|
||||
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
pooled_output = self.cast(pooled_output, self.dtype)
|
||||
|
||||
return sequence_output, pooled_output, embedding_tables
|
|
@ -0,0 +1,274 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Functional Cells used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import collections
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, logits, label_ids, num_labels):
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0
|
||||
return return_value
|
||||
|
||||
|
||||
def make_directory(path: str):
|
||||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invalid type")
|
||||
|
||||
# convert the relative paths
|
||||
path = os.path.realpath(path)
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
||||
# check the path is exist and write permissions?
|
||||
if os.path.exists(path):
|
||||
real_path = path
|
||||
else:
|
||||
# All exceptions need to be caught because create directory maybe have some limit(permissions)
|
||||
logger.debug("The directory(%s) doesn't exist, will create it", path)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
real_path = path
|
||||
except PermissionError as e:
|
||||
logger.error("No write permission on the directory(%r), error = %r", path, e)
|
||||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_size=-1):
|
||||
super(LossCallBack, self).__init__()
|
||||
self._dataset_size = dataset_size
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Print loss after each step
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if self._dataset_size > 0:
|
||||
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
|
||||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
|
||||
flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
|
||||
def convert_labels_to_index(label_list):
|
||||
"""
|
||||
Convert label_list to indices for NER task.
|
||||
"""
|
||||
label2id = collections.OrderedDict()
|
||||
label2id["O"] = 0
|
||||
prefix = ["S_", "B_", "M_", "E_"]
|
||||
index = 0
|
||||
for label in label_list:
|
||||
for pre in prefix:
|
||||
index += 1
|
||||
sub_label = pre + label
|
||||
label2id[sub_label] = index
|
||||
return label2id
|
||||
|
||||
|
||||
def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): current step
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_steps(int): number of warmup epochs
|
||||
total_steps(int): total epoch of training
|
||||
poly_power(int): poly learning rate power
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max - lr_end) * (base ** poly_power)
|
||||
lr = lr + lr_end
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = global_step
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
|
||||
total_steps=lr_total_steps, poly_power=lr_power)
|
||||
return Tensor(learning_rate)
|
||||
|
||||
|
||||
def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
|
||||
total_steps=damping_total_steps, poly_power=damping_power)
|
||||
return Tensor(damping)
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluate after a certain amount of training samples.
|
||||
Args:
|
||||
model (Model): The network model.
|
||||
eval_ds (Dataset): The eval dataset.
|
||||
global_batch (int): The batchsize of the sum of all devices.
|
||||
eval_samples (int): The number of eval interval samples.
|
||||
"""
|
||||
|
||||
def __init__(self, model, eval_ds, global_batch, eval_samples):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.model = model
|
||||
self.eval_ds = eval_ds
|
||||
self.global_batch = global_batch
|
||||
self.eval_samples = eval_samples
|
||||
self.last_eval_step = 0
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Evaluate after training a certain number of samples.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
num_samples = (cb_params.cur_step_num - self.last_eval_step) * self.global_batch
|
||||
if num_samples < self.eval_samples:
|
||||
return
|
||||
self.last_eval_step = cb_params.cur_step_num
|
||||
total_sumples = cb_params.cur_step_num * self.global_batch
|
||||
res = self.model.eval(self.eval_ds, dataset_sink_mode=True)
|
||||
res = res['bert_acc']
|
||||
print("====================================", flush=True)
|
||||
print("Accuracy is: ", "%.6f" % res, ", current samples is: ", total_sumples)
|
||||
print("====================================", flush=True)
|
||||
|
||||
|
||||
class BertMetric(Metric):
|
||||
"""
|
||||
The metric of bert network.
|
||||
Args:
|
||||
batch_size (int): The batchsize of each device.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size):
|
||||
super(BertMetric, self).__init__()
|
||||
self.clear()
|
||||
self.batch_size = batch_size
|
||||
|
||||
def clear(self):
|
||||
self.mlm_total = 0
|
||||
self.mlm_acc = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
mlm_acc = self._convert_data(inputs[0])
|
||||
mlm_total = self._convert_data(inputs[1])
|
||||
self.mlm_acc += mlm_acc
|
||||
self.mlm_total += mlm_total
|
||||
|
||||
def eval(self):
|
||||
return self.mlm_acc / self.mlm_total
|
|
@ -33,9 +33,11 @@ from mindspore.train.model import Model
|
|||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
import mindspore.dataset.transforms as C
|
||||
|
||||
from tests.models.official.nlp.bert.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from tests.models.official.nlp.bert.src.utils import get_bert_thor_lr, get_bert_thor_damping
|
||||
from tests.models.official.nlp.bert.src.bert_model import BertConfig
|
||||
from tests.st.networks.models.bert.bert_performance.src.bert_for_pre_training import BertNetworkWithLoss, \
|
||||
BertTrainOneStepCell
|
||||
|
||||
from tests.st.networks.models.bert.bert_performance.src.utils import get_bert_thor_lr, get_bert_thor_damping
|
||||
from tests.st.networks.models.bert.bert_performance.src.bert_model import BertConfig
|
||||
|
||||
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
|
||||
DATASET_PATH = "/home/workspace/mindspore_dataset/bert/thor/en-wiki-512_test_first1wan"
|
||||
|
@ -188,7 +190,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
|
|||
q.put({'loss': loss_list, 'cost': per_step_mseconds})
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
|
|
Loading…
Reference in New Issue