forked from mindspore-Ecosystem/mindspore
add eval onnx script for r1.3
This commit is contained in:
parent
1bf2a57ff4
commit
3957acfb6e
|
@ -1162,7 +1162,7 @@ class Tensor(Tensor_):
|
|||
>>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32")
|
||||
>>> output = x.clip(0, 2)
|
||||
>>> print(output)
|
||||
[1 2 2 0 0 2 2 0]
|
||||
[1. 2. 2. 0. 0. 2. 2. 0.]
|
||||
"""
|
||||
if xmin is None and xmax is None:
|
||||
raise ValueError("One of max or min must be given.")
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
- [Inference Process](#inference-process)
|
||||
- [Usage](#usage)
|
||||
- [result](#result)
|
||||
- [Export ONNX model and inference](#export-onnx-model-and-inference)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Pretraining Performance](#pretraining-performance)
|
||||
|
@ -712,6 +713,31 @@ Inference result is saved in current path, you can find result in acc.log file.
|
|||
F1 0.931243
|
||||
```
|
||||
|
||||
### [Export ONNX model and inference](#contents)
|
||||
|
||||
Currently, the ONNX model of Bert classification task can be exported, and third-party tools such as ONNXRuntime can be used to load ONNX for inference.
|
||||
|
||||
- export ONNX
|
||||
|
||||
```shell
|
||||
python export.py --config_path [../../task_classifier_config.yaml] --file_format ["ONNX"] --export_ckpt_file [CKPT_PATH] --num_class [NUM_CLASS] --export_file_name [EXPORT_FILE_NAME]
|
||||
```
|
||||
|
||||
'CKPT_PATH' is mandatory, it is the path of the CKPT file that has been trained for a certain classification task model.
|
||||
'NUM_CLASS' is mandatory, it is the number of categories in the classification task model.
|
||||
'EXPORT_FILE_NAME' is optional, it is the name of the exported ONNX model. If not set, the ONNX model will be saved in the current directory with the default name.
|
||||
|
||||
After running, the ONNX model of Bert will be saved in the current file directory.
|
||||
|
||||
- Load ONNX and inference
|
||||
|
||||
```shell
|
||||
python run_eval_onnx.py --config_path [../../task_classifier_config.yaml] --eval_data_file_path [EVAL_DATA_FILE_PATH] -export_file_name [EXPORT_FILE_NAME]
|
||||
```
|
||||
|
||||
'EVAL_DATA_FILE_PATH' is mandatory, it is the eval data of the dataset used by the classification task.
|
||||
'EXPORT_FILE_NAME' is optional, it is the model name of the ONNX in the step of export ONNX, which is used to load the specified ONNX model for inference.
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
- [推理过程](#推理过程)
|
||||
- [用法](#用法-2)
|
||||
- [结果](#结果)
|
||||
- [导出onnx模型与推理](#导出onnx模型与推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [预训练性能](#预训练性能)
|
||||
|
@ -672,6 +673,31 @@ bash run_infer_310.sh [MINDIR_PATH] [LABEL_PATH] [DATA_FILE_PATH] [DATASET_FORMA
|
|||
F1 0.931243
|
||||
```
|
||||
|
||||
## 导出onnx模型与推理
|
||||
|
||||
当前已支持导出bert分类任务的ONNX模型, 并可通过ONNXRuntime等第三方工具加载ONNX进行推理。
|
||||
|
||||
- 导出ONNX
|
||||
|
||||
```shell
|
||||
python export.py --config_path [../../task_classifier_config.yaml] --file_format ["ONNX"] --export_ckpt_file [CKPT_PATH] --num_class [NUM_CLASS] --export_file_name [EXPORT_FILE_NAME]
|
||||
```
|
||||
|
||||
`CKPT_PATH`为必选项, 是某个分类任务模型训练完毕的ckpt文件路径。
|
||||
`NUM_CLASS`为必选项, 是该分类任务模型的类别数。
|
||||
`EXPORT_FILE_NAME`为可选项, 是导出ONNX模型的名字, 如果未设置则ONNX模型会以默认名保存在当前目录下。
|
||||
|
||||
运行结束后, 当前文件目录下会保存bert该分类任务模型的ONNX模型。
|
||||
|
||||
- 加载ONNX并推理
|
||||
|
||||
```shell
|
||||
python run_eval_onnx.py --config_path [../../task_classifier_config.yaml] --eval_data_file_path [EVAL_DATA_FILE_PATH] --export_file_name [EXPORT_FILE_NAME]
|
||||
```
|
||||
|
||||
`EVAL_DATA_FILE_PATH`为必选项, 是该分类任务所用数据集的eval数据。
|
||||
`EXPORT_FILE_NAME`为可选项, 是导出ONNX步骤中ONNX的模型名, 此处用于加载指定ONNX模型进行推理。
|
||||
|
||||
## 模型描述
|
||||
|
||||
## 性能
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -44,28 +44,26 @@ def run_export():
|
|||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
label_list = []
|
||||
with open(args.label_file_path) as f:
|
||||
for label in f:
|
||||
label_list.append(label.strip())
|
||||
|
||||
tag_to_index = convert_labels_to_index(label_list)
|
||||
|
||||
if args.use_crf.lower() == "true":
|
||||
max_val = max(tag_to_index.values())
|
||||
tag_to_index["<START>"] = max_val + 1
|
||||
tag_to_index["<STOP>"] = max_val + 2
|
||||
number_labels = len(tag_to_index)
|
||||
else:
|
||||
number_labels = len(tag_to_index)
|
||||
if args.description == "run_ner":
|
||||
label_list = []
|
||||
with open(args.label_file_path) as f:
|
||||
for label in f:
|
||||
label_list.append(label.strip())
|
||||
|
||||
tag_to_index = convert_labels_to_index(label_list)
|
||||
|
||||
if args.use_crf.lower() == "true":
|
||||
max_val = max(tag_to_index.values())
|
||||
tag_to_index["<START>"] = max_val + 1
|
||||
tag_to_index["<STOP>"] = max_val + 2
|
||||
number_labels = len(tag_to_index)
|
||||
net = BertNER(bert_net_cfg, args.export_batch_size, False, num_labels=number_labels,
|
||||
use_crf=True, tag_to_index=tag_to_index)
|
||||
else:
|
||||
number_labels = len(tag_to_index)
|
||||
net = BertNERModel(bert_net_cfg, False, number_labels, use_crf=(args.use_crf.lower() == "true"))
|
||||
elif args.description == "run_classifier":
|
||||
net = BertCLSModel(bert_net_cfg, False, num_labels=number_labels)
|
||||
net = BertCLSModel(bert_net_cfg, False, num_labels=args.num_class)
|
||||
elif args.description == "run_squad":
|
||||
net = BertSquadModel(bert_net_cfg, False)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
'''
|
||||
Inference script of ONNX exported from the Bert classification model.
|
||||
'''
|
||||
|
||||
import os
|
||||
from src.dataset import create_classification_dataset
|
||||
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
|
||||
from src.model_utils.config import config as args_opt
|
||||
from mindspore import Tensor, dtype
|
||||
import onnxruntime as rt
|
||||
|
||||
|
||||
def eval_result_print(assessment_method="accuracy", callback=None):
|
||||
""" print eval result """
|
||||
if assessment_method == "accuracy":
|
||||
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
|
||||
callback.acc_num / callback.total_num))
|
||||
elif assessment_method == "f1":
|
||||
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
|
||||
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
|
||||
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
|
||||
elif assessment_method == "mcc":
|
||||
print("MCC {:.6f} ".format(callback.cal()))
|
||||
elif assessment_method == "spearman_correlation":
|
||||
print("Spearman Correlation is {:.6f} ".format(callback.cal()[0]))
|
||||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||
|
||||
|
||||
def do_eval_onnx(dataset=None, num_class=15, assessment_method="accuracy"):
|
||||
""" do eval for onnx model"""
|
||||
if assessment_method == "accuracy":
|
||||
callback = Accuracy()
|
||||
elif assessment_method == "f1":
|
||||
callback = F1(False, num_class)
|
||||
elif assessment_method == "mcc":
|
||||
callback = MCC()
|
||||
elif assessment_method == "spearman_correlation":
|
||||
callback = Spearman_Correlation()
|
||||
else:
|
||||
raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]")
|
||||
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
|
||||
onnx_file_name = args_opt.export_file_name + '.onnx'
|
||||
if not os.path.isabs(onnx_file_name):
|
||||
onnx_file_name = os.getcwd() + '/' + onnx_file_name
|
||||
sess = rt.InferenceSession(onnx_file_name)
|
||||
input_name_0 = sess.get_inputs()[0].name
|
||||
input_name_1 = sess.get_inputs()[1].name
|
||||
input_name_2 = sess.get_inputs()[2].name
|
||||
output_name_0 = sess.get_outputs()[0].name
|
||||
|
||||
for data in dataset.create_dict_iterator(num_epochs=1):
|
||||
input_data = []
|
||||
for i in columns_list:
|
||||
input_data.append(data[i])
|
||||
input_ids, input_mask, token_type_id, label_ids = input_data
|
||||
|
||||
x0 = input_ids.asnumpy()
|
||||
x1 = input_mask.asnumpy()
|
||||
x2 = token_type_id.asnumpy()
|
||||
|
||||
result = sess.run([output_name_0], {input_name_0: x0, input_name_1: x1, input_name_2: x2})
|
||||
logits = Tensor(result[0], dtype.float32)
|
||||
callback.update(logits, label_ids)
|
||||
|
||||
print("==============================================================")
|
||||
eval_result_print(assessment_method, callback)
|
||||
print("==============================================================")
|
||||
|
||||
|
||||
def run_classifier_onnx():
|
||||
"""run classifier task for onnx model"""
|
||||
if args_opt.eval_data_file_path == "":
|
||||
raise ValueError("'eval_data_file_path' must be set when do onnx evaluation task")
|
||||
if args_opt.onnx_file_path == "":
|
||||
raise ValueError("'onnx_file_path' must be set when do onnx evaluation task")
|
||||
assessment_method = args_opt.assessment_method.lower()
|
||||
ds = create_classification_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
|
||||
assessment_method=assessment_method,
|
||||
data_file_path=args_opt.eval_data_file_path,
|
||||
schema_file_path=args_opt.schema_file_path,
|
||||
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
|
||||
do_eval_onnx(ds, args_opt.num_class, assessment_method)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_classifier_onnx()
|
Loading…
Reference in New Issue