forked from mindspore-Ecosystem/mindspore
!21956 [特性][Ascend] Add post training quantization of yolov3_darknet53
Merge pull request !21956 from chenzhuo/quant
This commit is contained in:
commit
cf7bb217b0
|
@ -23,6 +23,7 @@ from mindspore import Tensor, context
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def quant_unet(network, dataset, input_data):
|
||||
"""
|
||||
Export post training quantization model of AIR format.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
- [Evaluation](#evaluation)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Post Training Quantization](#post-training-quantization)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
|
@ -440,6 +441,52 @@ Inference result is saved in current path, you can find result in acc.log file.
|
|||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
|
||||
```
|
||||
|
||||
### [Post Training Quantization](#contents)
|
||||
|
||||
Relative executing script files reside in the directory "ascend310_quant_infer". Please implement following steps sequentially to complete post quantization.
|
||||
Current quantization project bases on COCO2014 dataset.
|
||||
|
||||
1. Generate data of .bin format required for AIR model inference at Ascend310 platform.
|
||||
|
||||
```shell
|
||||
python export_bin.py --config_path [YMAL CONFIG PATH] --data_dir [DATA DIR] --annFile [ANNOTATION FILE PATH]
|
||||
```
|
||||
|
||||
2. Export quantized AIR model.
|
||||
|
||||
Post quantization of model requires special toolkits for exporting quantized AIR model. Please refer to [official website](https://www.hiascend.com/software/cann/community).
|
||||
|
||||
```shell
|
||||
python post_quant.py --config_path [YMAL CONFIG PATH] --ckpt_file [CKPT_PATH] --data_dir [DATASET PATH] --annFile [ANNOTATION FILE PATH]
|
||||
```
|
||||
|
||||
The quantized AIR file will be stored as "./results/yolov3_quant.air".
|
||||
|
||||
3. Implement inference at Ascend310 platform.
|
||||
|
||||
```shell
|
||||
# Ascend310 quant inference
|
||||
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]
|
||||
```
|
||||
|
||||
Inference result is saved in current path, you can find result like this in acc.log file.
|
||||
|
||||
```bash
|
||||
=============coco eval result=========
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.306
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.524
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.314
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.319
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.423
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.256
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.395
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.419
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.438
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
- [推理过程](#推理过程)
|
||||
- [用法](#用法-2)
|
||||
- [结果](#结果-2)
|
||||
- [训练后量化推理](#训练后量化推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
|
@ -434,6 +435,51 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANNO_PATH] [DEVICE_ID]
|
|||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
|
||||
```
|
||||
|
||||
## [训练后量化推理](#contents)
|
||||
|
||||
训练后量化推理的相关执行脚本文件在"ascend310_quant_infer"目录下,依次执行以下步骤实现训练后量化推理。本训练后量化工程基于COCO2014数据集。
|
||||
|
||||
1、生成Ascend310平台AIR模型推理需要的.bin格式数据。
|
||||
|
||||
```shell
|
||||
python export_bin.py --config_path [YMAL CONFIG PATH] --data_dir [DATA DIR] --annFile [ANNOTATION FILE PATH]
|
||||
```
|
||||
|
||||
2、导出训练后量化的AIR格式模型。
|
||||
|
||||
导出训练后量化模型需要配套的量化工具包,参考[官方地址](https://www.hiascend.com/software/cann/community)
|
||||
|
||||
```shell
|
||||
python post_quant.py --config_path [YMAL CONFIG PATH] --ckpt_file [CKPT_PATH] --data_dir [DATASET PATH] --annFile [ANNOTATION FILE PATH]
|
||||
```
|
||||
|
||||
导出的模型会存储在./result/yolov3_quant.air。
|
||||
|
||||
3、在Ascend310执行推理量化模型。
|
||||
|
||||
```shell
|
||||
# Ascend310 quant inference
|
||||
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]
|
||||
```
|
||||
|
||||
推理结果保存在脚本执行的当前路径,可以在acc.log中看到精度计算结果。
|
||||
|
||||
```bash
|
||||
=============coco eval result=========
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.306
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.524
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.314
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.319
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.423
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.256
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.395
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.419
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.438
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""postprocess for 310 inference"""
|
||||
import os
|
||||
import datetime
|
||||
import argparse
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser('YoloV3 quant postprocess')
|
||||
parser.add_argument('--result_path', type=str, required=True, help='result files path.')
|
||||
parser.add_argument('--batch_size', default=1, type=int, help='batch size for per gpu')
|
||||
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
|
||||
parser.add_argument('--eval_ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
|
||||
parser.add_argument('--annFile', type=str, default='', help='path to annotation')
|
||||
parser.add_argument('--image_shape', type=str, default='./image_shape.npy', help='path to image_shape.npy')
|
||||
parser.add_argument('--image_id', type=str, default='./image_id.npy', help='path to image_id.npy')
|
||||
parser.add_argument('--log_path', type=str, default='outputs/', help='inference result save location')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
class Redirct:
|
||||
def __init__(self):
|
||||
self.content = ""
|
||||
|
||||
def write(self, content):
|
||||
self.content += content
|
||||
|
||||
def flush(self):
|
||||
self.content = ""
|
||||
|
||||
|
||||
class DetectionEngine:
|
||||
"""Detection engine."""
|
||||
def __init__(self):
|
||||
self.eval_ignore_threshold = args.eval_ignore_threshold
|
||||
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
||||
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
||||
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
||||
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
||||
self.num_classes = len(self.labels)
|
||||
self.results = {}
|
||||
self.file_path = ''
|
||||
self.save_prefix = args.outputs_dir
|
||||
self.annFile = args.annFile
|
||||
self._coco = COCO(self.annFile)
|
||||
self._img_ids = list(sorted(self._coco.imgs.keys()))
|
||||
self.det_boxes = []
|
||||
self.nms_thresh = args.nms_thresh
|
||||
self.coco_catIds = self._coco.getCatIds()
|
||||
|
||||
def do_nms_for_results(self):
|
||||
"""Get result boxes."""
|
||||
for img_id in self.results:
|
||||
for clsi in self.results[img_id]:
|
||||
dets = self.results[img_id][clsi]
|
||||
dets = np.array(dets)
|
||||
keep_index = self._nms(dets, self.nms_thresh)
|
||||
|
||||
keep_box = [{'image_id': int(img_id),
|
||||
'category_id': int(clsi),
|
||||
'bbox': list(dets[i][:4].astype(float)),
|
||||
'score': dets[i][4].astype(float)}
|
||||
for i in keep_index]
|
||||
self.det_boxes.extend(keep_box)
|
||||
|
||||
def _nms(self, predicts, threshold):
|
||||
"""Calculate NMS."""
|
||||
# convert xywh -> xmin ymin xmax ymax
|
||||
x1 = predicts[:, 0]
|
||||
y1 = predicts[:, 1]
|
||||
x2 = x1 + predicts[:, 2]
|
||||
y2 = y1 + predicts[:, 3]
|
||||
scores = predicts[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
reserved_boxes = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
reserved_boxes.append(i)
|
||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||
intersect_area = intersect_w * intersect_h
|
||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||
|
||||
indexes = np.where(ovr <= threshold)[0]
|
||||
order = order[indexes + 1]
|
||||
return reserved_boxes
|
||||
|
||||
def write_result(self):
|
||||
"""Save result to file."""
|
||||
import json
|
||||
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
|
||||
try:
|
||||
self.file_path = self.save_prefix + '/predict' + t + '.json'
|
||||
f = open(self.file_path, 'w')
|
||||
json.dump(self.det_boxes, f)
|
||||
except IOError as e:
|
||||
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
|
||||
else:
|
||||
f.close()
|
||||
return self.file_path
|
||||
|
||||
def get_eval_result(self):
|
||||
"""Get eval result."""
|
||||
cocoGt = COCO(self.annFile)
|
||||
cocoDt = cocoGt.loadRes(self.file_path)
|
||||
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
|
||||
cocoEval.evaluate()
|
||||
cocoEval.accumulate()
|
||||
rdct = Redirct()
|
||||
stdout = sys.stdout
|
||||
sys.stdout = rdct
|
||||
cocoEval.summarize()
|
||||
sys.stdout = stdout
|
||||
return rdct.content
|
||||
|
||||
def detect(self, outputs, batch, image_shape, image_id):
|
||||
"""Detect boxes."""
|
||||
outputs_num = len(outputs)
|
||||
# output [|32, 52, 52, 3, 85| ]
|
||||
for batch_id in range(batch):
|
||||
for out_id in range(outputs_num):
|
||||
# 32, 52, 52, 3, 85
|
||||
out_item = outputs[out_id]
|
||||
# 52, 52, 3, 85
|
||||
out_item_single = out_item[batch_id, :]
|
||||
# get number of items in one head, [B, gx, gy, anchors, 5+80]
|
||||
dimensions = out_item_single.shape[:-1]
|
||||
out_num = 1
|
||||
for d in dimensions:
|
||||
out_num *= d
|
||||
ori_w, ori_h = image_shape[batch_id]
|
||||
img_id = int(image_id[batch_id])
|
||||
x = out_item_single[..., 0] * ori_w
|
||||
y = out_item_single[..., 1] * ori_h
|
||||
w = out_item_single[..., 2] * ori_w
|
||||
h = out_item_single[..., 3] * ori_h
|
||||
|
||||
conf = out_item_single[..., 4:5]
|
||||
cls_emb = out_item_single[..., 5:]
|
||||
|
||||
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
|
||||
x = x.reshape(-1)
|
||||
y = y.reshape(-1)
|
||||
w = w.reshape(-1)
|
||||
h = h.reshape(-1)
|
||||
cls_emb = cls_emb.reshape(-1, self.num_classes)
|
||||
conf = conf.reshape(-1)
|
||||
cls_argmax = cls_argmax.reshape(-1)
|
||||
|
||||
x_top_left = x - w / 2.
|
||||
y_top_left = y - h / 2.
|
||||
# create all False
|
||||
flag = np.random.random(cls_emb.shape) > sys.maxsize
|
||||
for i in range(flag.shape[0]):
|
||||
c = cls_argmax[i]
|
||||
flag[i, c] = True
|
||||
confidence = cls_emb[flag] * conf
|
||||
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
|
||||
if confi < self.eval_ignore_threshold:
|
||||
continue
|
||||
if img_id not in self.results:
|
||||
self.results[img_id] = defaultdict(list)
|
||||
x_lefti = max(0, x_lefti)
|
||||
y_lefti = max(0, y_lefti)
|
||||
wi = min(wi, ori_w)
|
||||
hi = min(hi, ori_h)
|
||||
# transform catId to match coco
|
||||
coco_clsi = self.coco_catIds[clsi]
|
||||
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
if not os.path.exists(args.outputs_dir):
|
||||
os.makedirs(args.outputs_dir)
|
||||
detection = DetectionEngine()
|
||||
bs = args.batch_size
|
||||
shape_list = np.load(args.image_shape)
|
||||
id_list = np.load(args.image_id)
|
||||
prefix = "YoloV3-DarkNet_coco_bs_" + str(bs) + "_"
|
||||
iter_num = 0
|
||||
for id_img in id_list:
|
||||
shape_img = shape_list[iter_num]
|
||||
path_small = os.path.join(args.result_path, prefix + str(iter_num) + '_output_0.bin')
|
||||
path_medium = os.path.join(args.result_path, prefix + str(iter_num) + '_output_1.bin')
|
||||
path_big = os.path.join(args.result_path, prefix + str(iter_num) + '_output_2.bin')
|
||||
if os.path.exists(path_small) and os.path.exists(path_medium) and os.path.exists(path_big):
|
||||
output_small = np.fromfile(path_small, np.float32).reshape(bs, 13, 13, 3, 85)
|
||||
output_medium = np.fromfile(path_medium, np.float32).reshape(bs, 26, 26, 3, 85)
|
||||
output_big = np.fromfile(path_big, np.float32).reshape(bs, 52, 52, 3, 85)
|
||||
detection.detect([output_small, output_medium, output_big], bs, shape_img, id_img)
|
||||
else:
|
||||
print("Error: Image ", iter_num, " is not exist.")
|
||||
iter_num += 1
|
||||
|
||||
detection.do_nms_for_results()
|
||||
result_file_path = detection.write_result()
|
||||
eval_result = detection.get_eval_result()
|
||||
print('\n=============coco eval result=========\n' + eval_result)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""generate data and label needed for AIR model inference"""
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_data():
|
||||
"""
|
||||
Generate data and label needed for AIR model inference at Ascend310 platform.
|
||||
"""
|
||||
config.batch_size = 1
|
||||
data_path = os.path.join(config.data_dir, "val2014")
|
||||
ds, data_size = create_yolo_dataset(data_path, config.annFile, is_training=False, batch_size=config.batch_size,
|
||||
max_epoch=1, device_num=1, rank=0, shuffle=False, config=config)
|
||||
print('testing shape : {}'.format(config.test_img_shape))
|
||||
print('total {} images to eval'.format(data_size))
|
||||
|
||||
save_folder = "./data"
|
||||
image_folder = os.path.join(save_folder, "image_bin")
|
||||
if not os.path.exists(image_folder):
|
||||
os.makedirs(image_folder)
|
||||
|
||||
list_image_shape = []
|
||||
list_image_id = []
|
||||
|
||||
for i, data in enumerate(ds.create_dict_iterator()):
|
||||
image = data["image"].asnumpy()
|
||||
image_shape = data["image_shape"]
|
||||
image_id = data["img_id"]
|
||||
file_name = "YoloV3-DarkNet_coco_bs_" + str(config.batch_size) + "_" + str(i) + ".bin"
|
||||
file_path = image_folder + "/" + file_name
|
||||
image.tofile(file_path)
|
||||
list_image_shape.append(image_shape.asnumpy())
|
||||
list_image_id.append(image_id.asnumpy())
|
||||
shapes = np.array(list_image_shape)
|
||||
ids = np.array(list_image_id)
|
||||
np.save(save_folder + "/image_shape.npy", shapes)
|
||||
np.save(save_folder + "/image_id.npy", ids)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append("..")
|
||||
from model_utils.config import config
|
||||
from src.yolo_dataset import create_yolo_dataset
|
||||
|
||||
generate_data()
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include "../inc/utils.h"
|
||||
#include "acl/acl.h"
|
||||
|
||||
/**
|
||||
* ModelProcess
|
||||
*/
|
||||
class ModelProcess {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor
|
||||
*/
|
||||
ModelProcess();
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~ModelProcess();
|
||||
|
||||
/**
|
||||
* @brief load model from file with mem
|
||||
* @param [in] modelPath: model path
|
||||
* @return result
|
||||
*/
|
||||
Result LoadModelFromFileWithMem(const char *modelPath);
|
||||
|
||||
/**
|
||||
* @brief unload model
|
||||
*/
|
||||
void Unload();
|
||||
|
||||
/**
|
||||
* @brief create model desc
|
||||
* @return result
|
||||
*/
|
||||
Result CreateDesc();
|
||||
|
||||
/**
|
||||
* @brief destroy desc
|
||||
*/
|
||||
void DestroyDesc();
|
||||
|
||||
/**
|
||||
* @brief create model input
|
||||
* @param [in] inputDataBuffer: input buffer
|
||||
* @param [in] bufferSize: input buffer size
|
||||
* @return result
|
||||
*/
|
||||
Result CreateInput(void *inputDataBuffer, size_t bufferSize);
|
||||
|
||||
/**
|
||||
* @brief destroy input resource
|
||||
*/
|
||||
void DestroyInput();
|
||||
|
||||
/**
|
||||
* @brief create output buffer
|
||||
* @return result
|
||||
*/
|
||||
Result CreateOutput();
|
||||
|
||||
/**
|
||||
* @brief destroy output resource
|
||||
*/
|
||||
void DestroyOutput();
|
||||
|
||||
/**
|
||||
* @brief model execute
|
||||
* @return result
|
||||
*/
|
||||
Result Execute();
|
||||
|
||||
/**
|
||||
* @brief dump model output result to file
|
||||
*/
|
||||
void DumpModelOutputResult(char *output_name);
|
||||
|
||||
/**
|
||||
* @brief get model output result
|
||||
*/
|
||||
void OutputModelResult();
|
||||
|
||||
private:
|
||||
uint32_t modelId_;
|
||||
size_t modelMemSize_;
|
||||
size_t modelWeightSize_;
|
||||
void *modelMemPtr_;
|
||||
void *modelWeightPtr_;
|
||||
bool loadFlag_; // model load flag
|
||||
aclmdlDesc *modelDesc_;
|
||||
aclmdlDataset *input_;
|
||||
aclmdlDataset *output_;
|
||||
};
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../inc/utils.h"
|
||||
#include "acl/acl.h"
|
||||
|
||||
/**
|
||||
* SampleProcess
|
||||
*/
|
||||
class SampleProcess {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor
|
||||
*/
|
||||
SampleProcess();
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~SampleProcess();
|
||||
|
||||
/**
|
||||
* @brief init reousce
|
||||
* @return result
|
||||
*/
|
||||
Result InitResource();
|
||||
|
||||
/**
|
||||
* @brief sample process
|
||||
* @return result
|
||||
*/
|
||||
Result Process(char *om_path, char *input_folder);
|
||||
|
||||
void GetAllFiles(std::string path, std::vector<std::string> *files);
|
||||
|
||||
private:
|
||||
void DestroyResource();
|
||||
|
||||
int32_t deviceId_;
|
||||
aclrtContext context_;
|
||||
aclrtStream stream_;
|
||||
};
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args)
|
||||
#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args)
|
||||
#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args)
|
||||
|
||||
typedef enum Result {
|
||||
SUCCESS = 0,
|
||||
FAILED = 1
|
||||
} Result;
|
||||
|
||||
/**
|
||||
* Utils
|
||||
*/
|
||||
class Utils {
|
||||
public:
|
||||
/**
|
||||
* @brief create device buffer of file
|
||||
* @param [in] fileName: file name
|
||||
* @param [out] fileSize: size of file
|
||||
* @return device buffer of file
|
||||
*/
|
||||
static void *GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize);
|
||||
|
||||
/**
|
||||
* @brief create buffer of file
|
||||
* @param [in] fileName: file name
|
||||
* @param [out] fileSize: size of file
|
||||
* @return buffer of pic
|
||||
*/
|
||||
static void* ReadBinFile(std::string fileName, uint32_t *fileSize);
|
||||
};
|
||||
|
||||
#pragma once
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""do post training quantization for Ascend310"""
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from amct_mindspore.quantize_tool import create_quant_config
|
||||
from amct_mindspore.quantize_tool import quantize_model
|
||||
from amct_mindspore.quantize_tool import save_model
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def quant_yolov3(network, dataset, input_data):
|
||||
"""
|
||||
Export post training quantization model of AIR format.
|
||||
|
||||
Args:
|
||||
network: the origin network for inference.
|
||||
dataset: the data for inference.
|
||||
input_data: the data used for constructing network. The shape and format of input data should be the same as
|
||||
actual data for inference.
|
||||
"""
|
||||
|
||||
# step2: create the quant config json file
|
||||
create_quant_config("./config.json", network, input_data)
|
||||
|
||||
# step3: do some network modification and return the modified network
|
||||
calibration_network = quantize_model("./config.json", network, input_data)
|
||||
calibration_network.set_train(False)
|
||||
|
||||
# step4: perform the evaluation of network to do activation calibration
|
||||
for _, data in enumerate(dataset.create_dict_iterator(num_epochs=1)):
|
||||
image = data["image"]
|
||||
_ = calibration_network(image)
|
||||
|
||||
# step5: export the air file
|
||||
save_model("results/yolov3_quant", calibration_network, input_data)
|
||||
print("[INFO] the quantized AIR file has been stored at: \n {}".format("results/yolov3_quant.air"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.path.append("..")
|
||||
from src.yolo import YOLOV3DarkNet53
|
||||
from src.yolo_dataset import create_yolo_dataset
|
||||
from model_utils.config import config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = YOLOV3DarkNet53(is_training=False)
|
||||
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
config.batch_size = 1
|
||||
data_path = os.path.join(config.data_dir, "val2014")
|
||||
datasets, data_size = create_yolo_dataset(data_path, config.annFile, is_training=False,
|
||||
batch_size=config.batch_size, max_epoch=1, device_num=1, rank=0,
|
||||
shuffle=False, config=config)
|
||||
ds = datasets.take(1)
|
||||
shape = [config.batch_size, 3] + config.test_img_shape
|
||||
inputs = Tensor(np.zeros(shape), ms.float32)
|
||||
quant_yolov3(net, ds, inputs)
|
|
@ -0,0 +1,109 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]"
|
||||
echo "Example: bash run_quant_infer.sh ./yolov3_quant.air ./image_bin ./image_id.npy ./image_shape.npy \
|
||||
./instances_val2014.json"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
id_path=$(get_real_path $3)
|
||||
shape_path=$(get_real_path $4)
|
||||
ann_path=$(get_real_path $5)
|
||||
|
||||
echo "air name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "id path: "$id_path
|
||||
echo "shape path: "$shape_path
|
||||
echo "ann path: "$ann_path
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function air_to_om()
|
||||
{
|
||||
atc --input_format=NCHW --framework=1 --model=$model --output=yolov3_quant --soc_version=Ascend310 &> atc.log
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
bash ./src/build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
if [ -d result ]; then
|
||||
rm -rf ./result
|
||||
fi
|
||||
mkdir result
|
||||
./out/main ./yolov3_quant.om $data_path &> infer.log
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ./acc.py --result_path=./result --annFile=$ann_path --image_shape=$shape_path \
|
||||
--image_id=$id_path &> acc.log
|
||||
}
|
||||
|
||||
echo "start atc================================================"
|
||||
air_to_om
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "air to om code failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "start compile============================================"
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "start infer=============================================="
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "start calculate acc======================================"
|
||||
cal_acc
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "calculate accuracy failed"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
|
||||
|
||||
# CMake lowest version requirement
|
||||
cmake_minimum_required(VERSION 3.5.1)
|
||||
# project information
|
||||
project(InferClassification)
|
||||
# Check environment variable
|
||||
if(NOT DEFINED ENV{ASCEND_HOME})
|
||||
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
|
||||
endif()
|
||||
|
||||
# Compile options
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
|
||||
# Skip build rpath
|
||||
set(CMAKE_SKIP_BUILD_RPATH True)
|
||||
|
||||
# Set output directory
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/../out)
|
||||
|
||||
# Set include directory and library directory
|
||||
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
|
||||
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
|
||||
|
||||
# Header path
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${FWKACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../inc)
|
||||
|
||||
# add host lib path
|
||||
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
add_executable(main utils.cpp
|
||||
sample_process.cpp
|
||||
model_process.cpp
|
||||
main.cpp)
|
||||
|
||||
target_link_libraries(main ${acl} gflags pthread)
|
|
@ -0,0 +1 @@
|
|||
{}
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
path_cur=$(cd "`dirname $0`" || exit; pwd)
|
||||
|
||||
function preparePath() {
|
||||
rm -rf $1
|
||||
mkdir -p $1
|
||||
cd $1 || exit
|
||||
}
|
||||
|
||||
function buildA300() {
|
||||
if [ ! "${ARCH_PATTERN}" ]; then
|
||||
# set ARCH_PATTERN to acllib when it was not specified by user
|
||||
export ARCH_PATTERN=acllib
|
||||
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
|
||||
else
|
||||
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user, reset it to ${ARCH_PATTERN}/acllib"
|
||||
export ARCH_PATTERN=${ARCH_PATTERN}/acllib
|
||||
fi
|
||||
|
||||
path_build=$path_cur/build
|
||||
preparePath $path_build
|
||||
cmake ..
|
||||
make -j
|
||||
ret=$?
|
||||
cd ..
|
||||
return ${ret}
|
||||
}
|
||||
|
||||
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
|
||||
if [ ! "${ASCEND_VERSION}" ]; then
|
||||
export ASCEND_VERSION=ascend-toolkit/latest
|
||||
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
|
||||
else
|
||||
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
|
||||
fi
|
||||
|
||||
buildA300
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "../inc/sample_process.h"
|
||||
#include "../inc/utils.h"
|
||||
bool g_is_device = false;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc != 3) {
|
||||
ERROR_LOG("usage:./main path_of_om path_of_inputFolder");
|
||||
return FAILED;
|
||||
}
|
||||
SampleProcess processSample;
|
||||
Result ret = processSample.InitResource();
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("sample init resource failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = processSample.Process(argv[1], argv[2]);
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("sample process failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
INFO_LOG("execute sample success");
|
||||
return SUCCESS;
|
||||
}
|
|
@ -0,0 +1,339 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "../inc/model_process.h"
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "../inc/utils.h"
|
||||
extern bool g_is_device;
|
||||
|
||||
ModelProcess::ModelProcess() :modelId_(0), modelMemSize_(0), modelWeightSize_(0), modelMemPtr_(nullptr),
|
||||
modelWeightPtr_(nullptr), loadFlag_(false), modelDesc_(nullptr), input_(nullptr), output_(nullptr) {
|
||||
}
|
||||
|
||||
ModelProcess::~ModelProcess() {
|
||||
Unload();
|
||||
DestroyDesc();
|
||||
DestroyInput();
|
||||
DestroyOutput();
|
||||
}
|
||||
|
||||
Result ModelProcess::LoadModelFromFileWithMem(const char *modelPath) {
|
||||
if (loadFlag_) {
|
||||
ERROR_LOG("has already loaded a model");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclError ret = aclmdlQuerySize(modelPath, &modelMemSize_, &modelWeightSize_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("query model failed, model file is %s", modelPath);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = aclrtMalloc(&modelMemPtr_, modelMemSize_, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("malloc buffer for mem failed, require size is %zu", modelMemSize_);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = aclrtMalloc(&modelWeightPtr_, modelWeightSize_, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("malloc buffer for weight failed, require size is %zu", modelWeightSize_);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = aclmdlLoadFromFileWithMem(modelPath, &modelId_, modelMemPtr_,
|
||||
modelMemSize_, modelWeightPtr_, modelWeightSize_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("load model from file failed, model file is %s", modelPath);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
loadFlag_ = true;
|
||||
INFO_LOG("load model %s success", modelPath);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Result ModelProcess::CreateDesc() {
|
||||
modelDesc_ = aclmdlCreateDesc();
|
||||
if (modelDesc_ == nullptr) {
|
||||
ERROR_LOG("create model description failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclError ret = aclmdlGetDesc(modelDesc_, modelId_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("get model description failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
INFO_LOG("create model description success");
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyDesc() {
|
||||
if (modelDesc_ != nullptr) {
|
||||
(void)aclmdlDestroyDesc(modelDesc_);
|
||||
modelDesc_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Result ModelProcess::CreateInput(void *inputDataBuffer, size_t bufferSize) {
|
||||
input_ = aclmdlCreateDataset();
|
||||
if (input_ == nullptr) {
|
||||
ERROR_LOG("can't create dataset, create input failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclDataBuffer* inputData = aclCreateDataBuffer(inputDataBuffer, bufferSize);
|
||||
if (inputData == nullptr) {
|
||||
ERROR_LOG("can't create data buffer, create input failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclError ret = aclmdlAddDatasetBuffer(input_, inputData);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("add input dataset buffer failed");
|
||||
aclDestroyDataBuffer(inputData);
|
||||
inputData = nullptr;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInput() {
|
||||
if (input_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(input_); ++i) {
|
||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(input_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(input_);
|
||||
input_ = nullptr;
|
||||
}
|
||||
|
||||
Result ModelProcess::CreateOutput() {
|
||||
if (modelDesc_ == nullptr) {
|
||||
ERROR_LOG("no model description, create output failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
output_ = aclmdlCreateDataset();
|
||||
if (output_ == nullptr) {
|
||||
ERROR_LOG("can't create dataset, create output failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t outputSize = aclmdlGetNumOutputs(modelDesc_);
|
||||
for (size_t i = 0; i < outputSize; ++i) {
|
||||
size_t buffer_size = aclmdlGetOutputSizeByIndex(modelDesc_, i);
|
||||
|
||||
void *outputBuffer = nullptr;
|
||||
aclError ret = aclrtMalloc(&outputBuffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("can't malloc buffer, size is %zu, create output failed", buffer_size);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
aclDataBuffer* outputData = aclCreateDataBuffer(outputBuffer, buffer_size);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("can't create data buffer, create output failed");
|
||||
aclrtFree(outputBuffer);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = aclmdlAddDatasetBuffer(output_, outputData);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("can't add data buffer, create output failed");
|
||||
aclrtFree(outputBuffer);
|
||||
aclDestroyDataBuffer(outputData);
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
INFO_LOG("create model output success");
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ModelProcess::DumpModelOutputResult(char *output_name) {
|
||||
size_t outputNum = aclmdlGetDatasetNumBuffers(output_);
|
||||
|
||||
for (size_t i = 0; i < outputNum; ++i) {
|
||||
std::stringstream ss;
|
||||
ss << "result/" << output_name << "_output_" << i << ".bin";
|
||||
std::string outputFileName = ss.str();
|
||||
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
||||
if (outputFile) {
|
||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||
|
||||
void* outHostData = NULL;
|
||||
aclError ret = ACL_ERROR_NONE;
|
||||
if (!g_is_device) {
|
||||
ret = aclrtMallocHost(&outHostData, len);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
|
||||
(void)aclrtFreeHost(outHostData);
|
||||
return;
|
||||
}
|
||||
|
||||
fwrite(outHostData, len, sizeof(char), outputFile);
|
||||
|
||||
ret = aclrtFreeHost(outHostData);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
fwrite(data, len, sizeof(char), outputFile);
|
||||
}
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
} else {
|
||||
ERROR_LOG("create output file [%s] failed", outputFileName.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
INFO_LOG("dump data success");
|
||||
return;
|
||||
}
|
||||
|
||||
void ModelProcess::OutputModelResult() {
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
|
||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||
|
||||
void *outHostData = NULL;
|
||||
aclError ret = ACL_ERROR_NONE;
|
||||
float *outData = NULL;
|
||||
if (!g_is_device) {
|
||||
ret = aclrtMallocHost(&outHostData, len);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
outData = reinterpret_cast<float*>(outHostData);
|
||||
} else {
|
||||
outData = reinterpret_cast<float*>(data);
|
||||
}
|
||||
std::map<float, unsigned int, std::greater<float> > resultMap;
|
||||
for (unsigned int j = 0; j < len / sizeof(float); ++j) {
|
||||
resultMap[*outData] = j;
|
||||
outData++;
|
||||
}
|
||||
|
||||
int cnt = 0;
|
||||
for (auto it = resultMap.begin(); it != resultMap.end(); ++it) {
|
||||
// print top 5
|
||||
if (++cnt > 5) {
|
||||
break;
|
||||
}
|
||||
|
||||
INFO_LOG("top %d: index[%d] value[%lf]", cnt, it->second, it->first);
|
||||
}
|
||||
if (!g_is_device) {
|
||||
ret = aclrtFreeHost(outHostData);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INFO_LOG("output data success");
|
||||
return;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyOutput() {
|
||||
if (output_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
|
||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||
(void)aclrtFree(data);
|
||||
(void)aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
|
||||
(void)aclmdlDestroyDataset(output_);
|
||||
output_ = nullptr;
|
||||
}
|
||||
|
||||
Result ModelProcess::Execute() {
|
||||
aclError ret = aclmdlExecute(modelId_, input_, output_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("execute model failed, modelId is %u", modelId_);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
INFO_LOG("model execute success");
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ModelProcess::Unload() {
|
||||
if (!loadFlag_) {
|
||||
WARN_LOG("no model had been loaded, unload failed");
|
||||
return;
|
||||
}
|
||||
|
||||
aclError ret = aclmdlUnload(modelId_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("unload model failed, modelId is %u", modelId_);
|
||||
}
|
||||
|
||||
if (modelDesc_ != nullptr) {
|
||||
(void)aclmdlDestroyDesc(modelDesc_);
|
||||
modelDesc_ = nullptr;
|
||||
}
|
||||
|
||||
if (modelMemPtr_ != nullptr) {
|
||||
aclrtFree(modelMemPtr_);
|
||||
modelMemPtr_ = nullptr;
|
||||
modelMemSize_ = 0;
|
||||
}
|
||||
|
||||
if (modelWeightPtr_ != nullptr) {
|
||||
aclrtFree(modelWeightPtr_);
|
||||
modelWeightPtr_ = nullptr;
|
||||
modelWeightSize_ = 0;
|
||||
}
|
||||
|
||||
loadFlag_ = false;
|
||||
INFO_LOG("unload model success, modelId is %u", modelId_);
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "../inc/sample_process.h"
|
||||
#include <sys/time.h>
|
||||
#include <sys/types.h>
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include "../inc/model_process.h"
|
||||
#include "acl/acl.h"
|
||||
#include "../inc/utils.h"
|
||||
extern bool g_is_device;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
SampleProcess::SampleProcess() :deviceId_(0), context_(nullptr), stream_(nullptr) {
|
||||
}
|
||||
|
||||
SampleProcess::~SampleProcess() {
|
||||
DestroyResource();
|
||||
}
|
||||
|
||||
Result SampleProcess::InitResource() {
|
||||
// ACL init
|
||||
|
||||
const char *aclConfigPath = "./src/acl.json";
|
||||
aclError ret = aclInit(aclConfigPath);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("acl init failed");
|
||||
return FAILED;
|
||||
}
|
||||
INFO_LOG("acl init success");
|
||||
|
||||
// open device
|
||||
ret = aclrtSetDevice(deviceId_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("acl open device %d failed", deviceId_);
|
||||
return FAILED;
|
||||
}
|
||||
INFO_LOG("open device %d success", deviceId_);
|
||||
|
||||
// create context (set current)
|
||||
ret = aclrtCreateContext(&context_, deviceId_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("acl create context failed");
|
||||
return FAILED;
|
||||
}
|
||||
INFO_LOG("create context success");
|
||||
|
||||
// create stream
|
||||
ret = aclrtCreateStream(&stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("acl create stream failed");
|
||||
return FAILED;
|
||||
}
|
||||
INFO_LOG("create stream success");
|
||||
|
||||
// get run mode
|
||||
aclrtRunMode runMode;
|
||||
ret = aclrtGetRunMode(&runMode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("acl get run mode failed");
|
||||
return FAILED;
|
||||
}
|
||||
g_is_device = (runMode == ACL_DEVICE);
|
||||
INFO_LOG("get run mode success");
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
||||
DIR *pDir = NULL;
|
||||
struct dirent* ptr;
|
||||
if (!(pDir = opendir(path.c_str()))) {
|
||||
return;
|
||||
}
|
||||
while ((ptr = readdir(pDir)) != 0) {
|
||||
if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0) {
|
||||
files->push_back(path + "/" + ptr->d_name);
|
||||
}
|
||||
}
|
||||
closedir(pDir);
|
||||
}
|
||||
|
||||
Result SampleProcess::Process(char *om_path, char *input_folder) {
|
||||
// model init
|
||||
double second_to_millisecond = 1000;
|
||||
double second_to_microsecond = 1000000;
|
||||
|
||||
double whole_cost_time = 0.0;
|
||||
struct timeval start_global = {0};
|
||||
struct timeval end_global = {0};
|
||||
double startTimeMs_global = 0.0;
|
||||
double endTimeMs_global = 0.0;
|
||||
|
||||
gettimeofday(&start_global, nullptr);
|
||||
|
||||
ModelProcess processModel;
|
||||
const char* omModelPath = om_path;
|
||||
|
||||
Result ret = processModel.LoadModelFromFileWithMem(omModelPath);
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("execute LoadModelFromFileWithMem failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = processModel.CreateDesc();
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("execute CreateDesc failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ret = processModel.CreateOutput();
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("execute CreateOutput failed");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<string> testFile;
|
||||
GetAllFiles(input_folder, &testFile);
|
||||
|
||||
if (testFile.size() == 0) {
|
||||
WARN_LOG("no input data under folder");
|
||||
}
|
||||
|
||||
double model_cost_time = 0.0;
|
||||
double edge_to_edge_model_cost_time = 0.0;
|
||||
|
||||
for (size_t index = 0; index < testFile.size(); ++index) {
|
||||
INFO_LOG("start to process file:%s", testFile[index].c_str());
|
||||
// model process
|
||||
|
||||
struct timeval time_init = {0};
|
||||
double timeval_init = 0.0;
|
||||
gettimeofday(&time_init, nullptr);
|
||||
timeval_init = (time_init.tv_sec * second_to_microsecond + time_init.tv_usec) / second_to_millisecond;
|
||||
|
||||
uint32_t devBufferSize;
|
||||
void *picDevBuffer = Utils::GetDeviceBufferOfFile(testFile[index], &devBufferSize);
|
||||
if (picDevBuffer == nullptr) {
|
||||
ERROR_LOG("get pic device buffer failed,index is %zu", index);
|
||||
return FAILED;
|
||||
}
|
||||
ret = processModel.CreateInput(picDevBuffer, devBufferSize);
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("execute CreateInput failed");
|
||||
aclrtFree(picDevBuffer);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs = 0.0;
|
||||
double endTimeMs = 0.0;
|
||||
gettimeofday(&start, nullptr);
|
||||
startTimeMs = (start.tv_sec * second_to_microsecond + start.tv_usec) / second_to_millisecond;
|
||||
|
||||
ret = processModel.Execute();
|
||||
|
||||
gettimeofday(&end, nullptr);
|
||||
endTimeMs = (end.tv_sec * second_to_microsecond + end.tv_usec) / second_to_millisecond;
|
||||
|
||||
double cost_time = endTimeMs - startTimeMs;
|
||||
INFO_LOG("model infer time: %lf ms", cost_time);
|
||||
|
||||
model_cost_time += cost_time;
|
||||
|
||||
double edge_to_edge_cost_time = endTimeMs - timeval_init;
|
||||
edge_to_edge_model_cost_time += edge_to_edge_cost_time;
|
||||
|
||||
if (ret != SUCCESS) {
|
||||
ERROR_LOG("execute inference failed");
|
||||
aclrtFree(picDevBuffer);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int pos = testFile[index].find_last_of('/');
|
||||
std::string name = testFile[index].substr(pos+1);
|
||||
std::string outputname = name.substr(0, name.rfind("."));
|
||||
|
||||
// dump output result to file in the current directory
|
||||
processModel.DumpModelOutputResult(const_cast<char *>(outputname.c_str()));
|
||||
|
||||
// release model input buffer
|
||||
aclrtFree(picDevBuffer);
|
||||
processModel.DestroyInput();
|
||||
}
|
||||
double test_file_size = 0.0;
|
||||
test_file_size = testFile.size();
|
||||
INFO_LOG("infer dataset size:%lf", test_file_size);
|
||||
|
||||
gettimeofday(&end_global, nullptr);
|
||||
startTimeMs_global = (start_global.tv_sec * second_to_microsecond + start_global.tv_usec) / second_to_millisecond;
|
||||
endTimeMs_global = (end_global.tv_sec * second_to_microsecond + end_global.tv_usec) / second_to_millisecond;
|
||||
whole_cost_time = (endTimeMs_global - startTimeMs_global) / test_file_size;
|
||||
|
||||
model_cost_time /= test_file_size;
|
||||
INFO_LOG("model cost time per sample: %lf ms", model_cost_time);
|
||||
edge_to_edge_model_cost_time /= test_file_size;
|
||||
INFO_LOG("edge-to-edge model cost time per sample:%lf ms", edge_to_edge_model_cost_time);
|
||||
INFO_LOG("whole cost time per sample: %lf ms", whole_cost_time);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void SampleProcess::DestroyResource() {
|
||||
aclError ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = aclrtDestroyStream(stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("destroy stream failed");
|
||||
}
|
||||
stream_ = nullptr;
|
||||
}
|
||||
INFO_LOG("end to destroy stream");
|
||||
|
||||
if (context_ != nullptr) {
|
||||
ret = aclrtDestroyContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("destroy context failed");
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
INFO_LOG("end to destroy context");
|
||||
|
||||
ret = aclrtResetDevice(deviceId_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("reset device failed");
|
||||
}
|
||||
INFO_LOG("end to reset device is %d", deviceId_);
|
||||
|
||||
ret = aclFinalize();
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("finalize acl failed");
|
||||
}
|
||||
INFO_LOG("end to finalize acl");
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "../inc/utils.h"
|
||||
#include <sys/stat.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cstring>
|
||||
#include "acl/acl.h"
|
||||
|
||||
extern bool g_is_device;
|
||||
|
||||
void* Utils::ReadBinFile(std::string fileName, uint32_t *fileSize) {
|
||||
struct stat sBuf;
|
||||
int fileStatus = stat(fileName.data(), &sBuf);
|
||||
if (fileStatus == -1) {
|
||||
ERROR_LOG("failed to get file");
|
||||
return nullptr;
|
||||
}
|
||||
if (S_ISREG(sBuf.st_mode) == 0) {
|
||||
ERROR_LOG("%s is not a file, please enter a file", fileName.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::ifstream binFile(fileName, std::ifstream::binary);
|
||||
if (binFile.is_open() == false) {
|
||||
ERROR_LOG("open file %s failed", fileName.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
binFile.seekg(0, binFile.end);
|
||||
uint32_t binFileBufferLen = binFile.tellg();
|
||||
if (binFileBufferLen == 0) {
|
||||
ERROR_LOG("binfile is empty, filename is %s", fileName.c_str());
|
||||
binFile.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
binFile.seekg(0, binFile.beg);
|
||||
|
||||
void* binFileBufferData = nullptr;
|
||||
aclError ret = ACL_ERROR_NONE;
|
||||
if (!g_is_device) {
|
||||
ret = aclrtMallocHost(&binFileBufferData, binFileBufferLen);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("malloc for binFileBufferData failed");
|
||||
binFile.close();
|
||||
return nullptr;
|
||||
}
|
||||
if (binFileBufferData == nullptr) {
|
||||
ERROR_LOG("malloc binFileBufferData failed");
|
||||
binFile.close();
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMalloc(&binFileBufferData, binFileBufferLen, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("malloc device buffer failed. size is %u", binFileBufferLen);
|
||||
binFile.close();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
binFile.read(static_cast<char *>(binFileBufferData), binFileBufferLen);
|
||||
binFile.close();
|
||||
*fileSize = binFileBufferLen;
|
||||
return binFileBufferData;
|
||||
}
|
||||
|
||||
void* Utils::GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize) {
|
||||
uint32_t inputHostBuffSize = 0;
|
||||
void* inputHostBuff = Utils::ReadBinFile(fileName, &inputHostBuffSize);
|
||||
if (inputHostBuff == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!g_is_device) {
|
||||
void *inBufferDev = nullptr;
|
||||
uint32_t inBufferSize = inputHostBuffSize;
|
||||
aclError ret = aclrtMalloc(&inBufferDev, inBufferSize, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("malloc device buffer failed. size is %u", inBufferSize);
|
||||
aclrtFreeHost(inputHostBuff);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ret = aclrtMemcpy(inBufferDev, inBufferSize, inputHostBuff, inputHostBuffSize, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
ERROR_LOG("memcpy failed. device buffer size is %u, input host buffer size is %u",
|
||||
inBufferSize, inputHostBuffSize);
|
||||
aclrtFree(inBufferDev);
|
||||
aclrtFreeHost(inputHostBuff);
|
||||
return nullptr;
|
||||
}
|
||||
aclrtFreeHost(inputHostBuff);
|
||||
*fileSize = inBufferSize;
|
||||
return inBufferDev;
|
||||
} else {
|
||||
*fileSize = inputHostBuffSize;
|
||||
return inputHostBuff;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue