!21956 [特性][Ascend] Add post training quantization of yolov3_darknet53

Merge pull request !21956 from chenzhuo/quant
This commit is contained in:
i-robot 2021-08-19 01:27:17 +00:00 committed by Gitee
commit cf7bb217b0
17 changed files with 1637 additions and 0 deletions

View File

@ -23,6 +23,7 @@ from mindspore import Tensor, context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def quant_unet(network, dataset, input_data):
"""
Export post training quantization model of AIR format.

View File

@ -16,6 +16,7 @@
- [Evaluation](#evaluation)
- [Export MindIR](#export-mindir)
- [Inference Process](#inference-process)
- [Post Training Quantization](#post-training-quantization)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
@ -440,6 +441,52 @@ Inference result is saved in current path, you can find result in acc.log file.
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
```
### [Post Training Quantization](#contents)
Relative executing script files reside in the directory "ascend310_quant_infer". Please implement following steps sequentially to complete post quantization.
Current quantization project bases on COCO2014 dataset.
1. Generate data of .bin format required for AIR model inference at Ascend310 platform.
```shell
python export_bin.py --config_path [YMAL CONFIG PATH] --data_dir [DATA DIR] --annFile [ANNOTATION FILE PATH]
```
2. Export quantized AIR model.
Post quantization of model requires special toolkits for exporting quantized AIR model. Please refer to [official website](https://www.hiascend.com/software/cann/community).
```shell
python post_quant.py --config_path [YMAL CONFIG PATH] --ckpt_file [CKPT_PATH] --data_dir [DATASET PATH] --annFile [ANNOTATION FILE PATH]
```
The quantized AIR file will be stored as "./results/yolov3_quant.air".
3. Implement inference at Ascend310 platform.
```shell
# Ascend310 quant inference
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]
```
Inference result is saved in current path, you can find result like this in acc.log file.
```bash
=============coco eval result=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.306
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.524
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.314
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.319
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.423
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.256
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.395
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.419
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.438
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
```
## [Model Description](#contents)
### [Performance](#contents)

View File

@ -20,6 +20,7 @@
- [推理过程](#推理过程)
- [用法](#用法-2)
- [结果](#结果-2)
- [训练后量化推理](#训练后量化推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
@ -434,6 +435,51 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANNO_PATH] [DEVICE_ID]
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
```
## [训练后量化推理](#contents)
训练后量化推理的相关执行脚本文件在"ascend310_quant_infer"目录下依次执行以下步骤实现训练后量化推理。本训练后量化工程基于COCO2014数据集。
1、生成Ascend310平台AIR模型推理需要的.bin格式数据。
```shell
python export_bin.py --config_path [YMAL CONFIG PATH] --data_dir [DATA DIR] --annFile [ANNOTATION FILE PATH]
```
2、导出训练后量化的AIR格式模型。
导出训练后量化模型需要配套的量化工具包,参考[官方地址](https://www.hiascend.com/software/cann/community)
```shell
python post_quant.py --config_path [YMAL CONFIG PATH] --ckpt_file [CKPT_PATH] --data_dir [DATASET PATH] --annFile [ANNOTATION FILE PATH]
```
导出的模型会存储在./result/yolov3_quant.air。
3、在Ascend310执行推理量化模型。
```shell
# Ascend310 quant inference
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]
```
推理结果保存在脚本执行的当前路径可以在acc.log中看到精度计算结果。
```bash
=============coco eval result=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.306
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.524
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.314
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.319
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.423
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.256
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.395
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.419
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.438
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
```
# 模型描述
## 性能

View File

@ -0,0 +1,231 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess for 310 inference"""
import os
import datetime
import argparse
import sys
from collections import defaultdict
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
parser = argparse.ArgumentParser('YoloV3 quant postprocess')
parser.add_argument('--result_path', type=str, required=True, help='result files path.')
parser.add_argument('--batch_size', default=1, type=int, help='batch size for per gpu')
parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
parser.add_argument('--eval_ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
parser.add_argument('--annFile', type=str, default='', help='path to annotation')
parser.add_argument('--image_shape', type=str, default='./image_shape.npy', help='path to image_shape.npy')
parser.add_argument('--image_id', type=str, default='./image_id.npy', help='path to image_id.npy')
parser.add_argument('--log_path', type=str, default='outputs/', help='inference result save location')
args, _ = parser.parse_known_args()
class Redirct:
def __init__(self):
self.content = ""
def write(self, content):
self.content += content
def flush(self):
self.content = ""
class DetectionEngine:
"""Detection engine."""
def __init__(self):
self.eval_ignore_threshold = args.eval_ignore_threshold
self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
self.num_classes = len(self.labels)
self.results = {}
self.file_path = ''
self.save_prefix = args.outputs_dir
self.annFile = args.annFile
self._coco = COCO(self.annFile)
self._img_ids = list(sorted(self._coco.imgs.keys()))
self.det_boxes = []
self.nms_thresh = args.nms_thresh
self.coco_catIds = self._coco.getCatIds()
def do_nms_for_results(self):
"""Get result boxes."""
for img_id in self.results:
for clsi in self.results[img_id]:
dets = self.results[img_id][clsi]
dets = np.array(dets)
keep_index = self._nms(dets, self.nms_thresh)
keep_box = [{'image_id': int(img_id),
'category_id': int(clsi),
'bbox': list(dets[i][:4].astype(float)),
'score': dets[i][4].astype(float)}
for i in keep_index]
self.det_boxes.extend(keep_box)
def _nms(self, predicts, threshold):
"""Calculate NMS."""
# convert xywh -> xmin ymin xmax ymax
x1 = predicts[:, 0]
y1 = predicts[:, 1]
x2 = x1 + predicts[:, 2]
y2 = y1 + predicts[:, 3]
scores = predicts[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
reserved_boxes = []
while order.size > 0:
i = order[0]
reserved_boxes.append(i)
max_x1 = np.maximum(x1[i], x1[order[1:]])
max_y1 = np.maximum(y1[i], y1[order[1:]])
min_x2 = np.minimum(x2[i], x2[order[1:]])
min_y2 = np.minimum(y2[i], y2[order[1:]])
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
intersect_area = intersect_w * intersect_h
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
indexes = np.where(ovr <= threshold)[0]
order = order[indexes + 1]
return reserved_boxes
def write_result(self):
"""Save result to file."""
import json
t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
try:
self.file_path = self.save_prefix + '/predict' + t + '.json'
f = open(self.file_path, 'w')
json.dump(self.det_boxes, f)
except IOError as e:
raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
else:
f.close()
return self.file_path
def get_eval_result(self):
"""Get eval result."""
cocoGt = COCO(self.annFile)
cocoDt = cocoGt.loadRes(self.file_path)
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
rdct = Redirct()
stdout = sys.stdout
sys.stdout = rdct
cocoEval.summarize()
sys.stdout = stdout
return rdct.content
def detect(self, outputs, batch, image_shape, image_id):
"""Detect boxes."""
outputs_num = len(outputs)
# output [|32, 52, 52, 3, 85| ]
for batch_id in range(batch):
for out_id in range(outputs_num):
# 32, 52, 52, 3, 85
out_item = outputs[out_id]
# 52, 52, 3, 85
out_item_single = out_item[batch_id, :]
# get number of items in one head, [B, gx, gy, anchors, 5+80]
dimensions = out_item_single.shape[:-1]
out_num = 1
for d in dimensions:
out_num *= d
ori_w, ori_h = image_shape[batch_id]
img_id = int(image_id[batch_id])
x = out_item_single[..., 0] * ori_w
y = out_item_single[..., 1] * ori_h
w = out_item_single[..., 2] * ori_w
h = out_item_single[..., 3] * ori_h
conf = out_item_single[..., 4:5]
cls_emb = out_item_single[..., 5:]
cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
x = x.reshape(-1)
y = y.reshape(-1)
w = w.reshape(-1)
h = h.reshape(-1)
cls_emb = cls_emb.reshape(-1, self.num_classes)
conf = conf.reshape(-1)
cls_argmax = cls_argmax.reshape(-1)
x_top_left = x - w / 2.
y_top_left = y - h / 2.
# create all False
flag = np.random.random(cls_emb.shape) > sys.maxsize
for i in range(flag.shape[0]):
c = cls_argmax[i]
flag[i, c] = True
confidence = cls_emb[flag] * conf
for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
if confi < self.eval_ignore_threshold:
continue
if img_id not in self.results:
self.results[img_id] = defaultdict(list)
x_lefti = max(0, x_lefti)
y_lefti = max(0, y_lefti)
wi = min(wi, ori_w)
hi = min(hi, ori_h)
# transform catId to match coco
coco_clsi = self.coco_catIds[clsi]
self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
if __name__ == "__main__":
args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
detection = DetectionEngine()
bs = args.batch_size
shape_list = np.load(args.image_shape)
id_list = np.load(args.image_id)
prefix = "YoloV3-DarkNet_coco_bs_" + str(bs) + "_"
iter_num = 0
for id_img in id_list:
shape_img = shape_list[iter_num]
path_small = os.path.join(args.result_path, prefix + str(iter_num) + '_output_0.bin')
path_medium = os.path.join(args.result_path, prefix + str(iter_num) + '_output_1.bin')
path_big = os.path.join(args.result_path, prefix + str(iter_num) + '_output_2.bin')
if os.path.exists(path_small) and os.path.exists(path_medium) and os.path.exists(path_big):
output_small = np.fromfile(path_small, np.float32).reshape(bs, 13, 13, 3, 85)
output_medium = np.fromfile(path_medium, np.float32).reshape(bs, 26, 26, 3, 85)
output_big = np.fromfile(path_big, np.float32).reshape(bs, 52, 52, 3, 85)
detection.detect([output_small, output_medium, output_big], bs, shape_img, id_img)
else:
print("Error: Image ", iter_num, " is not exist.")
iter_num += 1
detection.do_nms_for_results()
result_file_path = detection.write_result()
eval_result = detection.get_eval_result()
print('\n=============coco eval result=========\n' + eval_result)

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate data and label needed for AIR model inference"""
import os
import sys
import numpy as np
def generate_data():
"""
Generate data and label needed for AIR model inference at Ascend310 platform.
"""
config.batch_size = 1
data_path = os.path.join(config.data_dir, "val2014")
ds, data_size = create_yolo_dataset(data_path, config.annFile, is_training=False, batch_size=config.batch_size,
max_epoch=1, device_num=1, rank=0, shuffle=False, config=config)
print('testing shape : {}'.format(config.test_img_shape))
print('total {} images to eval'.format(data_size))
save_folder = "./data"
image_folder = os.path.join(save_folder, "image_bin")
if not os.path.exists(image_folder):
os.makedirs(image_folder)
list_image_shape = []
list_image_id = []
for i, data in enumerate(ds.create_dict_iterator()):
image = data["image"].asnumpy()
image_shape = data["image_shape"]
image_id = data["img_id"]
file_name = "YoloV3-DarkNet_coco_bs_" + str(config.batch_size) + "_" + str(i) + ".bin"
file_path = image_folder + "/" + file_name
image.tofile(file_path)
list_image_shape.append(image_shape.asnumpy())
list_image_id.append(image_id.asnumpy())
shapes = np.array(list_image_shape)
ids = np.array(list_image_id)
np.save(save_folder + "/image_shape.npy", shapes)
np.save(save_folder + "/image_id.npy", ids)
if __name__ == '__main__':
sys.path.append("..")
from model_utils.config import config
from src.yolo_dataset import create_yolo_dataset
generate_data()

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include "../inc/utils.h"
#include "acl/acl.h"
/**
* ModelProcess
*/
class ModelProcess {
public:
/**
* @brief Constructor
*/
ModelProcess();
/**
* @brief Destructor
*/
~ModelProcess();
/**
* @brief load model from file with mem
* @param [in] modelPath: model path
* @return result
*/
Result LoadModelFromFileWithMem(const char *modelPath);
/**
* @brief unload model
*/
void Unload();
/**
* @brief create model desc
* @return result
*/
Result CreateDesc();
/**
* @brief destroy desc
*/
void DestroyDesc();
/**
* @brief create model input
* @param [in] inputDataBuffer: input buffer
* @param [in] bufferSize: input buffer size
* @return result
*/
Result CreateInput(void *inputDataBuffer, size_t bufferSize);
/**
* @brief destroy input resource
*/
void DestroyInput();
/**
* @brief create output buffer
* @return result
*/
Result CreateOutput();
/**
* @brief destroy output resource
*/
void DestroyOutput();
/**
* @brief model execute
* @return result
*/
Result Execute();
/**
* @brief dump model output result to file
*/
void DumpModelOutputResult(char *output_name);
/**
* @brief get model output result
*/
void OutputModelResult();
private:
uint32_t modelId_;
size_t modelMemSize_;
size_t modelWeightSize_;
void *modelMemPtr_;
void *modelWeightPtr_;
bool loadFlag_; // model load flag
aclmdlDesc *modelDesc_;
aclmdlDataset *input_;
aclmdlDataset *output_;
};

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include "../inc/utils.h"
#include "acl/acl.h"
/**
* SampleProcess
*/
class SampleProcess {
public:
/**
* @brief Constructor
*/
SampleProcess();
/**
* @brief Destructor
*/
~SampleProcess();
/**
* @brief init reousce
* @return result
*/
Result InitResource();
/**
* @brief sample process
* @return result
*/
Result Process(char *om_path, char *input_folder);
void GetAllFiles(std::string path, std::vector<std::string> *files);
private:
void DestroyResource();
int32_t deviceId_;
aclrtContext context_;
aclrtStream stream_;
};

View File

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <string>
#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args)
#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args)
#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args)
typedef enum Result {
SUCCESS = 0,
FAILED = 1
} Result;
/**
* Utils
*/
class Utils {
public:
/**
* @brief create device buffer of file
* @param [in] fileName: file name
* @param [out] fileSize: size of file
* @return device buffer of file
*/
static void *GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize);
/**
* @brief create buffer of file
* @param [in] fileName: file name
* @param [out] fileSize: size of file
* @return buffer of pic
*/
static void* ReadBinFile(std::string fileName, uint32_t *fileSize);
};
#pragma once

View File

@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""do post training quantization for Ascend310"""
import os
import sys
import numpy as np
from amct_mindspore.quantize_tool import create_quant_config
from amct_mindspore.quantize_tool import quantize_model
from amct_mindspore.quantize_tool import save_model
import mindspore as ms
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def quant_yolov3(network, dataset, input_data):
"""
Export post training quantization model of AIR format.
Args:
network: the origin network for inference.
dataset: the data for inference.
input_data: the data used for constructing network. The shape and format of input data should be the same as
actual data for inference.
"""
# step2: create the quant config json file
create_quant_config("./config.json", network, input_data)
# step3: do some network modification and return the modified network
calibration_network = quantize_model("./config.json", network, input_data)
calibration_network.set_train(False)
# step4: perform the evaluation of network to do activation calibration
for _, data in enumerate(dataset.create_dict_iterator(num_epochs=1)):
image = data["image"]
_ = calibration_network(image)
# step5: export the air file
save_model("results/yolov3_quant", calibration_network, input_data)
print("[INFO] the quantized AIR file has been stored at: \n {}".format("results/yolov3_quant.air"))
if __name__ == "__main__":
sys.path.append("..")
from src.yolo import YOLOV3DarkNet53
from src.yolo_dataset import create_yolo_dataset
from model_utils.config import config
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = YOLOV3DarkNet53(is_training=False)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
config.batch_size = 1
data_path = os.path.join(config.data_dir, "val2014")
datasets, data_size = create_yolo_dataset(data_path, config.annFile, is_training=False,
batch_size=config.batch_size, max_epoch=1, device_num=1, rank=0,
shuffle=False, config=config)
ds = datasets.take(1)
shape = [config.batch_size, 3] + config.test_img_shape
inputs = Tensor(np.zeros(shape), ms.float32)
quant_yolov3(net, ds, inputs)

View File

@ -0,0 +1,109 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 5 ]; then
echo "Usage: bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [IMAGE_ID] [IMAGE_SHAPE] [ANN_FILE]"
echo "Example: bash run_quant_infer.sh ./yolov3_quant.air ./image_bin ./image_id.npy ./image_shape.npy \
./instances_val2014.json"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
id_path=$(get_real_path $3)
shape_path=$(get_real_path $4)
ann_path=$(get_real_path $5)
echo "air name: "$model
echo "dataset path: "$data_path
echo "id path: "$id_path
echo "shape path: "$shape_path
echo "ann path: "$ann_path
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function air_to_om()
{
atc --input_format=NCHW --framework=1 --model=$model --output=yolov3_quant --soc_version=Ascend310 &> atc.log
}
function compile_app()
{
bash ./src/build.sh &> build.log
}
function infer()
{
if [ -d result ]; then
rm -rf ./result
fi
mkdir result
./out/main ./yolov3_quant.om $data_path &> infer.log
}
function cal_acc()
{
python3.7 ./acc.py --result_path=./result --annFile=$ann_path --image_shape=$shape_path \
--image_id=$id_path &> acc.log
}
echo "start atc================================================"
air_to_om
if [ $? -ne 0 ]; then
echo "air to om code failed"
exit 1
fi
echo "start compile============================================"
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
echo "start infer=============================================="
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
echo "start calculate acc======================================"
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi

View File

@ -0,0 +1,43 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
# CMake lowest version requirement
cmake_minimum_required(VERSION 3.5.1)
# project information
project(InferClassification)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
# Compile options
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
# Skip build rpath
set(CMAKE_SKIP_BUILD_RPATH True)
# Set output directory
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/../out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
sample_process.cpp
model_process.cpp
main.cpp)
target_link_libraries(main ${acl} gflags pthread)

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
path_cur=$(cd "`dirname $0`" || exit; pwd)
function preparePath() {
rm -rf $1
mkdir -p $1
cd $1 || exit
}
function buildA300() {
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to acllib when it was not specified by user
export ARCH_PATTERN=acllib
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user, reset it to ${ARCH_PATTERN}/acllib"
export ARCH_PATTERN=${ARCH_PATTERN}/acllib
fi
path_build=$path_cur/build
preparePath $path_build
cmake ..
make -j
ret=$?
cd ..
return ${ret}
}
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
buildA300
if [ $? -ne 0 ]; then
exit 1
fi

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "../inc/sample_process.h"
#include "../inc/utils.h"
bool g_is_device = false;
int main(int argc, char **argv) {
if (argc != 3) {
ERROR_LOG("usage:./main path_of_om path_of_inputFolder");
return FAILED;
}
SampleProcess processSample;
Result ret = processSample.InitResource();
if (ret != SUCCESS) {
ERROR_LOG("sample init resource failed");
return FAILED;
}
ret = processSample.Process(argv[1], argv[2]);
if (ret != SUCCESS) {
ERROR_LOG("sample process failed");
return FAILED;
}
INFO_LOG("execute sample success");
return SUCCESS;
}

View File

@ -0,0 +1,339 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/model_process.h"
#include <iostream>
#include <map>
#include <sstream>
#include <algorithm>
#include "../inc/utils.h"
extern bool g_is_device;
ModelProcess::ModelProcess() :modelId_(0), modelMemSize_(0), modelWeightSize_(0), modelMemPtr_(nullptr),
modelWeightPtr_(nullptr), loadFlag_(false), modelDesc_(nullptr), input_(nullptr), output_(nullptr) {
}
ModelProcess::~ModelProcess() {
Unload();
DestroyDesc();
DestroyInput();
DestroyOutput();
}
Result ModelProcess::LoadModelFromFileWithMem(const char *modelPath) {
if (loadFlag_) {
ERROR_LOG("has already loaded a model");
return FAILED;
}
aclError ret = aclmdlQuerySize(modelPath, &modelMemSize_, &modelWeightSize_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("query model failed, model file is %s", modelPath);
return FAILED;
}
ret = aclrtMalloc(&modelMemPtr_, modelMemSize_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc buffer for mem failed, require size is %zu", modelMemSize_);
return FAILED;
}
ret = aclrtMalloc(&modelWeightPtr_, modelWeightSize_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc buffer for weight failed, require size is %zu", modelWeightSize_);
return FAILED;
}
ret = aclmdlLoadFromFileWithMem(modelPath, &modelId_, modelMemPtr_,
modelMemSize_, modelWeightPtr_, modelWeightSize_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("load model from file failed, model file is %s", modelPath);
return FAILED;
}
loadFlag_ = true;
INFO_LOG("load model %s success", modelPath);
return SUCCESS;
}
Result ModelProcess::CreateDesc() {
modelDesc_ = aclmdlCreateDesc();
if (modelDesc_ == nullptr) {
ERROR_LOG("create model description failed");
return FAILED;
}
aclError ret = aclmdlGetDesc(modelDesc_, modelId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("get model description failed");
return FAILED;
}
INFO_LOG("create model description success");
return SUCCESS;
}
void ModelProcess::DestroyDesc() {
if (modelDesc_ != nullptr) {
(void)aclmdlDestroyDesc(modelDesc_);
modelDesc_ = nullptr;
}
}
Result ModelProcess::CreateInput(void *inputDataBuffer, size_t bufferSize) {
input_ = aclmdlCreateDataset();
if (input_ == nullptr) {
ERROR_LOG("can't create dataset, create input failed");
return FAILED;
}
aclDataBuffer* inputData = aclCreateDataBuffer(inputDataBuffer, bufferSize);
if (inputData == nullptr) {
ERROR_LOG("can't create data buffer, create input failed");
return FAILED;
}
aclError ret = aclmdlAddDatasetBuffer(input_, inputData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("add input dataset buffer failed");
aclDestroyDataBuffer(inputData);
inputData = nullptr;
return FAILED;
}
return SUCCESS;
}
void ModelProcess::DestroyInput() {
if (input_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(input_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(input_, i);
aclDestroyDataBuffer(dataBuffer);
}
aclmdlDestroyDataset(input_);
input_ = nullptr;
}
Result ModelProcess::CreateOutput() {
if (modelDesc_ == nullptr) {
ERROR_LOG("no model description, create output failed");
return FAILED;
}
output_ = aclmdlCreateDataset();
if (output_ == nullptr) {
ERROR_LOG("can't create dataset, create output failed");
return FAILED;
}
size_t outputSize = aclmdlGetNumOutputs(modelDesc_);
for (size_t i = 0; i < outputSize; ++i) {
size_t buffer_size = aclmdlGetOutputSizeByIndex(modelDesc_, i);
void *outputBuffer = nullptr;
aclError ret = aclrtMalloc(&outputBuffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't malloc buffer, size is %zu, create output failed", buffer_size);
return FAILED;
}
aclDataBuffer* outputData = aclCreateDataBuffer(outputBuffer, buffer_size);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't create data buffer, create output failed");
aclrtFree(outputBuffer);
return FAILED;
}
ret = aclmdlAddDatasetBuffer(output_, outputData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't add data buffer, create output failed");
aclrtFree(outputBuffer);
aclDestroyDataBuffer(outputData);
return FAILED;
}
}
INFO_LOG("create model output success");
return SUCCESS;
}
void ModelProcess::DumpModelOutputResult(char *output_name) {
size_t outputNum = aclmdlGetDatasetNumBuffers(output_);
for (size_t i = 0; i < outputNum; ++i) {
std::stringstream ss;
ss << "result/" << output_name << "_output_" << i << ".bin";
std::string outputFileName = ss.str();
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
void* outHostData = NULL;
aclError ret = ACL_ERROR_NONE;
if (!g_is_device) {
ret = aclrtMallocHost(&outHostData, len);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
return;
}
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
(void)aclrtFreeHost(outHostData);
return;
}
fwrite(outHostData, len, sizeof(char), outputFile);
ret = aclrtFreeHost(outHostData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
return;
}
} else {
fwrite(data, len, sizeof(char), outputFile);
}
fclose(outputFile);
outputFile = nullptr;
} else {
ERROR_LOG("create output file [%s] failed", outputFileName.c_str());
return;
}
}
INFO_LOG("dump data success");
return;
}
void ModelProcess::OutputModelResult() {
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
void *outHostData = NULL;
aclError ret = ACL_ERROR_NONE;
float *outData = NULL;
if (!g_is_device) {
ret = aclrtMallocHost(&outHostData, len);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
return;
}
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
return;
}
outData = reinterpret_cast<float*>(outHostData);
} else {
outData = reinterpret_cast<float*>(data);
}
std::map<float, unsigned int, std::greater<float> > resultMap;
for (unsigned int j = 0; j < len / sizeof(float); ++j) {
resultMap[*outData] = j;
outData++;
}
int cnt = 0;
for (auto it = resultMap.begin(); it != resultMap.end(); ++it) {
// print top 5
if (++cnt > 5) {
break;
}
INFO_LOG("top %d: index[%d] value[%lf]", cnt, it->second, it->first);
}
if (!g_is_device) {
ret = aclrtFreeHost(outHostData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
return;
}
}
}
INFO_LOG("output data success");
return;
}
void ModelProcess::DestroyOutput() {
if (output_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
(void)aclrtFree(data);
(void)aclDestroyDataBuffer(dataBuffer);
}
(void)aclmdlDestroyDataset(output_);
output_ = nullptr;
}
Result ModelProcess::Execute() {
aclError ret = aclmdlExecute(modelId_, input_, output_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("execute model failed, modelId is %u", modelId_);
return FAILED;
}
INFO_LOG("model execute success");
return SUCCESS;
}
void ModelProcess::Unload() {
if (!loadFlag_) {
WARN_LOG("no model had been loaded, unload failed");
return;
}
aclError ret = aclmdlUnload(modelId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("unload model failed, modelId is %u", modelId_);
}
if (modelDesc_ != nullptr) {
(void)aclmdlDestroyDesc(modelDesc_);
modelDesc_ = nullptr;
}
if (modelMemPtr_ != nullptr) {
aclrtFree(modelMemPtr_);
modelMemPtr_ = nullptr;
modelMemSize_ = 0;
}
if (modelWeightPtr_ != nullptr) {
aclrtFree(modelWeightPtr_);
modelWeightPtr_ = nullptr;
modelWeightSize_ = 0;
}
loadFlag_ = false;
INFO_LOG("unload model success, modelId is %u", modelId_);
}

View File

@ -0,0 +1,252 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/sample_process.h"
#include <sys/time.h>
#include <sys/types.h>
#include <dirent.h>
#include <string.h>
#include <iostream>
#include <fstream>
#include "../inc/model_process.h"
#include "acl/acl.h"
#include "../inc/utils.h"
extern bool g_is_device;
using std::string;
using std::vector;
SampleProcess::SampleProcess() :deviceId_(0), context_(nullptr), stream_(nullptr) {
}
SampleProcess::~SampleProcess() {
DestroyResource();
}
Result SampleProcess::InitResource() {
// ACL init
const char *aclConfigPath = "./src/acl.json";
aclError ret = aclInit(aclConfigPath);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl init failed");
return FAILED;
}
INFO_LOG("acl init success");
// open device
ret = aclrtSetDevice(deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl open device %d failed", deviceId_);
return FAILED;
}
INFO_LOG("open device %d success", deviceId_);
// create context (set current)
ret = aclrtCreateContext(&context_, deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl create context failed");
return FAILED;
}
INFO_LOG("create context success");
// create stream
ret = aclrtCreateStream(&stream_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl create stream failed");
return FAILED;
}
INFO_LOG("create stream success");
// get run mode
aclrtRunMode runMode;
ret = aclrtGetRunMode(&runMode);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl get run mode failed");
return FAILED;
}
g_is_device = (runMode == ACL_DEVICE);
INFO_LOG("get run mode success");
return SUCCESS;
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir = NULL;
struct dirent* ptr;
if (!(pDir = opendir(path.c_str()))) {
return;
}
while ((ptr = readdir(pDir)) != 0) {
if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0) {
files->push_back(path + "/" + ptr->d_name);
}
}
closedir(pDir);
}
Result SampleProcess::Process(char *om_path, char *input_folder) {
// model init
double second_to_millisecond = 1000;
double second_to_microsecond = 1000000;
double whole_cost_time = 0.0;
struct timeval start_global = {0};
struct timeval end_global = {0};
double startTimeMs_global = 0.0;
double endTimeMs_global = 0.0;
gettimeofday(&start_global, nullptr);
ModelProcess processModel;
const char* omModelPath = om_path;
Result ret = processModel.LoadModelFromFileWithMem(omModelPath);
if (ret != SUCCESS) {
ERROR_LOG("execute LoadModelFromFileWithMem failed");
return FAILED;
}
ret = processModel.CreateDesc();
if (ret != SUCCESS) {
ERROR_LOG("execute CreateDesc failed");
return FAILED;
}
ret = processModel.CreateOutput();
if (ret != SUCCESS) {
ERROR_LOG("execute CreateOutput failed");
return FAILED;
}
std::vector<string> testFile;
GetAllFiles(input_folder, &testFile);
if (testFile.size() == 0) {
WARN_LOG("no input data under folder");
}
double model_cost_time = 0.0;
double edge_to_edge_model_cost_time = 0.0;
for (size_t index = 0; index < testFile.size(); ++index) {
INFO_LOG("start to process file:%s", testFile[index].c_str());
// model process
struct timeval time_init = {0};
double timeval_init = 0.0;
gettimeofday(&time_init, nullptr);
timeval_init = (time_init.tv_sec * second_to_microsecond + time_init.tv_usec) / second_to_millisecond;
uint32_t devBufferSize;
void *picDevBuffer = Utils::GetDeviceBufferOfFile(testFile[index], &devBufferSize);
if (picDevBuffer == nullptr) {
ERROR_LOG("get pic device buffer failed,index is %zu", index);
return FAILED;
}
ret = processModel.CreateInput(picDevBuffer, devBufferSize);
if (ret != SUCCESS) {
ERROR_LOG("execute CreateInput failed");
aclrtFree(picDevBuffer);
return FAILED;
}
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs = 0.0;
double endTimeMs = 0.0;
gettimeofday(&start, nullptr);
startTimeMs = (start.tv_sec * second_to_microsecond + start.tv_usec) / second_to_millisecond;
ret = processModel.Execute();
gettimeofday(&end, nullptr);
endTimeMs = (end.tv_sec * second_to_microsecond + end.tv_usec) / second_to_millisecond;
double cost_time = endTimeMs - startTimeMs;
INFO_LOG("model infer time: %lf ms", cost_time);
model_cost_time += cost_time;
double edge_to_edge_cost_time = endTimeMs - timeval_init;
edge_to_edge_model_cost_time += edge_to_edge_cost_time;
if (ret != SUCCESS) {
ERROR_LOG("execute inference failed");
aclrtFree(picDevBuffer);
return FAILED;
}
int pos = testFile[index].find_last_of('/');
std::string name = testFile[index].substr(pos+1);
std::string outputname = name.substr(0, name.rfind("."));
// dump output result to file in the current directory
processModel.DumpModelOutputResult(const_cast<char *>(outputname.c_str()));
// release model input buffer
aclrtFree(picDevBuffer);
processModel.DestroyInput();
}
double test_file_size = 0.0;
test_file_size = testFile.size();
INFO_LOG("infer dataset size:%lf", test_file_size);
gettimeofday(&end_global, nullptr);
startTimeMs_global = (start_global.tv_sec * second_to_microsecond + start_global.tv_usec) / second_to_millisecond;
endTimeMs_global = (end_global.tv_sec * second_to_microsecond + end_global.tv_usec) / second_to_millisecond;
whole_cost_time = (endTimeMs_global - startTimeMs_global) / test_file_size;
model_cost_time /= test_file_size;
INFO_LOG("model cost time per sample: %lf ms", model_cost_time);
edge_to_edge_model_cost_time /= test_file_size;
INFO_LOG("edge-to-edge model cost time per sample:%lf ms", edge_to_edge_model_cost_time);
INFO_LOG("whole cost time per sample: %lf ms", whole_cost_time);
return SUCCESS;
}
void SampleProcess::DestroyResource() {
aclError ret;
if (stream_ != nullptr) {
ret = aclrtDestroyStream(stream_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("destroy stream failed");
}
stream_ = nullptr;
}
INFO_LOG("end to destroy stream");
if (context_ != nullptr) {
ret = aclrtDestroyContext(context_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("destroy context failed");
}
context_ = nullptr;
}
INFO_LOG("end to destroy context");
ret = aclrtResetDevice(deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("reset device failed");
}
INFO_LOG("end to reset device is %d", deviceId_);
ret = aclFinalize();
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("finalize acl failed");
}
INFO_LOG("end to finalize acl");
}

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/utils.h"
#include <sys/stat.h>
#include <iostream>
#include <fstream>
#include <cstring>
#include "acl/acl.h"
extern bool g_is_device;
void* Utils::ReadBinFile(std::string fileName, uint32_t *fileSize) {
struct stat sBuf;
int fileStatus = stat(fileName.data(), &sBuf);
if (fileStatus == -1) {
ERROR_LOG("failed to get file");
return nullptr;
}
if (S_ISREG(sBuf.st_mode) == 0) {
ERROR_LOG("%s is not a file, please enter a file", fileName.c_str());
return nullptr;
}
std::ifstream binFile(fileName, std::ifstream::binary);
if (binFile.is_open() == false) {
ERROR_LOG("open file %s failed", fileName.c_str());
return nullptr;
}
binFile.seekg(0, binFile.end);
uint32_t binFileBufferLen = binFile.tellg();
if (binFileBufferLen == 0) {
ERROR_LOG("binfile is empty, filename is %s", fileName.c_str());
binFile.close();
return nullptr;
}
binFile.seekg(0, binFile.beg);
void* binFileBufferData = nullptr;
aclError ret = ACL_ERROR_NONE;
if (!g_is_device) {
ret = aclrtMallocHost(&binFileBufferData, binFileBufferLen);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc for binFileBufferData failed");
binFile.close();
return nullptr;
}
if (binFileBufferData == nullptr) {
ERROR_LOG("malloc binFileBufferData failed");
binFile.close();
return nullptr;
}
} else {
ret = aclrtMalloc(&binFileBufferData, binFileBufferLen, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc device buffer failed. size is %u", binFileBufferLen);
binFile.close();
return nullptr;
}
}
binFile.read(static_cast<char *>(binFileBufferData), binFileBufferLen);
binFile.close();
*fileSize = binFileBufferLen;
return binFileBufferData;
}
void* Utils::GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize) {
uint32_t inputHostBuffSize = 0;
void* inputHostBuff = Utils::ReadBinFile(fileName, &inputHostBuffSize);
if (inputHostBuff == nullptr) {
return nullptr;
}
if (!g_is_device) {
void *inBufferDev = nullptr;
uint32_t inBufferSize = inputHostBuffSize;
aclError ret = aclrtMalloc(&inBufferDev, inBufferSize, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc device buffer failed. size is %u", inBufferSize);
aclrtFreeHost(inputHostBuff);
return nullptr;
}
ret = aclrtMemcpy(inBufferDev, inBufferSize, inputHostBuff, inputHostBuffSize, ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("memcpy failed. device buffer size is %u, input host buffer size is %u",
inBufferSize, inputHostBuffSize);
aclrtFree(inBufferDev);
aclrtFreeHost(inputHostBuff);
return nullptr;
}
aclrtFreeHost(inputHostBuff);
*fileSize = inBufferSize;
return inBufferDev;
} else {
*fileSize = inputHostBuffSize;
return inputHostBuff;
}
}