!2388 GPU update bert scripts

Merge pull request !2388 from VectorSL/bert-eval
This commit is contained in:
mindspore-ci-bot 2020-06-20 15:26:21 +08:00 committed by Gitee
commit 748b0b1b64
4 changed files with 24 additions and 6 deletions

View File

@ -18,9 +18,11 @@ Bert evaluation script.
""" """
import os import os
import argparse
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.dataset as de import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
@ -105,8 +107,17 @@ def bert_predict(Evaluation):
''' '''
prediction function prediction function
''' '''
devid = int(os.getenv('DEVICE_ID')) target = args_opt.device_target
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) if target == "Ascend":
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
if bert_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
else:
raise Exception("Target error, GPU or Ascend is supported.")
dataset = get_dataset(bert_net_cfg.batch_size, 1) dataset = get_dataset(bert_net_cfg.batch_size, 1)
if cfg.use_crf: if cfg.use_crf:
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
@ -147,6 +158,9 @@ def test_eval():
callback.acc_num / callback.total_num)) callback.acc_num / callback.total_num))
print("==============================================================") print("==============================================================")
parser = argparse.ArgumentParser(description='Bert eval')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
num_labels = cfg.num_labels num_labels = cfg.num_labels
test_eval() test_eval()

View File

@ -23,6 +23,7 @@ from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCe
from src.finetune_config import cfg, bert_net_cfg, tag_to_index from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger
import mindspore.dataset as de import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
@ -105,6 +106,9 @@ def test_train():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
elif target == "GPU": elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
if bert_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
else: else:
raise Exception("Target error, GPU or Ascend is supported.") raise Exception("Target error, GPU or Ascend is supported.")
#BertCLSTrain for classification #BertCLSTrain for classification

View File

@ -86,6 +86,6 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32):
for i in range(0, epoch): for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01) data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32)) label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label) loss = train_network(data, label).asnumpy()
losses.append(loss) losses.append(loss)
assert (losses[-1].asnumpy() < 0.01) assert losses[-1] < 0.01

View File

@ -144,9 +144,9 @@ def test_train_lenet():
for i in range(epoch): for i in range(epoch):
data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([net.batch_size]).astype(np.int32)) label = Tensor(np.ones([net.batch_size]).astype(np.int32))
loss = train_network(data, label) loss = train_network(data, label).asnumpy()
losses.append(loss) losses.append(loss)
print(losses) assert losses[-1] < 0.01
def create_dataset(data_path, batch_size=32, repeat_size=1, def create_dataset(data_path, batch_size=32, repeat_size=1,