From b2412ea571b854172c114a0feaa30cc0ecfde4cd Mon Sep 17 00:00:00 2001 From: VectorSL Date: Sat, 20 Jun 2020 10:52:21 +0800 Subject: [PATCH] gpu update bert scripts --- model_zoo/bert/evaluation.py | 18 ++++++++++++++++-- model_zoo/bert/finetune.py | 4 ++++ tests/st/networks/test_gpu_alexnet.py | 4 ++-- tests/st/networks/test_gpu_lenet.py | 4 ++-- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/model_zoo/bert/evaluation.py b/model_zoo/bert/evaluation.py index 320c99cc59a..4877b60cef3 100644 --- a/model_zoo/bert/evaluation.py +++ b/model_zoo/bert/evaluation.py @@ -18,9 +18,11 @@ Bert evaluation script. """ import os +import argparse import numpy as np import mindspore.common.dtype as mstype from mindspore import context +from mindspore import log as logger from mindspore.common.tensor import Tensor import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C @@ -105,8 +107,17 @@ def bert_predict(Evaluation): ''' prediction function ''' - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) + target = args_opt.device_target + if target == "Ascend": + devid = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") dataset = get_dataset(bert_net_cfg.batch_size, 1) if cfg.use_crf: net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, @@ -147,6 +158,9 @@ def test_eval(): callback.acc_num / callback.total_num)) print("==============================================================") +parser = argparse.ArgumentParser(description='Bert eval') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +args_opt = parser.parse_args() if __name__ == "__main__": num_labels = cfg.num_labels test_eval() diff --git a/model_zoo/bert/finetune.py b/model_zoo/bert/finetune.py index 94b1931bc65..df16e3c91d9 100644 --- a/model_zoo/bert/finetune.py +++ b/model_zoo/bert/finetune.py @@ -23,6 +23,7 @@ from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCe from src.finetune_config import cfg, bert_net_cfg, tag_to_index import mindspore.common.dtype as mstype from mindspore import context +from mindspore import log as logger import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell @@ -105,6 +106,9 @@ def test_train(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 else: raise Exception("Target error, GPU or Ascend is supported.") #BertCLSTrain for classification diff --git a/tests/st/networks/test_gpu_alexnet.py b/tests/st/networks/test_gpu_alexnet.py index 10099814642..7a55006571e 100644 --- a/tests/st/networks/test_gpu_alexnet.py +++ b/tests/st/networks/test_gpu_alexnet.py @@ -86,6 +86,6 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32): for i in range(0, epoch): data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) - loss = train_network(data, label) + loss = train_network(data, label).asnumpy() losses.append(loss) - assert (losses[-1].asnumpy() < 0.01) + assert losses[-1] < 0.01 diff --git a/tests/st/networks/test_gpu_lenet.py b/tests/st/networks/test_gpu_lenet.py index 038af922239..d1ca32c9edf 100644 --- a/tests/st/networks/test_gpu_lenet.py +++ b/tests/st/networks/test_gpu_lenet.py @@ -144,9 +144,9 @@ def test_train_lenet(): for i in range(epoch): data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([net.batch_size]).astype(np.int32)) - loss = train_network(data, label) + loss = train_network(data, label).asnumpy() losses.append(loss) - print(losses) + assert losses[-1] < 0.01 def create_dataset(data_path, batch_size=32, repeat_size=1,