!2388 GPU update bert scripts
Merge pull request !2388 from VectorSL/bert-eval
This commit is contained in:
commit
748b0b1b64
|
@ -18,9 +18,11 @@ Bert evaluation script.
|
|||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
@ -105,8 +107,17 @@ def bert_predict(Evaluation):
|
|||
'''
|
||||
prediction function
|
||||
'''
|
||||
target = args_opt.device_target
|
||||
if target == "Ascend":
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
elif target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
if bert_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
dataset = get_dataset(bert_net_cfg.batch_size, 1)
|
||||
if cfg.use_crf:
|
||||
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
|
||||
|
@ -147,6 +158,9 @@ def test_eval():
|
|||
callback.acc_num / callback.total_num))
|
||||
print("==============================================================")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Bert eval')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
args_opt = parser.parse_args()
|
||||
if __name__ == "__main__":
|
||||
num_labels = cfg.num_labels
|
||||
test_eval()
|
||||
|
|
|
@ -23,6 +23,7 @@ from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCe
|
|||
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
|
@ -105,6 +106,9 @@ def test_train():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
elif target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
if bert_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
else:
|
||||
raise Exception("Target error, GPU or Ascend is supported.")
|
||||
#BertCLSTrain for classification
|
||||
|
|
|
@ -86,6 +86,6 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32):
|
|||
for i in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
assert (losses[-1].asnumpy() < 0.01)
|
||||
assert losses[-1] < 0.01
|
||||
|
|
|
@ -144,9 +144,9 @@ def test_train_lenet():
|
|||
for i in range(epoch):
|
||||
data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([net.batch_size]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
||||
assert losses[-1] < 0.01
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
|
|
Loading…
Reference in New Issue