!35210 修复 ST probabilistic failure in daily version

Merge pull request !35210 from 王程浩/master
This commit is contained in:
i-robot 2022-06-10 08:36:31 +00:00 committed by Gitee
commit c79d9cb6d6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 2048 additions and 4 deletions

View File

@ -0,0 +1,911 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert for pretraining."""
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer, TruncatedNormal
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from src.bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class GetMaskedLMOutput(nn.Cell):
"""
Get masked lm output.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, masked lm output.
"""
def __init__(self, config):
super(GetMaskedLMOutput, self).__init__()
self.width = config.hidden_size
self.reshape = P.Reshape()
self.gather = P.Gather()
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(self.width,
config.hidden_size,
weight_init=weight_init,
activation=config.hidden_act).to_float(config.compute_type)
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
self.output_bias = Parameter(
initializer(
'zero',
config.vocab_size))
self.matmul = P.MatMul(transpose_b=True)
self.log_softmax = nn.LogSoftmax(axis=-1)
self.shape_flat_offsets = (-1, 1)
self.last_idx = (-1,)
self.shape_flat_sequence_tensor = (-1, self.width)
self.cast = P.Cast()
self.compute_type = config.compute_type
self.dtype = config.dtype
def construct(self,
input_tensor,
output_weights,
positions):
"""Get output log_probs"""
input_shape = P.Shape()(input_tensor)
rng = F.tuple_to_array(F.make_range(input_shape[0]))
flat_offsets = self.reshape(rng * input_shape[1], self.shape_flat_offsets)
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
input_tensor = self.cast(input_tensor, self.compute_type)
output_weights = self.cast(output_weights, self.compute_type)
input_tensor = self.dense(input_tensor)
input_tensor = self.layernorm(input_tensor)
logits = self.matmul(input_tensor, output_weights)
logits = self.cast(logits, self.dtype)
logits = logits + self.output_bias
log_probs = self.log_softmax(logits)
return log_probs
class GetNextSentenceOutput(nn.Cell):
"""
Get next sentence output.
Args:
config (BertConfig): The config of Bert.
Returns:
Tensor, next sentence output.
"""
def __init__(self, config):
super(GetNextSentenceOutput, self).__init__()
self.log_softmax = P.LogSoftmax()
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2,
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype
self.cast = P.Cast()
def construct(self, input_tensor):
logits = self.dense(input_tensor)
logits = self.cast(logits, self.dtype)
log_prob = self.log_softmax(logits)
return log_prob
class BertPreTraining(nn.Cell):
"""
Bert pretraining network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
Returns:
Tensor, prediction_scores, seq_relationship_score.
"""
def __init__(self, config, is_training, use_one_hot_embeddings):
super(BertPreTraining, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cls1 = GetMaskedLMOutput(config)
self.cls2 = GetNextSentenceOutput(config)
def construct(self, input_ids, input_mask, token_type_id,
masked_lm_positions):
sequence_output, pooled_output, embedding_table = \
self.bert(input_ids, token_type_id, input_mask)
prediction_scores = self.cls1(sequence_output,
embedding_table,
masked_lm_positions)
seq_relationship_score = self.cls2(pooled_output)
return prediction_scores, seq_relationship_score
class BertPretrainingLoss(nn.Cell):
"""
Provide bert pre-training loss.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, total loss.
"""
def __init__(self, config):
super(BertPretrainingLoss, self).__init__()
self.vocab_size = config.vocab_size
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
masked_lm_weights, next_sentence_labels):
"""Defines the computation performed."""
label_ids = self.reshape(masked_lm_ids, self.last_idx)
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
numerator = self.reduce_sum(label_weights * per_example_loss, ())
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
masked_lm_loss = numerator / denominator
# next_sentence_loss
labels = self.reshape(next_sentence_labels, self.last_idx)
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(
one_hot_labels * seq_relationship_score, self.last_idx))
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
# total_loss
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
class BertNetworkWithLoss(nn.Cell):
"""
Provide bert pre-training loss through network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(BertNetworkWithLoss, self).__init__()
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
self.loss = BertPretrainingLoss(config)
self.cast = P.Cast()
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
"""Get pre-training loss"""
prediction_scores, seq_relationship_score = \
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
total_loss = self.loss(prediction_scores, seq_relationship_score,
masked_lm_ids, masked_lm_weights, next_sentence_labels)
return self.cast(total_loss, mstype.float32)
class BertTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True.
"""
def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True):
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
self.enable_clip_grad = enable_clip_grad
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
if self.enable_clip_grad:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
self.cast = P.Cast()
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
degree_sens = self.cast(scaling_sens * self.degree, mstype.float32)
grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
cond = self.get_overflow_status(status, grads)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if not overflow:
self.optimizer(grads)
return (loss, cond, scaling_sens)
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
condition as input.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
self.cast = P.Cast()
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
cond = self.get_overflow_status(status, grads)
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(scaling_sens, cond)
self.optimizer(grads, overflow)
return (loss, cond, scaling_sens)
cast = P.Cast()
add_grads = C.MultitypeFuncGraph("add_grads")
@add_grads.register("Tensor", "Tensor")
def _add_grads(accu_grad, grad):
return accu_grad + cast(grad, mstype.float32)
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
zeroslike = P.ZerosLike()
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will only be implemented in the weight updated step,
i.e. the sub-step after gradients accumulated N times.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_status = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
init = F.depend(init, loss)
clear_status = self.clear_status(init)
scaling_sens = F.depend(scaling_sens, clear_status)
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
init = F.depend(init, mean_loss)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
if not is_accu_step:
# apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
if self.enable_global_norm:
grads = C.clip_by_global_norm(grads, 1.0, None)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
accu_overflow = F.depend(accu_overflow, grads)
accu_overflow = self.overflow_reducer(accu_overflow)
overflow = self.less_equal(self.base, accu_overflow)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
overflow = F.depend(overflow, accu_succ)
overflow = self.reshape(overflow, (()))
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if not overflow:
self.optimizer(grads)
return (mean_loss, overflow, scaling_sens)
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
To mimic higher batch size, gradients are accumulated N times before weight update.
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
will be overided by backend optimization pass.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
grads = self.grad_reducer(grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
flag_reduce = self.overflow_reducer(flag_sum)
overflow = self.less_equal(self.base, flag_reduce)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
overflow = self.reshape(overflow, (()))
if is_accu_step:
succ = False
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
succ = F.depend(succ, accu_succ)
else:
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
if self.enable_global_norm:
grads = C.clip_by_global_norm(grads, 1.0, None)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
succ = self.optimizer(grads)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
succ = F.depend(succ, accu_succ)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)
class BertNetworkMatchBucket(nn.Cell):
'''
Bert execute according to different sentence lengths.
'''
def __init__(self, network, seq_length, bucket_list=None):
super(BertNetworkMatchBucket, self).__init__()
self.network = network
if not bucket_list or not isinstance(bucket_list, list):
bucket_list = [seq_length]
self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length]
if network.reducer_flag:
reuse_attr = 'reuse_communication_node'
if not network.grad_reducer.split_fusion:
hccl_op = network.grad_reducer.allreduce
network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
else:
new_op_list = []
for hccl_op in network.grad_reducer.op_list:
new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion'))
new_op_list.append(new_op)
network.grad_reducer.op_list = new_op_list
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sentence_flag):
"""Switch network according to sentence length."""
for bucket in self.bucket_list:
if sentence_flag == bucket:
input_ids = input_ids[:, :bucket]
input_mask = input_mask[:, :bucket]
token_type_id = token_type_id[:, :bucket]
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
return loss
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
return loss
class BertPretrainEval(nn.Cell):
'''
Evaluate MaskedLM prediction scores
'''
def __init__(self, config, network=None):
super(BertPretrainEval, self).__init__(auto_prefix=False)
if network is None:
self.network = BertPreTraining(config, False, False)
else:
self.network = network
self.argmax = P.Argmax(axis=-1, output_type=mstype.int32)
self.equal = P.Equal()
self.sum = P.ReduceSum()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
self.allreduce = P.AllReduce()
self.reduce_flag = False
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reduce_flag = True
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
"""Calculate prediction scores"""
bs, _ = self.shape(input_ids)
mlm, _ = self.network(input_ids, input_mask, token_type_id, masked_lm_positions)
index = self.argmax(mlm)
index = self.reshape(index, (bs, -1))
eval_acc = self.equal(index, masked_lm_ids)
eval_acc = self.cast(eval_acc, mstype.float32)
real_acc = eval_acc * masked_lm_weights
acc = self.sum(real_acc)
total = self.sum(masked_lm_weights)
if self.reduce_flag:
acc = self.allreduce(acc)
total = self.allreduce(total)
return acc, total

View File

@ -0,0 +1,857 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert model."""
import copy
import math
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
class BertConfig:
"""
Configuration for `BertModel`.
Args:
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 32000.
hidden_size (int): Size of the bert encoder layers. Default: 768.
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
cell. Default: 12.
num_attention_heads (int): Number of attention heads in the BertTransformer
encoder cell. Default: 12.
intermediate_size (int): Size of intermediate layer in the BertTransformer
encoder cell. Default: 3072.
hidden_act (str): Activation function used in the BertTransformer encoder
cell. Default: "gelu".
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
type_vocab_size (int): Size of token type vocab. Default: 16.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
seq_length=128,
vocab_size=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float32):
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
embedding_shape,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]))
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
def construct(self, input_ids):
"""Get output and embeddings lookup table"""
extended_ids = self.expand(input_ids, -1)
flat_ids = self.reshape(extended_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, self.shape)
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional and token type embeddings to word embeddings.
Args:
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
token_type_vocab_size (int): Size of token type vocab. Default: 16.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
embedding_size,
embedding_shape,
use_relative_positions=False,
use_token_type=False,
token_type_vocab_size=16,
use_one_hot_embeddings=False,
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
super(EmbeddingPostprocessor, self).__init__()
self.use_token_type = use_token_type
self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings
self.token_type_embedding = nn.Embedding(
vocab_size=token_type_vocab_size,
embedding_size=embedding_size,
use_one_hot=use_one_hot_embeddings)
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.1, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.Gather()
self.use_relative_positions = use_relative_positions
self.slice = P.StridedSlice()
_, seq, _ = self.shape
self.full_position_embedding = nn.Embedding(
vocab_size=max_position_embeddings,
embedding_size=embedding_size,
use_one_hot=False)
self.layernorm = nn.LayerNorm((embedding_size,))
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.add = P.Add()
def construct(self, token_type_ids, word_embeddings):
"""Postprocessors apply positional and token type embeddings to word embeddings."""
output = word_embeddings
if self.use_token_type:
token_type_embeddings = self.token_type_embedding(token_type_ids)
output = self.add(output, token_type_embeddings)
if not self.use_relative_positions:
shape = F.shape(output)
position_ids = self.position_ids[:, :shape[1]]
position_embeddings = self.full_position_embedding(position_ids)
output = self.add(output, position_embeddings)
output = self.layernorm(output)
output = self.dropout(output)
return output
class BertOutput(nn.Cell):
"""
Apply a linear computation to hidden status and a residual computation to input.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
dropout_prob (float): The dropout probability. Default: 0.1.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
in_channels,
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.Add()
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(input_tensor, output)
output = self.layernorm(output)
return output
class RelaPosMatrixGenerator(nn.Cell):
"""
Generates matrix of relative positions between inputs.
Args:
length (int): Length of one dim for the matrix to be generated.
max_relative_position (int): Max value of relative position.
"""
def __init__(self, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._max_relative_position = max_relative_position
self._min_relative_position = -max_relative_position
self.tile = P.Tile()
self.range_mat = P.Reshape()
self.sub = P.Sub()
self.expanddims = P.ExpandDims()
self.cast = P.Cast()
def construct(self, length):
"""Generates matrix of relative positions between inputs."""
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(length)), mstype.int32)
range_vec_col_out = self.range_mat(range_vec_row_out, (length, -1))
tile_row_out = self.tile(range_vec_row_out, (length,))
tile_col_out = self.tile(range_vec_col_out, (1, length))
range_mat_out = self.range_mat(tile_row_out, (length, length))
transpose_out = self.range_mat(tile_col_out, (length, length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.
final_mat = distance_mat_clipped + self._max_relative_position
return final_mat
class RelaPosEmbeddingsGenerator(nn.Cell):
"""
Generates tensor of size [length, length, depth].
Args:
length (int): Length of one dim for the matrix to be generated.
depth (int): Size of each attention head.
max_relative_position (int): Maxmum value of relative position.
initializer_range (float): Initialization value of TruncatedNormal.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
depth,
max_relative_position,
initializer_range,
use_one_hot_embeddings=False):
super(RelaPosEmbeddingsGenerator, self).__init__()
self.depth = depth
self.vocab_size = max_relative_position * 2 + 1
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embeddings_table = Parameter(
initializer(TruncatedNormal(initializer_range),
[self.vocab_size, self.depth]))
self.relative_positions_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = nn.OneHot(depth=self.vocab_size)
self.shape = P.Shape()
self.gather = P.Gather() # index_select
self.matmul = P.BatchMatMul()
def construct(self, length):
"""Generate embedding for each relative position of dimension depth."""
relative_positions_matrix_out = self.relative_positions_matrix(length)
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
else:
embeddings = self.gather(self.embeddings_table,
relative_positions_matrix_out, 0)
return embeddings
class SaturateCast(nn.Cell):
"""
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
the danger that the value will overflow or underflow.
Args:
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
"""
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
self.tensor_min_type = float(np.finfo(np_type).min)
self.tensor_max_type = float(np.finfo(np_type).max)
self.min_op = P.Minimum()
self.max_op = P.Maximum()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
out = self.max_op(x, self.tensor_min_type)
out = self.min_op(out, self.tensor_max_type)
return self.cast(out, self.dst_type)
class BertAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
from_tensor_width (int): Size of last dim of from_tensor.
to_tensor_width (int): Size of last dim of to_tensor.
num_attention_heads (int): Number of attention heads. Default: 1.
size_per_head (int): Size of each attention head. Default: 512.
query_act (str): Activation function for the query transform. Default: None.
key_act (str): Activation function for the key transform. Default: None.
value_act (str): Activation function for the value transform. Default: None.
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.0.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
"""
def __init__(self,
from_tensor_width,
to_tensor_width,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=False,
attention_probs_dropout_prob=0.0,
use_one_hot_embeddings=False,
initializer_range=0.02,
use_relative_positions=False,
compute_type=mstype.float32):
super(BertAttention, self).__init__()
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
weight = TruncatedNormal(initializer_range)
units = num_attention_heads * size_per_head
self.query_layer = nn.Dense(from_tensor_width,
units,
activation=query_act,
weight_init=weight).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width,
units,
activation=key_act,
weight_init=weight).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width,
units,
activation=value_act,
weight_init=weight).to_float(compute_type)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = -10000.0
self.matmul = P.BatchMatMul()
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.Add()
self.cast = P.Cast()
self.get_dtype = P.DType()
self.shape_return = (-1, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
"""reshape 2d/3d input tensors to 2d"""
shape_from = F.shape(attention_mask)[2]
from_tensor = F.depend(from_tensor, shape_from)
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query_layer(from_tensor_2d)
key_out = self.key_layer(to_tensor_2d)
value_out = self.value_layer(to_tensor_2d)
query_layer = self.reshape(query_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
relations_keys = self._generate_relative_positions_embeddings(shape_from)
relations_keys = self.cast_compute_type(relations_keys)
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
query_layer_r = self.reshape(query_layer_t,
(shape_from,
-1,
self.size_per_head))
key_position_scores = self.matmul_trans_b(query_layer_r,
relations_keys)
key_position_scores_r = self.reshape(key_position_scores,
(shape_from,
-1,
self.num_attention_heads,
shape_from))
key_position_scores_r_t = self.transpose(key_position_scores_r,
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(self.scores_mul, attention_scores)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores)))
adder = self.multiply(multiply_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_probs = self.softmax(attention_scores)
attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, (-1, shape_from, self.num_attention_heads, self.size_per_head))
value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
relations_values = self._generate_relative_positions_embeddings(shape_from)
relations_values = self.cast_compute_type(relations_values)
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
attention_probs_r = self.reshape(
attention_probs_t,
(shape_from,
-1,
shape_from))
value_position_scores = self.matmul(attention_probs_r,
relations_values)
value_position_scores_r = self.reshape(value_position_scores,
(shape_from,
-1,
self.num_attention_heads,
self.size_per_head))
value_position_scores_r_t = self.transpose(value_position_scores_r,
self.trans_shape_position)
context_layer = context_layer + value_position_scores_r_t
context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, self.shape_return)
return context_layer
class BertSelfAttention(nn.Cell):
"""
Apply self-attention.
Args:
hidden_size (int): Size of the bert encoder layers.
num_attention_heads (int): Number of attention heads. Default: 12.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
"""
def __init__(self,
hidden_size,
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (hidden_size, num_attention_heads))
self.size_per_head = int(hidden_size / num_attention_heads)
self.attention = BertAttention(
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
num_attention_heads=num_attention_heads,
size_per_head=self.size_per_head,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
use_relative_positions=use_relative_positions,
has_attention_mask=True,
compute_type=compute_type)
self.output = BertOutput(in_channels=hidden_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, attention_mask):
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
output = self.output(attention_output, input_tensor)
return output
class BertEncoderCell(nn.Cell):
"""
Encoder cells used in BertTransformer.
Args:
hidden_size (int): Size of the bert encoder layers. Default: 768.
num_attention_heads (int): Number of attention heads. Default: 12.
intermediate_size (int): Size of intermediate layer. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.02.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.02,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.output = BertOutput(in_channels=intermediate_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
def construct(self, hidden_states, attention_mask):
# self-attention
attention_output = self.attention(hidden_states, attention_mask)
# feed construct
intermediate_output = self.intermediate(attention_output)
# add and normalize
output = self.output(intermediate_output, attention_output)
return output
class BertTransformer(nn.Cell):
"""
Multi-layer bert transformer.
Args:
hidden_size (int): Size of the encoder layers.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
"""
def __init__(self,
hidden_size,
num_hidden_layers,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
layers = []
for _ in range(num_hidden_layers):
layer = BertEncoderCell(hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, attention_mask):
"""Multi-layer bert transformer."""
prev_output = self.reshape(input_tensor, self.shape)
all_encoder_layers = ()
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask)
prev_output = layer_output
if self.return_all_encoders:
shape = F.shape(input_tensor)
layer_output = self.reshape(layer_output, shape)
all_encoder_layers = all_encoder_layers + (layer_output,)
if not self.return_all_encoders:
shape = F.shape(input_tensor)
prev_output = self.reshape(prev_output, shape)
all_encoder_layers = all_encoder_layers + (prev_output,)
return all_encoder_layers
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (Class): Configuration for BertModel.
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask = None
self.cast = P.Cast()
self.reshape = P.Reshape()
def construct(self, input_mask):
seq_length = F.shape(input_mask)[1]
attention_mask = self.cast(self.reshape(input_mask, (-1, 1, seq_length)), mstype.float32)
return attention_mask
class BertModel(nn.Cell):
"""
Bidirectional Encoder Representations from Transformers.
Args:
config (Class): Configuration for BertModel.
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
config,
is_training,
use_one_hot_embeddings=False):
super(BertModel, self).__init__()
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [-1, config.seq_length, self.embedding_size]
self.bert_embedding_lookup = nn.Embedding(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
use_one_hot=use_one_hot_embeddings,
embedding_table=TruncatedNormal(config.initializer_range))
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_relative_positions=config.use_relative_positions,
use_token_type=True,
token_type_vocab_size=config.type_vocab_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.bert_encoder = BertTransformer(
hidden_size=self.hidden_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
self.slice = P.StridedSlice()
self.squeeze_1 = P.Squeeze(axis=1)
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
activation="tanh",
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
def construct(self, input_ids, token_type_ids, input_mask):
"""Bidirectional Encoder Representations from Transformers."""
# embedding
embedding_tables = self.bert_embedding_lookup.embedding_table
word_embeddings = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
# bert encoder
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
pooled_output = self.cast(pooled_output, self.dtype)
return sequence_output, pooled_output, embedding_tables

View File

@ -0,0 +1,274 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Functional Cells used in Bert finetune and evaluation.
"""
import os
import collections
import math
import numpy as np
import mindspore.nn as nn
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from mindspore.nn.metrics import Metric
from mindspore.ops import operations as P
from mindspore.train.callback import Callback
class CrossEntropyCalculation(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, is_training=True):
super(CrossEntropyCalculation, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.is_training = is_training
def construct(self, logits, label_ids, num_labels):
if self.is_training:
label_ids = self.reshape(label_ids, self.last_idx)
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
loss = self.reduce_mean(per_example_loss, self.last_idx)
return_value = self.cast(loss, mstype.float32)
else:
return_value = logits * 1.0
return return_value
def make_directory(path: str):
"""Make directory."""
if path is None or not isinstance(path, str) or path.strip() == "":
logger.error("The path(%r) is invalid type.", path)
raise TypeError("Input path is invalid type")
# convert the relative paths
path = os.path.realpath(path)
logger.debug("The abs path is %r", path)
# check the path is exist and write permissions?
if os.path.exists(path):
real_path = path
else:
# All exceptions need to be caught because create directory maybe have some limit(permissions)
logger.debug("The directory(%s) doesn't exist, will create it", path)
try:
os.makedirs(path, exist_ok=True)
real_path = path
except PermissionError as e:
logger.error("No write permission on the directory(%r), error = %r", path, e)
raise TypeError("No write permission on the directory.")
return real_path
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss in NAN or INF terminating training.
Note:
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, dataset_size=-1):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
def step_end(self, run_context):
"""
Print loss after each step
"""
cb_params = run_context.original_args()
if self._dataset_size > 0:
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
if percent == 0:
percent = 1
epoch_num -= 1
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
flush=True)
else:
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)), flush=True)
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
decay_lr = self.decay_lr(global_step)
if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr
def convert_labels_to_index(label_list):
"""
Convert label_list to indices for NER task.
"""
label2id = collections.OrderedDict()
label2id["O"] = 0
prefix = ["S_", "B_", "M_", "E_"]
index = 0
for label in label_list:
for pre in prefix:
index += 1
sub_label = pre + label
label2id[sub_label] = index
return label2id
def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
"""
generate learning rate array
Args:
global_step(int): current step
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_steps(int): number of warmup epochs
total_steps(int): total epoch of training
poly_power(int): poly learning rate power
Returns:
np.array, learning rate array
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max - lr_end) * (base ** poly_power)
lr = lr + lr_end
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
learning_rate = np.array(lr_each_step).astype(np.float32)
current_step = global_step
learning_rate = learning_rate[current_step:]
return learning_rate
def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
total_steps=lr_total_steps, poly_power=lr_power)
return Tensor(learning_rate)
def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
total_steps=damping_total_steps, poly_power=damping_power)
return Tensor(damping)
class EvalCallBack(Callback):
"""
Evaluate after a certain amount of training samples.
Args:
model (Model): The network model.
eval_ds (Dataset): The eval dataset.
global_batch (int): The batchsize of the sum of all devices.
eval_samples (int): The number of eval interval samples.
"""
def __init__(self, model, eval_ds, global_batch, eval_samples):
super(EvalCallBack, self).__init__()
self.model = model
self.eval_ds = eval_ds
self.global_batch = global_batch
self.eval_samples = eval_samples
self.last_eval_step = 0
def epoch_end(self, run_context):
"""
Evaluate after training a certain number of samples.
"""
cb_params = run_context.original_args()
num_samples = (cb_params.cur_step_num - self.last_eval_step) * self.global_batch
if num_samples < self.eval_samples:
return
self.last_eval_step = cb_params.cur_step_num
total_sumples = cb_params.cur_step_num * self.global_batch
res = self.model.eval(self.eval_ds, dataset_sink_mode=True)
res = res['bert_acc']
print("====================================", flush=True)
print("Accuracy is: ", "%.6f" % res, ", current samples is: ", total_sumples)
print("====================================", flush=True)
class BertMetric(Metric):
"""
The metric of bert network.
Args:
batch_size (int): The batchsize of each device.
"""
def __init__(self, batch_size):
super(BertMetric, self).__init__()
self.clear()
self.batch_size = batch_size
def clear(self):
self.mlm_total = 0
self.mlm_acc = 0
def update(self, *inputs):
mlm_acc = self._convert_data(inputs[0])
mlm_total = self._convert_data(inputs[1])
self.mlm_acc += mlm_acc
self.mlm_total += mlm_total
def eval(self):
return self.mlm_acc / self.mlm_total

View File

@ -33,9 +33,11 @@ from mindspore.train.model import Model
from mindspore.train.train_thor import ConvertModelUtils
import mindspore.dataset.transforms as C
from tests.models.official.nlp.bert.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell
from tests.models.official.nlp.bert.src.utils import get_bert_thor_lr, get_bert_thor_damping
from tests.models.official.nlp.bert.src.bert_model import BertConfig
from tests.st.networks.models.bert.bert_performance.src.bert_for_pre_training import BertNetworkWithLoss, \
BertTrainOneStepCell
from tests.st.networks.models.bert.bert_performance.src.utils import get_bert_thor_lr, get_bert_thor_damping
from tests.st.networks.models.bert.bert_performance.src.bert_model import BertConfig
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
DATASET_PATH = "/home/workspace/mindspore_dataset/bert/thor/en-wiki-512_test_first1wan"
@ -188,7 +190,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
q.put({'loss': loss_list, 'cost': per_step_mseconds})
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single