From e071f04d4bade82d41412b7701f6edc2b6c1f349 Mon Sep 17 00:00:00 2001 From: wsc Date: Mon, 20 Apr 2020 11:19:25 +0800 Subject: [PATCH] Add ST test script of bert with loss scale --- .../Bert_NEZHA/bert_for_pre_training.py | 2 +- .../models/bert/bert_tdt_lossscale.py | 198 ++++++++++++++++++ 2 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 tests/st/networks/models/bert/bert_tdt_lossscale.py diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py index 046b2adbe2c..53a0d039330 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py @@ -445,5 +445,5 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): succ = False else: succ = self.optimizer(grads) - ret = (loss, cond) + ret = (loss, cond, scaling_sens) return F.depend(ret, succ) diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py new file mode 100644 index 00000000000..cfd0b556975 --- /dev/null +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train bert network without lossscale""" + +import os +import pytest +import numpy as np +from numpy import allclose +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import Callback, LossMonitor +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell +from mindspore.nn.optim import Momentum +from mindspore import log as logger +_current_dir = os.path.dirname(os.path.realpath(__file__)) +DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] +SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" + +def get_config(version='base', batch_size=1): + """get config""" + if version == 'base': + bert_config = BertConfig( + batch_size=batch_size, + seq_length=128, + vocab_size=21136, + hidden_size=768, + num_hidden_layers=2, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32) + elif version == 'large': + bert_config = BertConfig( + batch_size=batch_size, + seq_length=128, + vocab_size=21136, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16) + elif version == 'large_mixed': + bert_config = BertConfig( + batch_size=batch_size, + seq_length=128, + vocab_size=21136, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32) + else: + bert_config = BertConfig(batch_size=batch_size) + return bert_config + +def me_de_train_dataset(): + """test me de train dataset""" + # apply repeat operations + repeat_count = 1 + ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", + "next_sentence_labels", "masked_lm_positions", + "masked_lm_ids", "masked_lm_weights"], shuffle=False) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + # apply batch operations + batch_size = int(os.getenv('BATCH_SIZE', '16')) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_count) + return ds + +def weight_variable(shape): + """weight variable""" + np.random.seed(1) + ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) + return Tensor(ones) + +class ModelCallback(Callback): + def __init__(self): + super(ModelCallback, self).__init__() + self.loss_list = [] + self.overflow_list = [] + self.lossscale_list = [] + + def step_end(self, run_context): + cb_params = run_context.original_args() + self.loss_list.append(cb_params.net_outputs[0]) + self.overflow_list.append(cb_params.net_outputs[1]) + self.lossscale_list.append(cb_params.net_outputs[2]) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_bert_tdt(): + """test bert tdt""" + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) + context.set_context(enable_task_sink=True) + context.set_context(enable_loop_sink=True) + context.set_context(enable_mem_reuse=True) + ds = me_de_train_dataset() + version = os.getenv('VERSION', 'large') + batch_size = int(os.getenv('BATCH_SIZE', '16')) + config = get_config(version=version, batch_size=batch_size) + netwithloss = BertNetworkWithLoss(config, True) + optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) + scale_window = 3 + scale_manager = DynamicLossScaleManager(2**32, 2, scale_window) + netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) + netwithgrads.set_train(True) + model = Model(netwithgrads) + callback = ModelCallback() + params = netwithloss.trainable_params() + for param in params: + value = param.default_input + name = param.name + if isinstance(value, Tensor): + if name.split('.')[-1] in ['weight']: + if name.split('.')[-3] in ['cls2']: + logger.info("***************** BERT param name is 1 {}".format(name)) + param.default_input = weight_variable(value.asnumpy().shape) + else: + logger.info("***************** BERT param name is 2 {}".format(name)) + tempshape = value.asnumpy().shape + shape = (tempshape[1], tempshape[0]) + weight_value = weight_variable(shape).asnumpy() + param.default_input = Tensor(np.transpose(weight_value, [1, 0])) + else: + logger.info("***************** BERT param name is 3 {}".format(name)) + param.default_input = weight_variable(value.asnumpy().shape) + model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) + + # assertion occurs while the loss_scale value is wrong + count = 0 + for i in range(len(callback.overflow_list)): + if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0: + count = 0 + assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(0.5, mstype.float32) + if callback.overflow_list[i] == Tensor(False, mstype.bool_): + count = count + 1 + if count == scale_window: + count = 0 + assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32) + + +if __name__ == '__main__': + test_bert_tdt()