From 1561337f1e6d7d31772571ed72cc51c482b03ba2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A8=8B=E6=B5=A9?= Date: Tue, 31 May 2022 00:58:50 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20ST=20probabilistic=20failu?= =?UTF-8?q?re=20in=20daily=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/bert_for_pre_training.py | 911 ++++++++++++++++++ .../bert/bert_performance/src/bert_model.py | 857 ++++++++++++++++ .../models/bert/bert_performance/src/utils.py | 274 ++++++ .../bert/bert_performance/test_bert_thor.py | 10 +- 4 files changed, 2048 insertions(+), 4 deletions(-) create mode 100644 tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py create mode 100644 tests/st/networks/models/bert/bert_performance/src/bert_model.py create mode 100644 tests/st/networks/models/bert/bert_performance/src/utils.py diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py new file mode 100644 index 00000000000..eea55185fe2 --- /dev/null +++ b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py @@ -0,0 +1,911 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert for pretraining.""" +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.communication.management import get_group_size +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from src.bert_model import BertModel + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type not in (0, 1): + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +class GetMaskedLMOutput(nn.Cell): + """ + Get masked lm output. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, masked lm output. + """ + + def __init__(self, config): + super(GetMaskedLMOutput, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + self.gather = P.Gather() + + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(self.width, + config.hidden_size, + weight_init=weight_init, + activation=config.hidden_act).to_float(config.compute_type) + self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) + self.output_bias = Parameter( + initializer( + 'zero', + config.vocab_size)) + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_offsets = (-1, 1) + self.last_idx = (-1,) + self.shape_flat_sequence_tensor = (-1, self.width) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + + def construct(self, + input_tensor, + output_weights, + positions): + """Get output log_probs""" + input_shape = P.Shape()(input_tensor) + rng = F.tuple_to_array(F.make_range(input_shape[0])) + flat_offsets = self.reshape(rng * input_shape[1], self.shape_flat_offsets) + flat_position = self.reshape(positions + flat_offsets, self.last_idx) + flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + input_tensor = self.dense(input_tensor) + input_tensor = self.layernorm(input_tensor) + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + logits = logits + self.output_bias + log_probs = self.log_softmax(logits) + return log_probs + + +class GetNextSentenceOutput(nn.Cell): + """ + Get next sentence output. + + Args: + config (BertConfig): The config of Bert. + + Returns: + Tensor, next sentence output. + """ + + def __init__(self, config): + super(GetNextSentenceOutput, self).__init__() + self.log_softmax = P.LogSoftmax() + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(config.hidden_size, 2, + weight_init=weight_init, has_bias=True).to_float(config.compute_type) + self.dtype = config.dtype + self.cast = P.Cast() + + def construct(self, input_tensor): + logits = self.dense(input_tensor) + logits = self.cast(logits, self.dtype) + log_prob = self.log_softmax(logits) + return log_prob + + +class BertPreTraining(nn.Cell): + """ + Bert pretraining network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings): + super(BertPreTraining, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cls1 = GetMaskedLMOutput(config) + self.cls2 = GetNextSentenceOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, + masked_lm_positions): + sequence_output, pooled_output, embedding_table = \ + self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, + embedding_table, + masked_lm_positions) + seq_relationship_score = self.cls2(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPretrainingLoss(nn.Cell): + """ + Provide bert pre-training loss. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, total loss. + """ + + def __init__(self, config): + super(BertPretrainingLoss, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + + def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, + masked_lm_weights, next_sentence_labels): + """Defines the computation performed.""" + label_ids = self.reshape(masked_lm_ids, self.last_idx) + label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + masked_lm_loss = numerator / denominator + + # next_sentence_loss + labels = self.reshape(next_sentence_labels, self.last_idx) + one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum( + one_hot_labels * seq_relationship_score, self.last_idx)) + next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) + + # total_loss + total_loss = masked_lm_loss + next_sentence_loss + + return total_loss + + +class BertNetworkWithLoss(nn.Cell): + """ + Provide bert pre-training loss through network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(BertNetworkWithLoss, self).__init__() + self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) + self.loss = BertPretrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Get pre-training loss""" + prediction_scores, seq_relationship_score = \ + self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) + total_loss = self.loss(prediction_scores, seq_relationship_score, + masked_lm_ids, masked_lm_weights, next_sentence_labels) + return self.cast(total_loss, mstype.float32) + + +class BertTrainOneStepCell(nn.TrainOneStepCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True. + """ + + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True): + super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + self.enable_clip_grad = enable_clip_grad + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Defines the computation performed.""" + weights = self.weights + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + if self.enable_clip_grad: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + grads = self.grad_reducer(grads) + self.optimizer(grads) + return loss + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + + +class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + degree_sens = self.cast(scaling_sens * self.degree, mstype.float32) + grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + cond = self.get_overflow_status(status, grads) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if not overflow: + self.optimizer(grads) + return (loss, cond, scaling_sens) + + +class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow + condition as input. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + cond = self.get_overflow_status(status, grads) + overflow = cond + if self.loss_scaling_manager is not None: + overflow = self.loss_scaling_manager(scaling_sens, cond) + self.optimizer(grads, overflow) + return (loss, cond, scaling_sens) + + +cast = P.Cast() +add_grads = C.MultitypeFuncGraph("add_grads") + + +@add_grads.register("Tensor", "Tensor") +def _add_grads(accu_grad, grad): + return accu_grad + cast(grad, mstype.float32) + + +update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") + + +@update_accu_grads.register("Tensor", "Tensor") +def _update_accu_grads(accu_grad, grad): + succ = True + return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32))) + + +accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") + + +@accumulate_accu_grads.register("Tensor", "Tensor") +def _accumulate_accu_grads(accu_grad, grad): + succ = True + return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) + + +zeroslike = P.ZerosLike() +reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") + + +@reset_accu_grads.register("Tensor") +def _reset_accu_grads(accu_grad): + succ = True + return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) + + +class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will only be implemented in the weight updated step, + i.e. the sub-step after gradients accumulated N times. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): + super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32)) + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_status = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) + mean_loss = F.depend(mean_loss, accu_succ) + + init = F.depend(init, mean_loss) + get_status = self.get_status(init) + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + overflow = self.less_equal(self.base, flag_sum) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + + if not is_accu_step: + # apply grad reducer on grads + grads = self.grad_reducer(self.accu_grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), grads) + if self.enable_global_norm: + grads = C.clip_by_global_norm(grads, 1.0, None) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + accu_overflow = F.depend(accu_overflow, grads) + accu_overflow = self.overflow_reducer(accu_overflow) + overflow = self.less_equal(self.base, accu_overflow) + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + overflow = F.depend(overflow, accu_succ) + overflow = self.reshape(overflow, (())) + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if not overflow: + self.optimizer(grads) + + return (mean_loss, overflow, scaling_sens) + + +class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will be implemented after each sub-step and the trailing time + will be overided by backend optimization pass. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): + super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32)) + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) + grads = self.grad_reducer(grads) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + flag_reduce = self.overflow_reducer(flag_sum) + overflow = self.less_equal(self.base, flag_reduce) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + overflow = self.reshape(overflow, (())) + + if is_accu_step: + succ = False + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) + succ = F.depend(succ, accu_succ) + else: + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if overflow: + succ = False + else: + if self.enable_global_norm: + grads = C.clip_by_global_norm(grads, 1.0, None) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + succ = self.optimizer(grads) + + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + succ = F.depend(succ, accu_succ) + + ret = (mean_loss, overflow, scaling_sens) + return F.depend(ret, succ) + + +class BertNetworkMatchBucket(nn.Cell): + ''' + Bert execute according to different sentence lengths. + ''' + + def __init__(self, network, seq_length, bucket_list=None): + super(BertNetworkMatchBucket, self).__init__() + self.network = network + if not bucket_list or not isinstance(bucket_list, list): + bucket_list = [seq_length] + self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length] + + if network.reducer_flag: + reuse_attr = 'reuse_communication_node' + if not network.grad_reducer.split_fusion: + hccl_op = network.grad_reducer.allreduce + network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) + else: + new_op_list = [] + for hccl_op in network.grad_reducer.op_list: + new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) + new_op_list.append(new_op) + network.grad_reducer.op_list = new_op_list + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sentence_flag): + """Switch network according to sentence length.""" + for bucket in self.bucket_list: + if sentence_flag == bucket: + input_ids = input_ids[:, :bucket] + input_mask = input_mask[:, :bucket] + token_type_id = token_type_id[:, :bucket] + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + return loss + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + return loss + + +class BertPretrainEval(nn.Cell): + ''' + Evaluate MaskedLM prediction scores + ''' + + def __init__(self, config, network=None): + super(BertPretrainEval, self).__init__(auto_prefix=False) + if network is None: + self.network = BertPreTraining(config, False, False) + else: + self.network = network + self.argmax = P.Argmax(axis=-1, output_type=mstype.int32) + self.equal = P.Equal() + self.sum = P.ReduceSum() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.cast = P.Cast() + self.allreduce = P.AllReduce() + self.reduce_flag = False + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reduce_flag = True + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Calculate prediction scores""" + bs, _ = self.shape(input_ids) + mlm, _ = self.network(input_ids, input_mask, token_type_id, masked_lm_positions) + index = self.argmax(mlm) + index = self.reshape(index, (bs, -1)) + eval_acc = self.equal(index, masked_lm_ids) + eval_acc = self.cast(eval_acc, mstype.float32) + real_acc = eval_acc * masked_lm_weights + acc = self.sum(real_acc) + total = self.sum(masked_lm_weights) + + if self.reduce_flag: + acc = self.allreduce(acc) + total = self.allreduce(total) + + return acc, total diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_model.py b/tests/st/networks/models/bert/bert_performance/src/bert_model.py new file mode 100644 index 00000000000..77d2c9e441f --- /dev/null +++ b/tests/st/networks/models/bert/bert_performance/src/bert_model.py @@ -0,0 +1,857 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" + +import copy +import math +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + dtype=mstype.float32, + compute_type=mstype.float32): + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size])) + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.Gather() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + """Get output and embeddings lookup table""" + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + + def __init__(self, + embedding_size, + embedding_shape, + use_relative_positions=False, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.token_type_embedding = nn.Embedding( + vocab_size=token_type_vocab_size, + embedding_size=embedding_size, + use_one_hot=use_one_hot_embeddings) + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.Gather() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + _, seq, _ = self.shape + self.full_position_embedding = nn.Embedding( + vocab_size=max_position_embeddings, + embedding_size=embedding_size, + use_one_hot=False) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) + self.add = P.Add() + + def construct(self, token_type_ids, word_embeddings): + """Postprocessors apply positional and token type embeddings to word embeddings.""" + output = word_embeddings + if self.use_token_type: + token_type_embeddings = self.token_type_embedding(token_type_ids) + output = self.add(output, token_type_embeddings) + if not self.use_relative_positions: + shape = F.shape(output) + position_ids = self.position_ids[:, :shape[1]] + position_embeddings = self.full_position_embedding(position_ids) + output = self.add(output, position_embeddings) + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.dropout_prob = dropout_prob + self.add = P.Add() + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + + def __init__(self, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position + + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self, length): + """Generates matrix of relative positions between inputs.""" + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (length, -1)) + tile_row_out = self.tile(range_vec_row_out, (length,)) + tile_col_out = self.tile(range_vec_col_out, (1, length)) + range_mat_out = self.range_mat(tile_row_out, (length, length)) + transpose_out = self.range_mat(tile_col_out, (length, length)) + distance_mat = self.sub(range_mat_out, transpose_out) + + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth])) + + self.relative_positions_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = nn.OneHot(depth=self.vocab_size) + self.shape = P.Shape() + self.gather = P.Gather() # index_select + self.matmul = P.BatchMatMul() + + def construct(self, length): + """Generate embedding for each relative position of dimension depth.""" + relative_positions_matrix_out = self.relative_positions_matrix(length) + + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + + self.tensor_min_type = float(np.finfo(np_type).min) + self.tensor_max_type = float(np.finfo(np_type).max) + + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + + def __init__(self, + from_tensor_width, + to_tensor_width, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + use_relative_positions=False, + compute_type=mstype.float32): + + super(BertAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = -10000.0 + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.Add() + self.cast = P.Cast() + self.get_dtype = P.DType() + + self.shape_return = (-1, num_attention_heads * size_per_head) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + """reshape 2d/3d input tensors to 2d""" + shape_from = F.shape(attention_mask)[2] + from_tensor = F.depend(from_tensor, shape_from) + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + relations_keys = self._generate_relative_positions_embeddings(shape_from) + relations_keys = self.cast_compute_type(relations_keys) + + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + + query_layer_r = self.reshape(query_layer_t, + (shape_from, + -1, + self.size_per_head)) + + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + + key_position_scores_r = self.reshape(key_position_scores, + (shape_from, + -1, + self.num_attention_heads, + shape_from)) + + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = self.multiply(self.scores_mul, attention_scores) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + relations_values = self._generate_relative_positions_embeddings(shape_from) + relations_values = self.cast_compute_type(relations_values) + + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + + attention_probs_r = self.reshape( + attention_probs_t, + (shape_from, + -1, + shape_from)) + + value_position_scores = self.matmul(attention_probs_r, + relations_values) + + value_position_scores_r = self.reshape(value_position_scores, + (shape_from, + -1, + self.num_attention_heads, + self.size_per_head)) + + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + + return context_layer + + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + + def __init__(self, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + + self.size_per_head = int(hidden_size / num_attention_heads) + + self.attention = BertAttention( + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + compute_type=compute_type) + + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + attention_output = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + + def __init__(self, + hidden_size=768, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask): + # self-attention + attention_output = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + hidden_size (int): Size of the encoder layers. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + + def __init__(self, + hidden_size, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + + self.layers = nn.CellList(layers) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + """Multi-layer bert transformer.""" + prev_output = self.reshape(input_tensor, self.shape) + + all_encoder_layers = () + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + if self.return_all_encoders: + shape = F.shape(input_tensor) + layer_output = self.reshape(layer_output, shape) + all_encoder_layers = all_encoder_layers + (layer_output,) + + if not self.return_all_encoders: + shape = F.shape(input_tensor) + prev_output = self.reshape(prev_output, shape) + all_encoder_layers = all_encoder_layers + (prev_output,) + return all_encoder_layers + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask = None + + self.cast = P.Cast() + self.reshape = P.Reshape() + + def construct(self, input_mask): + seq_length = F.shape(input_mask)[1] + attention_mask = self.cast(self.reshape(input_mask, (-1, 1, seq_length)), mstype.float32) + return attention_mask + + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [-1, config.seq_length, self.embedding_size] + + self.bert_embedding_lookup = nn.Embedding( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + use_one_hot=use_one_hot_embeddings, + embedding_table=TruncatedNormal(config.initializer_range)) + + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + self.bert_encoder = BertTransformer( + hidden_size=self.hidden_size, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """Bidirectional Encoder Representations from Transformers.""" + # embedding + embedding_tables = self.bert_embedding_lookup.embedding_table + word_embeddings = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, + word_embeddings) + + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + + # bert encoder + encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + + # pooler + batch_size = P.Shape()(input_ids)[0] + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + + return sequence_output, pooled_output, embedding_tables diff --git a/tests/st/networks/models/bert/bert_performance/src/utils.py b/tests/st/networks/models/bert/bert_performance/src/utils.py new file mode 100644 index 00000000000..b009dbf6b77 --- /dev/null +++ b/tests/st/networks/models/bert/bert_performance/src/utils.py @@ -0,0 +1,274 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Functional Cells used in Bert finetune and evaluation. +""" + +import os +import collections +import math +import numpy as np +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.common import dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +from mindspore.nn.metrics import Metric +from mindspore.ops import operations as P +from mindspore.train.callback import Callback + + +class CrossEntropyCalculation(nn.Cell): + """ + Cross Entropy loss + """ + + def __init__(self, is_training=True): + super(CrossEntropyCalculation, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + self.is_training = is_training + + def construct(self, logits, label_ids, num_labels): + if self.is_training: + label_ids = self.reshape(label_ids, self.last_idx) + one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) + loss = self.reduce_mean(per_example_loss, self.last_idx) + return_value = self.cast(loss, mstype.float32) + else: + return_value = logits * 1.0 + return return_value + + +def make_directory(path: str): + """Make directory.""" + if path is None or not isinstance(path, str) or path.strip() == "": + logger.error("The path(%r) is invalid type.", path) + raise TypeError("Input path is invalid type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The abs path is %r", path) + + # check the path is exist and write permissions? + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path, exist_ok=True) + real_path = path + except PermissionError as e: + logger.error("No write permission on the directory(%r), error = %r", path, e) + raise TypeError("No write permission on the directory.") + return real_path + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, dataset_size=-1): + super(LossCallBack, self).__init__() + self._dataset_size = dataset_size + + def step_end(self, run_context): + """ + Print loss after each step + """ + cb_params = run_context.original_args() + if self._dataset_size > 0: + percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size) + if percent == 0: + percent = 1 + epoch_num -= 1 + print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" + .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)), + flush=True) + else: + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs)), flush=True) + + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr + + +def convert_labels_to_index(label_list): + """ + Convert label_list to indices for NER task. + """ + label2id = collections.OrderedDict() + label2id["O"] = 0 + prefix = ["S_", "B_", "M_", "E_"] + index = 0 + for label in label_list: + for pre in prefix: + index += 1 + sub_label = pre + label + label2id[sub_label] = index + return label2id + + +def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power): + """ + generate learning rate array + + Args: + global_step(int): current step + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_steps(int): number of warmup epochs + total_steps(int): total epoch of training + poly_power(int): poly learning rate power + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max - lr_end) * (base ** poly_power) + lr = lr + lr_end + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + current_step = global_step + learning_rate = learning_rate[current_step:] + return learning_rate + + +def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000): + learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0, + total_steps=lr_total_steps, poly_power=lr_power) + return Tensor(learning_rate) + + +def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000): + damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0, + total_steps=damping_total_steps, poly_power=damping_power) + return Tensor(damping) + + +class EvalCallBack(Callback): + """ + Evaluate after a certain amount of training samples. + Args: + model (Model): The network model. + eval_ds (Dataset): The eval dataset. + global_batch (int): The batchsize of the sum of all devices. + eval_samples (int): The number of eval interval samples. + """ + + def __init__(self, model, eval_ds, global_batch, eval_samples): + super(EvalCallBack, self).__init__() + self.model = model + self.eval_ds = eval_ds + self.global_batch = global_batch + self.eval_samples = eval_samples + self.last_eval_step = 0 + + def epoch_end(self, run_context): + """ + Evaluate after training a certain number of samples. + """ + cb_params = run_context.original_args() + num_samples = (cb_params.cur_step_num - self.last_eval_step) * self.global_batch + if num_samples < self.eval_samples: + return + self.last_eval_step = cb_params.cur_step_num + total_sumples = cb_params.cur_step_num * self.global_batch + res = self.model.eval(self.eval_ds, dataset_sink_mode=True) + res = res['bert_acc'] + print("====================================", flush=True) + print("Accuracy is: ", "%.6f" % res, ", current samples is: ", total_sumples) + print("====================================", flush=True) + + +class BertMetric(Metric): + """ + The metric of bert network. + Args: + batch_size (int): The batchsize of each device. + """ + + def __init__(self, batch_size): + super(BertMetric, self).__init__() + self.clear() + self.batch_size = batch_size + + def clear(self): + self.mlm_total = 0 + self.mlm_acc = 0 + + def update(self, *inputs): + mlm_acc = self._convert_data(inputs[0]) + mlm_total = self._convert_data(inputs[1]) + self.mlm_acc += mlm_acc + self.mlm_total += mlm_total + + def eval(self): + return self.mlm_acc / self.mlm_total diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py index 2936488e0bf..dedcd714448 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py @@ -33,9 +33,11 @@ from mindspore.train.model import Model from mindspore.train.train_thor import ConvertModelUtils import mindspore.dataset.transforms.c_transforms as C -from tests.models.official.nlp.bert.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell -from tests.models.official.nlp.bert.src.utils import get_bert_thor_lr, get_bert_thor_damping -from tests.models.official.nlp.bert.src.bert_model import BertConfig +from tests.st.networks.models.bert.bert_performance.src.bert_for_pre_training import BertNetworkWithLoss, \ + BertTrainOneStepCell + +from tests.st.networks.models.bert.bert_performance.src.utils import get_bert_thor_lr, get_bert_thor_damping +from tests.st.networks.models.bert.bert_performance.src.bert_model import BertConfig MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json" DATASET_PATH = "/home/workspace/mindspore_dataset/bert/thor/en-wiki-512_test_first1wan" @@ -188,7 +190,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): q.put({'loss': loss_list, 'cost': per_step_mseconds}) -@pytest.mark.level1 +@pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_single