gpu update bert scripts

This commit is contained in:
VectorSL 2020-06-20 10:52:21 +08:00
parent bc4b1c2460
commit b2412ea571
4 changed files with 24 additions and 6 deletions

View File

@ -18,9 +18,11 @@ Bert evaluation script.
"""
import os
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
@ -105,8 +107,17 @@ def bert_predict(Evaluation):
'''
prediction function
'''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
target = args_opt.device_target
if target == "Ascend":
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
if bert_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
else:
raise Exception("Target error, GPU or Ascend is supported.")
dataset = get_dataset(bert_net_cfg.batch_size, 1)
if cfg.use_crf:
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
@ -147,6 +158,9 @@ def test_eval():
callback.acc_num / callback.total_num))
print("==============================================================")
parser = argparse.ArgumentParser(description='Bert eval')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
if __name__ == "__main__":
num_labels = cfg.num_labels
test_eval()

View File

@ -23,6 +23,7 @@ from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCe
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
@ -105,6 +106,9 @@ def test_train():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
if bert_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
else:
raise Exception("Target error, GPU or Ascend is supported.")
#BertCLSTrain for classification

View File

@ -86,6 +86,6 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32):
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label)
loss = train_network(data, label).asnumpy()
losses.append(loss)
assert (losses[-1].asnumpy() < 0.01)
assert losses[-1] < 0.01

View File

@ -144,9 +144,9 @@ def test_train_lenet():
for i in range(epoch):
data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([net.batch_size]).astype(np.int32))
loss = train_network(data, label)
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)
assert losses[-1] < 0.01
def create_dataset(data_path, batch_size=32, repeat_size=1,