From 56fa56b17305a750ccf1cb7d2e2072ebb71c2c3b Mon Sep 17 00:00:00 2001 From: looop5 Date: Sat, 12 Dec 2020 09:34:53 +0800 Subject: [PATCH] add graph kernel testcases --- .../_extends/graph_kernel/expanders/tile.py | 7 +- .../bert_precision/test_bert_tdt_lossscale.py | 34 ++- .../models/bert/test_bert_graph_kernel.py | 215 ------------------ .../test_clip_by_norm_no_div_sum.py | 69 ++++++ tests/st/ops/graph_kernel/test_sqrt_grad.py | 60 +++++ tests/st/ops/graph_kernel/test_tile.py | 59 +++++ 6 files changed, 216 insertions(+), 228 deletions(-) delete mode 100644 tests/st/networks/models/bert/test_bert_graph_kernel.py create mode 100644 tests/st/ops/graph_kernel/test_clip_by_norm_no_div_sum.py create mode 100644 tests/st/ops/graph_kernel/test_sqrt_grad.py create mode 100644 tests/st/ops/graph_kernel/test_tile.py diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py index 4a2638dc073..7f1ae7c6ca8 100644 --- a/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -65,8 +65,7 @@ def expand_tile(expand_info): for item in attrs: if 'multiples' in item: multiples = item['multiples'] - output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'], - multiples) + output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples) graph_builder = builder.GraphBuilder() # generate a graph. @@ -77,9 +76,7 @@ def expand_tile(expand_info): if shape_compatible: result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) else: - input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape}) - reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape}) - result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape}) + result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) # set graph output. graph_scope.set_output(result) diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index 5d4c4a2297e..b431a554c53 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -35,7 +35,6 @@ from model_zoo.official.nlp.bert.src.bert_for_pre_training import BertNetworkWit from model_zoo.official.nlp.bert.src.bert_for_pre_training import BertTrainOneStepWithLossScaleCell from model_zoo.official.nlp.bert.src.bert_model import BertConfig - _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" @@ -155,13 +154,16 @@ class ModelCallback(Callback): self.lossscale_list.append(cb_params.net_outputs[2].asnumpy()) print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) + class TimeMonitor(Callback): """Time Monitor.""" + def __init__(self, data_size): super(TimeMonitor, self).__init__() self.data_size = data_size self.epoch_mseconds_list = [] self.per_step_mseconds_list = [] + def epoch_begin(self, run_context): self.epoch_time = time.time() @@ -170,18 +172,17 @@ class TimeMonitor(Callback): self.epoch_mseconds_list.append(epoch_mseconds) self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_bert_percision(): + +def test_bert_percision(enable_graph_kernel=False): """test bert percision""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) ds, new_repeat_count, _ = me_de_train_dataset() version = os.getenv('VERSION', 'large') config = get_config(version=version) netwithloss = BertNetworkWithLoss(config, True) - lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, + lr = BertLearningRate(decay_steps=ds.get_dataset_size() * new_repeat_count, learning_rate=5e-5, end_learning_rate=1e-9, power=10.0, warmup_steps=0) decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() @@ -239,5 +240,22 @@ def test_bert_percision(): assert np.allclose(loss_scale, expect_loss_scale, 0, 0) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_bert_percision_graph_kernel_off(): + test_bert_percision(enable_graph_kernel=False) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_bert_percision_graph_kernel_on(): + test_bert_percision(enable_graph_kernel=True) + + if __name__ == '__main__': - test_bert_percision() + test_bert_percision(enable_graph_kernel=False) + test_bert_percision(enable_graph_kernel=True) diff --git a/tests/st/networks/models/bert/test_bert_graph_kernel.py b/tests/st/networks/models/bert/test_bert_graph_kernel.py deleted file mode 100644 index 47d8f5d2469..00000000000 --- a/tests/st/networks/models/bert/test_bert_graph_kernel.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""train bert network without lossscale""" - -import os - -import numpy as np -from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell -from src.bert_model import BertConfig - -import mindspore.common.dtype as mstype -import mindspore.dataset.engine.datasets as de -import mindspore.dataset.transforms.c_transforms as C -from mindspore import context -from mindspore import log as logger -from mindspore.common.tensor import Tensor -from mindspore.nn import learning_rate_schedule as lr_schedules -from mindspore.nn.optim import Lamb -from mindspore.train.callback import Callback -from mindspore.train.loss_scale_manager import DynamicLossScaleManager -from mindspore.train.model import Model - -DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] -SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" - -def get_config(version='base', batch_size=1): - """get config""" - if version == 'base': - bert_config = BertConfig( - batch_size=batch_size, - seq_length=128, - vocab_size=21136, - hidden_size=768, - num_hidden_layers=2, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float32) - elif version == 'large': - bert_config = BertConfig( - batch_size=batch_size, - seq_length=128, - vocab_size=30522, - hidden_size=1024, - num_hidden_layers=2, - num_attention_heads=16, - intermediate_size=4096, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, - enable_fused_layernorm=True) - else: - bert_config = BertConfig(batch_size=batch_size) - return bert_config - - -def me_de_train_dataset(): - """test me de train dataset""" - # apply repeat operations - repeat_count = 1 - ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", - "next_sentence_labels", "masked_lm_positions", - "masked_lm_ids", "masked_lm_weights"], shuffle=False) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(operations=type_cast_op, input_columns="masked_lm_ids") - ds = ds.map(operations=type_cast_op, input_columns="masked_lm_positions") - ds = ds.map(operations=type_cast_op, input_columns="next_sentence_labels") - ds = ds.map(operations=type_cast_op, input_columns="segment_ids") - ds = ds.map(operations=type_cast_op, input_columns="input_mask") - ds = ds.map(operations=type_cast_op, input_columns="input_ids") - # apply batch operations - batch_size = int(os.getenv('BATCH_SIZE', '16')) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.repeat(repeat_count) - return ds - - -def weight_variable(shape): - """weight variable""" - np.random.seed(1) - ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) - return Tensor(ones) - - -class BertLearningRate(lr_schedules.LearningRateSchedule): - def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): - super(BertLearningRate, self).__init__() - self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) - self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() - - def construct(self, global_step): - is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) - warmup_lr = self.warmup_lr(global_step) - decay_lr = self.decay_lr(global_step) - lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr - return lr - - -class ModelCallback(Callback): - def __init__(self): - super(ModelCallback, self).__init__() - self.loss_list = [] - self.overflow_list = [] - self.lossscale_list = [] - - def step_end(self, run_context): - cb_params = run_context.original_args() - self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0]) - self.overflow_list.append(cb_params.net_outputs[1].asnumpy()) - self.lossscale_list.append(cb_params.net_outputs[2].asnumpy()) - print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) - - -def test_bert_tdt(): - """test bert tdt""" - np.random.seed(0) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) - context.set_context(enable_graph_kernel=True) - ds = me_de_train_dataset() - config = get_config(version='large', batch_size=16) - netwithloss = BertNetworkWithLoss(config, True) - lr = BertLearningRate(decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), learning_rate=5e-5, - end_learning_rate=1e-9, power=10.0, warmup_steps=0) - decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower() - no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower() - decay_params = list(filter(decay_filter, net_with_loss.trainable_params())) - other_params = list(filter(no_decay_filter, net_with_loss.trainable_params())) - group_params = [{'params': decay_params, 'weight_decay': 0.01}, - {'params': other_params}] - optimizer = Lamb(group_params, lr) - scale_window = 3 - scale_manager = DynamicLossScaleManager(262144, 2, scale_window) - netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, - scale_update_cell=scale_manager.get_update_cell()) - netwithgrads.set_train(True) - model = Model(netwithgrads) - callback = ModelCallback() - netwithloss.init_parameters_data() - params = netwithloss.trainable_params() - for param in params: - value = param.data - name = param.name - if isinstance(value, Tensor): - if name.split('.')[-1] in ['weight']: - if name.split('.')[-3] in ['cls2']: - logger.info("***************** BERT param name is 1 {}".format(name)) - param.set_data(weight_variable(value.asnumpy().shape)) - else: - logger.info("***************** BERT param name is 2 {}".format(name)) - tempshape = value.asnumpy().shape - shape = (tempshape[1], tempshape[0]) - weight_value = weight_variable(shape).asnumpy() - param.set_data(Tensor(np.transpose(weight_value, [1, 0]))) - else: - logger.info("***************** BERT param name is 3 {}".format(name)) - param.set_data(weight_variable(value.asnumpy().shape)) - model.train(1, ds, callbacks=callback, dataset_sink_mode=False) - - # assertion occurs while the loss value, overflow state or loss_scale value is wrong - loss_value = np.array(callback.loss_list) - expect_loss_value = [12.559319, 12.333815, 12.339806, 12.350235, 12.343947, 12.830965, 12.375336, 12.973715, - 12.57929, 12.7766905] - error = loss_value - expect_loss_value - print("loss value: {}".format(loss_value)) - print("error value: {}".format(error)) - assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) - - overflow = np.array(callback.overflow_list) - expect_overflow = [True, True, True, True, False, False, False, True, False, False] - print("overflow: {}".format(overflow)) - assert (overflow == expect_overflow).all() - - loss_scale = np.array(callback.lossscale_list) - expect_loss_scale = [131072.0, 65536.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0] - print("loss scale: {}".format(loss_scale)) - assert np.allclose(loss_scale, expect_loss_scale, 0, 0) - - -if __name__ == '__main__': - test_bert_tdt() diff --git a/tests/st/ops/graph_kernel/test_clip_by_norm_no_div_sum.py b/tests/st/ops/graph_kernel/test_clip_by_norm_no_div_sum.py new file mode 100644 index 00000000000..2143374a8f4 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_clip_by_norm_no_div_sum.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class ClipByNormNoDivSum(nn.Cell): + def __init__(self): + super(ClipByNormNoDivSum, self).__init__() + self.greater = P.Greater() + self.select = P.Select() + self.sqrt = P.Sqrt() + self.maximum = P.Maximum() + + def construct(self, i0, i1, i2, i3): + greater_res = self.greater(i0, i1) + select_res0 = self.select(greater_res, i0, i2) + sqrt_res = self.sqrt(select_res0) + select_res1 = self.select(greater_res, sqrt_res, i0) + res = self.maximum(select_res1, i3) + return res + + +def get_output(x0, x1, x2, x3, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = ClipByNormNoDivSum() + output = net(x0, x1, x2, x3) + return output + + +def test_clip_by_norm_no_div_sum(shape0, shape1, shape2, shape3, dtype): + x0 = Tensor(np.random.normal(0, 1, shape0).astype(dtype)) + x1 = Tensor(np.zeros(shape1, dtype)) + x2 = Tensor(np.ones(shape2, dtype)) + x3 = Tensor(np.ones(shape3, dtype)) + + expect = get_output(x0, x1, x2, x3, False) + output = get_output(x0, x1, x2, x3, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_clip_by_norm_no_div_sum_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_clip_by_norm_no_div_sum((1, 1), (1,), (1, 1), (1,), np.float32) diff --git a/tests/st/ops/graph_kernel/test_sqrt_grad.py b/tests/st/ops/graph_kernel/test_sqrt_grad.py new file mode 100644 index 00000000000..a200ae7d974 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_sqrt_grad.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.sqrt_grad = G.SqrtGrad() + + def construct(self, x, dout): + return self.sqrt_grad(x, dout) + + +def get_output(x, dout, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net() + output = net(x, dout) + return output + + +def test_sqrt_grad(shape_x, shape_dout, dtype): + x = Tensor(np.random.normal(0, 1, shape_x).astype(dtype)) + dout = Tensor(np.random.normal(0, 1, shape_dout).astype(dtype)) + + expect = get_output(x, dout, False) + output = get_output(x, dout, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sqrt_grad_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_sqrt_grad((16, 16), (16, 16), np.float16) + test_sqrt_grad((16, 16), (16, 16), np.float32) diff --git a/tests/st/ops/graph_kernel/test_tile.py b/tests/st/ops/graph_kernel/test_tile.py new file mode 100644 index 00000000000..7ca00b42b71 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_tile.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, multiples): + super(Net, self).__init__() + self.tile = P.Tile() + self.multiples = multiples + + def construct(self, x): + return self.tile(x, self.multiples) + + +def get_output(x, multiples, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net(multiples) + output = net(x) + return output + + +def test_tile(shape, dtype, multiples): + x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + expect = get_output(x, multiples, False) + output = get_output(x, multiples, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tile_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_tile((24, 1), np.float16, (2, 2, 2)) + test_tile((24, 1), np.float32, (1, 2))