!2388 GPU update bert scripts
Merge pull request !2388 from VectorSL/bert-eval
This commit is contained in:
commit
748b0b1b64
|
@ -18,9 +18,11 @@ Bert evaluation script.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import log as logger
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.c_transforms as C
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
@ -105,8 +107,17 @@ def bert_predict(Evaluation):
|
||||||
'''
|
'''
|
||||||
prediction function
|
prediction function
|
||||||
'''
|
'''
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
target = args_opt.device_target
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
if target == "Ascend":
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||||
|
elif target == "GPU":
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
if bert_net_cfg.compute_type != mstype.float32:
|
||||||
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||||
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
|
else:
|
||||||
|
raise Exception("Target error, GPU or Ascend is supported.")
|
||||||
dataset = get_dataset(bert_net_cfg.batch_size, 1)
|
dataset = get_dataset(bert_net_cfg.batch_size, 1)
|
||||||
if cfg.use_crf:
|
if cfg.use_crf:
|
||||||
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
|
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
|
||||||
|
@ -147,6 +158,9 @@ def test_eval():
|
||||||
callback.acc_num / callback.total_num))
|
callback.acc_num / callback.total_num))
|
||||||
print("==============================================================")
|
print("==============================================================")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Bert eval')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
num_labels = cfg.num_labels
|
num_labels = cfg.num_labels
|
||||||
test_eval()
|
test_eval()
|
||||||
|
|
|
@ -23,6 +23,7 @@ from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCe
|
||||||
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
|
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import log as logger
|
||||||
import mindspore.dataset as de
|
import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.c_transforms as C
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
|
@ -105,6 +106,9 @@ def test_train():
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||||
elif target == "GPU":
|
elif target == "GPU":
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
if bert_net_cfg.compute_type != mstype.float32:
|
||||||
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||||
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
else:
|
else:
|
||||||
raise Exception("Target error, GPU or Ascend is supported.")
|
raise Exception("Target error, GPU or Ascend is supported.")
|
||||||
#BertCLSTrain for classification
|
#BertCLSTrain for classification
|
||||||
|
|
|
@ -86,6 +86,6 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32):
|
||||||
for i in range(0, epoch):
|
for i in range(0, epoch):
|
||||||
data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01)
|
data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||||
loss = train_network(data, label)
|
loss = train_network(data, label).asnumpy()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
assert (losses[-1].asnumpy() < 0.01)
|
assert losses[-1] < 0.01
|
||||||
|
|
|
@ -144,9 +144,9 @@ def test_train_lenet():
|
||||||
for i in range(epoch):
|
for i in range(epoch):
|
||||||
data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
|
data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([net.batch_size]).astype(np.int32))
|
label = Tensor(np.ones([net.batch_size]).astype(np.int32))
|
||||||
loss = train_network(data, label)
|
loss = train_network(data, label).asnumpy()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
print(losses)
|
assert losses[-1] < 0.01
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||||
|
|
Loading…
Reference in New Issue