add post training quantization of yolov3_resnet18

This commit is contained in:
chenzhuo 2021-08-20 09:54:15 +08:00
parent 5d1bb097e2
commit 53442e9ac0
25 changed files with 1697 additions and 19 deletions

View File

@ -43,9 +43,10 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi

View File

@ -102,7 +102,9 @@ def generate_data():
batch_msk_lst = []
shape_lst = []
for i, line in enumerate(img_lst):
img_path, msk_path = line.strip().split(" ")
ori_img_path, ori_msk_path = line.strip().split(" ")
img_path = "VOCdevkit" + ori_img_path.split("VOCdevkit")[1]
msk_path = "VOCdevkit" + ori_msk_path.split("VOCdevkit")[1]
img_path = os.path.join(args.data_root, img_path)
msk_path = os.path.join(args.data_root, msk_path)
org_width, org_height = get_img_size(img_path)

View File

@ -87,7 +87,9 @@ def generate_batch_data():
# evaluate
batch_img_lst = []
img_path, _ = img_lst[0].strip().split(" ")
ori_img_path, _ = img_lst[0].strip().split(" ")
img_path = "VOCdevkit" + ori_img_path.split("VOCdevkit")[1]
img_path = os.path.join(args.data_root, img_path)
img_ = cv2.imread(img_path)
batch_img_lst.append(img_)

View File

@ -45,9 +45,10 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi

View File

@ -45,9 +45,10 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi

View File

@ -43,9 +43,10 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi

View File

@ -43,9 +43,10 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi

View File

@ -48,8 +48,8 @@ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$LD_LIBRARY_PATH
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp

View File

@ -17,6 +17,7 @@
- [Inference Process](#inference-process)
- [Usage](#usage)
- [result](#result)
- [Post Training Quantization](#post-training-quantization)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
@ -358,6 +359,44 @@ Inference result is saved in current path, you can find result in acc.log file.
class 1 precision is 85.34%, recall is 79.13%
```
## [Post Training Quantization](#contents)
Relative executing script files reside in the directory "ascend310_quant_infer". Please implement following steps sequentially to complete post quantization.
Note the precision and recall values are results of two-classification(person and face) used our own annotations with COCO2017 dataset.
Note quantization-related config file utils.py is located in the directory ascend310_quant_infer.
1. Generate data of .bin format required for AIR model inference at Ascend310 platform.
```shell
python export_bin.py --image_dir [COCO DATA PATH] --eval_mindrecord_dir [MINDRECORD PATH] --anno_path [LABEL PATH]
```
Note that image_dir is set as the parent directory of COCO dataset.
2. Export quantized AIR model.
Post quantization of model requires special toolkits for exporting quantized AIR model. Please refer to [official website](https://www.hiascend.com/software/cann/community).
```shell
python post_quant.py --image_dir [COCO DATA PATH] --eval_mindrecord_dir [MINDRECORD PATH] --anno_path [LABEL PATH] --ckpt_file [CKPT_PATH]
```
The quantized AIR file will be stored as "./results/yolov3_resnet_quant.air".
3. Implement inference at Ascend310 platform.
```shell
# Ascend310 quant inference
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [SHAPE_PATH] [ANNOTATION_PATH]
```
Inference result is saved in current path, you can find result like this in acc.log file.
```bash
class 0 precision is 91.34%, recall is 64.92%
class 1 precision is 94.61%, recall is 64.07%
```
# [Model Description](#contents)
## [Performance](#contents)

View File

@ -19,6 +19,7 @@
- [推理过程](#推理过程)
- [用法](#用法)
- [结果](#结果)
- [训练后量化推理](#训练后量化推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
@ -355,6 +356,44 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANNO_PATH] [DEVICE_ID]
class 1 precision is 85.34%, recall is 79.13%
```
## [训练后量化推理](#contents)
训练后量化推理的相关执行脚本文件在"ascend310_quant_infer"目录下,依次执行以下步骤实现训练后量化推理。
注意精度和召回值是使用我们自己的标注和COCO2017的两种分类人与脸的结果。
注意训练后量化端测推理有关的文件utils.py位于ascend310_quant_infer目录下。
1、生成Ascend310平台AIR模型推理需要的.bin格式数据。
```shell
python export_bin.py --image_dir [COCO DATA PATH] --eval_mindrecord_dir [MINDRECORD PATH] --ann_file [ANNOTATION PATH]
```
注意image_dir设置成COCO数据集的上级目录。
2、导出训练后量化的AIR格式模型。
导出训练后量化模型需要配套的量化工具包,参考[官方地址](https://www.hiascend.com/software/cann/community)
```shell
python post_quant.py --image_dir [COCO DATA PATH] --eval_mindrecord_dir [MINDRECORD PATH] --ckpt_file [CKPT_PATH]
```
导出的模型会存储在./result/yolov3_resnet_quant.air。
3、在Ascend310执行推理量化模型。
```shell
# Ascend310 inference
bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [SHAPE_PATH] [ANNOTATION_PATH]
```
推理结果保存在脚本执行的当前路径可以在acc.log中看到精度计算结果。
```bash
class 0 precision is 91.34%, recall is 64.92%
class 1 precision is 94.61%, recall is 64.07%
```
# 模型描述
## 性能

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import argparse
import numpy as np
from utils import metrics
parser = argparse.ArgumentParser("yolov3_resnet18 quant postprocess")
parser.add_argument("--anno_path", type=str, required=True, help="path to annotation.npy")
parser.add_argument("--result_path", type=str, required=True, help="path to inference results.")
parser.add_argument("--batch_size", type=int, default=1, help="batch size of data.")
parser.add_argument("--num_classes", type=int, default=2, help="number of classed to detect.")
args, _ = parser.parse_known_args()
def calculate_acc():
""" Calculate accuracy of yolov3_resnet18 inference"""
ann = np.load(args.anno_path, allow_pickle=True)
pred_data = []
prefix = "Yolov3-resnet18_coco_bs_" + str(args.batch_size) + "_"
for i in range(len(ann)):
result0 = os.path.join(args.result_path, prefix + str(i) + "_output_0.bin")
result1 = os.path.join(args.result_path, prefix + str(i) + "_output_1.bin")
output0 = np.fromfile(result0, np.float32).reshape(args.batch_size, 13860, 4)
output1 = np.fromfile(result1, np.float32).reshape(args.batch_size, 13860, 2)
for batch_idx in range(args.batch_size):
pred_data.append({"boxes": output0[batch_idx],
"box_scores": output1[batch_idx],
"annotation": ann[i]})
precisions, recalls = metrics(pred_data)
for j in range(args.num_classes):
print("class {} precision is {:.2f}%, recall is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100))
if __name__ == '__main__':
calculate_acc()

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate data and label needed for AIR model inference"""
import os
import sys
import numpy as np
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def generate_data(dataset_path):
"""
Generate data and label needed for AIR model inference at Ascend310 platform.
"""
ds = create_yolo_dataset(dataset_path, is_training=False)
cur_dir = os.getcwd() + "/data"
img_folder = cur_dir + "/00_image"
if not os.path.exists(img_folder):
os.makedirs(img_folder)
shape_folder = cur_dir + "/01_image_shape"
if not os.path.exists(shape_folder):
os.makedirs(shape_folder)
total = ds.get_dataset_size()
ann_list = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
prefix = "Yolov3-resnet18_coco_bs_1_"
for i, data in enumerate(ds.create_dict_iterator(output_numpy=True, num_epochs=1)):
image_np = data['image']
image_shape = data['image_shape']
annotation = data['annotation']
file_name = prefix + str(i) + ".bin"
image_path = os.path.join(img_folder, file_name)
image_np.tofile(image_path)
shape_path = os.path.join(shape_folder, file_name)
image_shape.tofile(shape_path)
ann_list.append(annotation)
ann_file = np.array(ann_list)
np.save(os.path.join(cur_dir, "annotation_list.npy"), ann_file)
if __name__ == "__main__":
sys.path.append("..")
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from model_utils.config import config as default_config
if not os.path.isdir(default_config.eval_mindrecord_dir):
os.makedirs(default_config.eval_mindrecord_dir)
yolo_prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(default_config.eval_mindrecord_dir, yolo_prefix + "0")
if not os.path.exists(mindrecord_file):
if os.path.isdir(default_config.image_dir) and os.path.exists(default_config.anno_path):
print("Create Mindrecord")
data_to_mindrecord_byte_image(default_config.image_dir,
default_config.anno_path,
default_config.eval_mindrecord_dir,
prefix=yolo_prefix,
file_num=8)
print("Create Mindrecord Done, at {}".format(default_config.eval_mindrecord_dir))
else:
print("image_dir or anno_path not exits")
generate_data(mindrecord_file)

View File

@ -0,0 +1,112 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <vector>
#include "../inc/utils.h"
#include "acl/acl.h"
/**
* ModelProcess
*/
class ModelProcess {
public:
/**
* @brief Constructor
*/
ModelProcess();
/**
* @brief Destructor
*/
~ModelProcess();
/**
* @brief load model from file with mem
* @param [in] modelPath: model path
* @return result
*/
Result LoadModelFromFileWithMem(const char *modelPath);
/**
* @brief unload model
*/
void Unload();
/**
* @brief create model desc
* @return result
*/
Result CreateDesc();
/**
* @brief destroy desc
*/
void DestroyDesc();
/**
* @brief create model input
* @param [in] inputDataBuffer: input buffer
* @param [in] bufferSize: input buffer size
* @return result
*/
Result CreateInput(const std::vector<void *> &inputDataBuffer, const std::vector<size_t> &bufferSize);
/**
* @brief destroy input resource
*/
void DestroyInput();
/**
* @brief create output buffer
* @return result
*/
Result CreateOutput();
/**
* @brief destroy output resource
*/
void DestroyOutput();
/**
* @brief model execute
* @return result
*/
Result Execute();
/**
* @brief dump model output result to file
*/
void DumpModelOutputResult(char *output_name);
/**
* @brief get model output result
*/
void OutputModelResult();
private:
uint32_t modelId_;
size_t modelMemSize_;
size_t modelWeightSize_;
void *modelMemPtr_;
void *modelWeightPtr_;
bool loadFlag_; // model load flag
aclmdlDesc *modelDesc_;
aclmdlDataset *input_;
aclmdlDataset *output_;
};

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include "../inc/utils.h"
#include "acl/acl.h"
/**
* SampleProcess
*/
class SampleProcess {
public:
/**
* @brief Constructor
*/
SampleProcess();
/**
* @brief Destructor
*/
~SampleProcess();
/**
* @brief init reousce
* @return result
*/
Result InitResource();
/**
* @brief sample process
* @return result
*/
Result Process(char *om_path, char *input_folder, char *shape_folder);
void GetAllFiles(std::string path, std::vector<std::string> *files);
private:
void DestroyResource();
int32_t deviceId_;
aclrtContext context_;
aclrtStream stream_;
};

View File

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <iostream>
#include <string>
#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args)
#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args)
#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args)
typedef enum Result {
SUCCESS = 0,
FAILED = 1
} Result;
/**
* Utils
*/
class Utils {
public:
/**
* @brief create device buffer of file
* @param [in] fileName: file name
* @param [out] fileSize: size of file
* @return device buffer of file
*/
static void *GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize);
/**
* @brief create buffer of file
* @param [in] fileName: file name
* @param [out] fileSize: size of file
* @return buffer of pic
*/
static void* ReadBinFile(std::string fileName, uint32_t *fileSize);
};
#pragma once

View File

@ -0,0 +1,101 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""do post training quantization for Ascend310"""
import os
import sys
import numpy as np
from amct_mindspore.quantize_tool import create_quant_config
from amct_mindspore.quantize_tool import quantize_model
from amct_mindspore.quantize_tool import save_model
import mindspore as ms
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def quant_yolov3_resnet(network, dataset, input_data):
"""
Export post training quantization model of AIR format.
Args:
network: the origin network for inference.
dataset: the data for inference.
input_data: the data used for constructing network. The shape and format of input data should be the same as
actual data for inference.
"""
# step2: create the quant config json file
create_quant_config("./config.json", network, *input_data)
# step3: do some network modification and return the modified network
calibration_network = quantize_model("./config.json", network, *input_data)
calibration_network.set_train(False)
# step4: perform the evaluation of network to do activation calibration
for data in dataset.create_dict_iterator(num_epochs=1):
_ = calibration_network(data["image"], data["image_shape"])
# step5: export the air file
save_model("results/yolov3_resnet_quant", calibration_network, *input_data)
print("[INFO] the quantized AIR file has been stored at: \n {}".format("results/yolov3_resnet_quant.air"))
def export_yolov3_resnet():
""" prepare for quantization of yolov3_resnet """
cfg = ConfigYOLOV3ResNet18()
net = yolov3_resnet18(cfg)
eval_net = YoloWithEval(net, cfg)
param_dict = load_checkpoint(default_config.ckpt_file)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
default_config.export_batch_size = 1
shape = [default_config.export_batch_size, 3] + cfg.img_shape
input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(np.zeros([1, 2]), ms.float32)
inputs = (input_data, input_shape)
if not os.path.isdir(default_config.eval_mindrecord_dir):
os.makedirs(default_config.eval_mindrecord_dir)
yolo_prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(default_config.eval_mindrecord_dir, yolo_prefix + "0")
if not os.path.exists(mindrecord_file):
if os.path.isdir(default_config.image_dir) and os.path.exists(default_config.anno_path):
print("Create Mindrecord")
data_to_mindrecord_byte_image(default_config.image_dir,
default_config.anno_path,
default_config.eval_mindrecord_dir,
prefix=yolo_prefix,
file_num=8)
print("Create Mindrecord Done, at {}".format(default_config.eval_mindrecord_dir))
else:
print("image_dir or anno_path not exits")
datasets = create_yolo_dataset(mindrecord_file, is_training=False)
ds = datasets.take(1)
quant_yolov3_resnet(eval_net, ds, inputs)
if __name__ == "__main__":
sys.path.append("..")
from src.yolov3 import yolov3_resnet18, YoloWithEval
from src.config import ConfigYOLOV3ResNet18
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from model_utils.config import config as default_config
export_yolov3_resnet()

View File

@ -0,0 +1,105 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 4 ]; then
echo "Usage: bash run_quant_infer.sh [AIR_PATH] [DATA_PATH] [SHAPE_PATH] [ANNOTATION_PATH]"
echo "Example: bash run_quant_infer.sh ./yolov3_resnet_quant.air ./00_image ./01_image_shape ./annotation_list.npy"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
shape_path=$(get_real_path $3)
annotation_path=$(get_real_path $4)
echo "air name: "$model
echo "dataset path: "$data_path
echo "shape path: "$shape_path
echo "annotation path: "$annotation_path
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/fwkacllib/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/atc/lib64:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function air_to_om()
{
atc --input_format=NCHW --framework=1 --model=$model --output=yolov3_resnet_quant --soc_version=Ascend310 &> atc.log
}
function compile_app()
{
bash ./src/build.sh &> build.log
}
function infer()
{
if [ -d result ]; then
rm -rf ./result
fi
mkdir result
./out/main ./yolov3_resnet_quant.om $data_path $shape_path &> infer.log
}
function cal_acc()
{
python3.7 ./acc.py --result_path=./result --anno_path=$annotation_path &> acc.log
}
echo "start atc================================================"
air_to_om
if [ $? -ne 0 ]; then
echo "air to om code failed"
exit 1
fi
echo "start compile============================================"
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
echo "start infer=============================================="
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
echo "start calculate acc======================================"
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi

View File

@ -0,0 +1,43 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
# CMake lowest version requirement
cmake_minimum_required(VERSION 3.5.1)
# project information
project(InferClassification)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
# Compile options
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
# Skip build rpath
set(CMAKE_SKIP_BUILD_RPATH True)
# Set output directory
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/../out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
sample_process.cpp
model_process.cpp
main.cpp)
target_link_libraries(main ${acl} gflags pthread)

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
path_cur=$(cd "`dirname $0`" || exit; pwd)
function preparePath() {
rm -rf $1
mkdir -p $1
cd $1 || exit
}
function buildA300() {
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to acllib when it was not specified by user
export ARCH_PATTERN=acllib
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user, reset it to ${ARCH_PATTERN}/acllib"
export ARCH_PATTERN=${ARCH_PATTERN}/acllib
fi
path_build=$path_cur/build
preparePath $path_build
cmake ..
make -j
ret=$?
cd ..
return ${ret}
}
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
buildA300
if [ $? -ne 0 ]; then
exit 1
fi

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "../inc/sample_process.h"
#include "../inc/utils.h"
bool g_is_device = false;
int main(int argc, char **argv) {
if (argc != 4) {
ERROR_LOG("usage:./main path_of_om path_of_inputFolder path_of_shapeFolder");
return FAILED;
}
SampleProcess processSample;
Result ret = processSample.InitResource();
if (ret != SUCCESS) {
ERROR_LOG("sample init resource failed");
return FAILED;
}
ret = processSample.Process(argv[1], argv[2], argv[3]);
if (ret != SUCCESS) {
ERROR_LOG("sample process failed");
return FAILED;
}
INFO_LOG("execute sample success");
return SUCCESS;
}

View File

@ -0,0 +1,339 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/model_process.h"
#include <iostream>
#include <map>
#include <sstream>
#include <algorithm>
#include "../inc/utils.h"
extern bool g_is_device;
ModelProcess::ModelProcess() :modelId_(0), modelMemSize_(0), modelWeightSize_(0), modelMemPtr_(nullptr),
modelWeightPtr_(nullptr), loadFlag_(false), modelDesc_(nullptr), input_(nullptr), output_(nullptr) {
}
ModelProcess::~ModelProcess() {
Unload();
DestroyDesc();
DestroyInput();
DestroyOutput();
}
Result ModelProcess::LoadModelFromFileWithMem(const char *modelPath) {
if (loadFlag_) {
ERROR_LOG("has already loaded a model");
return FAILED;
}
aclError ret = aclmdlQuerySize(modelPath, &modelMemSize_, &modelWeightSize_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("query model failed, model file is %s", modelPath);
return FAILED;
}
ret = aclrtMalloc(&modelMemPtr_, modelMemSize_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc buffer for mem failed, require size is %zu", modelMemSize_);
return FAILED;
}
ret = aclrtMalloc(&modelWeightPtr_, modelWeightSize_, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc buffer for weight failed, require size is %zu", modelWeightSize_);
return FAILED;
}
ret = aclmdlLoadFromFileWithMem(modelPath, &modelId_, modelMemPtr_,
modelMemSize_, modelWeightPtr_, modelWeightSize_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("load model from file failed, model file is %s", modelPath);
return FAILED;
}
loadFlag_ = true;
INFO_LOG("load model %s success", modelPath);
return SUCCESS;
}
Result ModelProcess::CreateDesc() {
modelDesc_ = aclmdlCreateDesc();
if (modelDesc_ == nullptr) {
ERROR_LOG("create model description failed");
return FAILED;
}
aclError ret = aclmdlGetDesc(modelDesc_, modelId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("get model description failed");
return FAILED;
}
INFO_LOG("create model description success");
return SUCCESS;
}
void ModelProcess::DestroyDesc() {
if (modelDesc_ != nullptr) {
(void)aclmdlDestroyDesc(modelDesc_);
modelDesc_ = nullptr;
}
}
Result ModelProcess::CreateInput(const std::vector<void *> &inputDataBuffer, const std::vector<size_t> &bufferSize) {
input_ = aclmdlCreateDataset();
if (input_ == nullptr) {
ERROR_LOG("can't create dataset, create input failed");
return FAILED;
}
for (size_t i = 0; i < inputDataBuffer.size(); ++i) {
aclDataBuffer* inputData = aclCreateDataBuffer(inputDataBuffer[i], bufferSize[i]);
if (inputData == nullptr) {
ERROR_LOG("can't create data buffer, create input failed");
return FAILED;
}
aclError ret = aclmdlAddDatasetBuffer(input_, inputData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("add input dataset buffer failed");
aclDestroyDataBuffer(inputData);
inputData = nullptr;
return FAILED;
}
}
return SUCCESS;
}
void ModelProcess::DestroyInput() {
if (input_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(input_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(input_, i);
aclDestroyDataBuffer(dataBuffer);
}
aclmdlDestroyDataset(input_);
input_ = nullptr;
}
Result ModelProcess::CreateOutput() {
if (modelDesc_ == nullptr) {
ERROR_LOG("no model description, create output failed");
return FAILED;
}
output_ = aclmdlCreateDataset();
if (output_ == nullptr) {
ERROR_LOG("can't create dataset, create output failed");
return FAILED;
}
size_t outputSize = aclmdlGetNumOutputs(modelDesc_);
for (size_t i = 0; i < outputSize; ++i) {
size_t buffer_size = aclmdlGetOutputSizeByIndex(modelDesc_, i);
void *outputBuffer = nullptr;
aclError ret = aclrtMalloc(&outputBuffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't malloc buffer, size is %zu, create output failed", buffer_size);
return FAILED;
}
aclDataBuffer* outputData = aclCreateDataBuffer(outputBuffer, buffer_size);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't create data buffer, create output failed");
aclrtFree(outputBuffer);
return FAILED;
}
ret = aclmdlAddDatasetBuffer(output_, outputData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("can't add data buffer, create output failed");
aclrtFree(outputBuffer);
aclDestroyDataBuffer(outputData);
return FAILED;
}
}
INFO_LOG("create model output success");
return SUCCESS;
}
void ModelProcess::DumpModelOutputResult(char *output_name) {
size_t outputNum = aclmdlGetDatasetNumBuffers(output_);
for (size_t i = 0; i < outputNum; ++i) {
std::stringstream ss;
ss << "result/" << output_name << "_output_" << i << ".bin";
std::string outputFileName = ss.str();
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
void* outHostData = NULL;
aclError ret = ACL_ERROR_NONE;
if (!g_is_device) {
ret = aclrtMallocHost(&outHostData, len);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
return;
}
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
(void)aclrtFreeHost(outHostData);
return;
}
fwrite(outHostData, len, sizeof(char), outputFile);
ret = aclrtFreeHost(outHostData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
return;
}
} else {
fwrite(data, len, sizeof(char), outputFile);
}
fclose(outputFile);
outputFile = nullptr;
} else {
ERROR_LOG("create output file [%s] failed", outputFileName.c_str());
return;
}
}
INFO_LOG("dump data success");
return;
}
void ModelProcess::OutputModelResult() {
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
void *outHostData = NULL;
aclError ret = ACL_ERROR_NONE;
float *outData = NULL;
if (!g_is_device) {
ret = aclrtMallocHost(&outHostData, len);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMallocHost failed, ret[%d]", ret);
return;
}
ret = aclrtMemcpy(outHostData, len, data, len, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtMemcpy failed, ret[%d]", ret);
return;
}
outData = reinterpret_cast<float*>(outHostData);
} else {
outData = reinterpret_cast<float*>(data);
}
std::map<float, unsigned int, std::greater<float> > resultMap;
for (unsigned int j = 0; j < len / sizeof(float); ++j) {
resultMap[*outData] = j;
outData++;
}
int cnt = 0;
for (auto it = resultMap.begin(); it != resultMap.end(); ++it) {
// print top 5
if (++cnt > 5) {
break;
}
INFO_LOG("top %d: index[%d] value[%lf]", cnt, it->second, it->first);
}
if (!g_is_device) {
ret = aclrtFreeHost(outHostData);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("aclrtFreeHost failed, ret[%d]", ret);
return;
}
}
}
INFO_LOG("output data success");
return;
}
void ModelProcess::DestroyOutput() {
if (output_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(output_); ++i) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
(void)aclrtFree(data);
(void)aclDestroyDataBuffer(dataBuffer);
}
(void)aclmdlDestroyDataset(output_);
output_ = nullptr;
}
Result ModelProcess::Execute() {
aclError ret = aclmdlExecute(modelId_, input_, output_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("execute model failed, modelId is %u", modelId_);
return FAILED;
}
INFO_LOG("model execute success");
return SUCCESS;
}
void ModelProcess::Unload() {
if (!loadFlag_) {
WARN_LOG("no model had been loaded, unload failed");
return;
}
aclError ret = aclmdlUnload(modelId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("unload model failed, modelId is %u", modelId_);
}
if (modelDesc_ != nullptr) {
(void)aclmdlDestroyDesc(modelDesc_);
modelDesc_ = nullptr;
}
if (modelMemPtr_ != nullptr) {
aclrtFree(modelMemPtr_);
modelMemPtr_ = nullptr;
modelMemSize_ = 0;
}
if (modelWeightPtr_ != nullptr) {
aclrtFree(modelWeightPtr_);
modelWeightPtr_ = nullptr;
modelWeightSize_ = 0;
}
loadFlag_ = false;
INFO_LOG("unload model success, modelId is %u", modelId_);
}

View File

@ -0,0 +1,263 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/sample_process.h"
#include <sys/time.h>
#include <sys/types.h>
#include <dirent.h>
#include <string.h>
#include <iostream>
#include <fstream>
#include "../inc/model_process.h"
#include "acl/acl.h"
#include "../inc/utils.h"
extern bool g_is_device;
using std::string;
using std::vector;
SampleProcess::SampleProcess() :deviceId_(0), context_(nullptr), stream_(nullptr) {
}
SampleProcess::~SampleProcess() {
DestroyResource();
}
Result SampleProcess::InitResource() {
// ACL init
const char *aclConfigPath = "./src/acl.json";
aclError ret = aclInit(aclConfigPath);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl init failed");
return FAILED;
}
INFO_LOG("acl init success");
// open device
ret = aclrtSetDevice(deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl open device %d failed", deviceId_);
return FAILED;
}
INFO_LOG("open device %d success", deviceId_);
// create context (set current)
ret = aclrtCreateContext(&context_, deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl create context failed");
return FAILED;
}
INFO_LOG("create context success");
// create stream
ret = aclrtCreateStream(&stream_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl create stream failed");
return FAILED;
}
INFO_LOG("create stream success");
// get run mode
aclrtRunMode runMode;
ret = aclrtGetRunMode(&runMode);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("acl get run mode failed");
return FAILED;
}
g_is_device = (runMode == ACL_DEVICE);
INFO_LOG("get run mode success");
return SUCCESS;
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir = NULL;
struct dirent* ptr;
if (!(pDir = opendir(path.c_str()))) {
return;
}
while ((ptr = readdir(pDir)) != 0) {
if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0) {
files->push_back(path + "/" + ptr->d_name);
}
}
closedir(pDir);
}
Result SampleProcess::Process(char *om_path, char *input_folder, char *shape_folder) {
// model init
double second_to_millisecond = 1000;
double second_to_microsecond = 1000000;
double whole_cost_time = 0.0;
struct timeval start_global = {0};
struct timeval end_global = {0};
double startTimeMs_global = 0.0;
double endTimeMs_global = 0.0;
gettimeofday(&start_global, nullptr);
ModelProcess processModel;
const char* omModelPath = om_path;
Result ret = processModel.LoadModelFromFileWithMem(omModelPath);
if (ret != SUCCESS) {
ERROR_LOG("execute LoadModelFromFileWithMem failed");
return FAILED;
}
ret = processModel.CreateDesc();
if (ret != SUCCESS) {
ERROR_LOG("execute CreateDesc failed");
return FAILED;
}
ret = processModel.CreateOutput();
if (ret != SUCCESS) {
ERROR_LOG("execute CreateOutput failed");
return FAILED;
}
std::vector<string> testFile;
GetAllFiles(input_folder, &testFile);
std::vector<string> shapeFile;
GetAllFiles(shape_folder, &shapeFile);
if (testFile.size() !=shapeFile.size()) {
ERROR_LOG("number of data files is not equal to shape file");
}
double model_cost_time = 0.0;
double edge_to_edge_model_cost_time = 0.0;
for (size_t index = 0; index < testFile.size(); ++index) {
INFO_LOG("start to process data file:%s", testFile[index].c_str());
INFO_LOG("start to process shape file:%s", shapeFile[index].c_str());
// model process
struct timeval time_init = {0};
double timeval_init = 0.0;
gettimeofday(&time_init, nullptr);
timeval_init = (time_init.tv_sec * second_to_microsecond + time_init.tv_usec) / second_to_millisecond;
uint32_t devBufferSize;
void *picDevBuffer = Utils::GetDeviceBufferOfFile(testFile[index], &devBufferSize);
if (picDevBuffer == nullptr) {
ERROR_LOG("get pic device buffer failed, index is %zu", index);
return FAILED;
}
uint32_t devBufferShapeSize;
void *shapeDevBuffer = Utils::GetDeviceBufferOfFile(shapeFile[index], &devBufferShapeSize);
if (shapeDevBuffer == nullptr) {
ERROR_LOG("get shape device buffer failed, index is %zu", index);
return FAILED;
}
std::vector<void *> inputBuffers({picDevBuffer, shapeDevBuffer});
std::vector<size_t> inputSizes({devBufferSize, devBufferShapeSize});
ret = processModel.CreateInput(inputBuffers, inputSizes);
if (ret != SUCCESS) {
ERROR_LOG("execute CreateInput failed");
aclrtFree(picDevBuffer);
return FAILED;
}
struct timeval start = {0};
struct timeval end = {0};
gettimeofday(&start, nullptr);
double startTimeMs = (start.tv_sec * second_to_microsecond + start.tv_usec) / second_to_millisecond;
ret = processModel.Execute();
gettimeofday(&end, nullptr);
double endTimeMs = (end.tv_sec * second_to_microsecond + end.tv_usec) / second_to_millisecond;
double cost_time = endTimeMs - startTimeMs;
INFO_LOG("model infer time: %lf ms", cost_time);
model_cost_time += cost_time;
double edge_to_edge_cost_time = endTimeMs - timeval_init;
edge_to_edge_model_cost_time += edge_to_edge_cost_time;
if (ret != SUCCESS) {
ERROR_LOG("execute inference failed");
aclrtFree(picDevBuffer);
return FAILED;
}
int pos = testFile[index].find_last_of('/');
std::string name = testFile[index].substr(pos+1);
std::string outputname = name.substr(0, name.rfind("."));
// dump output result to file in the current directory
processModel.DumpModelOutputResult(const_cast<char *>(outputname.c_str()));
// release model input buffer
aclrtFree(picDevBuffer);
processModel.DestroyInput();
}
double test_file_size = testFile.size();
INFO_LOG("infer dataset size:%lf", test_file_size);
gettimeofday(&end_global, nullptr);
startTimeMs_global = (start_global.tv_sec * second_to_microsecond + start_global.tv_usec) / second_to_millisecond;
endTimeMs_global = (end_global.tv_sec * second_to_microsecond + end_global.tv_usec) / second_to_millisecond;
whole_cost_time = (endTimeMs_global - startTimeMs_global) / test_file_size;
model_cost_time /= test_file_size;
INFO_LOG("model cost time per sample: %lf ms", model_cost_time);
edge_to_edge_model_cost_time /= test_file_size;
INFO_LOG("edge-to-edge model cost time per sample:%lf ms", edge_to_edge_model_cost_time);
INFO_LOG("whole cost time per sample: %lf ms", whole_cost_time);
return SUCCESS;
}
void SampleProcess::DestroyResource() {
aclError ret;
if (stream_ != nullptr) {
ret = aclrtDestroyStream(stream_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("destroy stream failed");
}
stream_ = nullptr;
}
INFO_LOG("end to destroy stream");
if (context_ != nullptr) {
ret = aclrtDestroyContext(context_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("destroy context failed");
}
context_ = nullptr;
}
INFO_LOG("end to destroy context");
ret = aclrtResetDevice(deviceId_);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("reset device failed");
}
INFO_LOG("end to reset device is %d", deviceId_);
ret = aclFinalize();
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("finalize acl failed");
}
INFO_LOG("end to finalize acl");
}

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/utils.h"
#include <sys/stat.h>
#include <iostream>
#include <fstream>
#include <cstring>
#include "acl/acl.h"
extern bool g_is_device;
void* Utils::ReadBinFile(std::string fileName, uint32_t *fileSize) {
struct stat sBuf;
int fileStatus = stat(fileName.data(), &sBuf);
if (fileStatus == -1) {
ERROR_LOG("failed to get file");
return nullptr;
}
if (S_ISREG(sBuf.st_mode) == 0) {
ERROR_LOG("%s is not a file, please enter a file", fileName.c_str());
return nullptr;
}
std::ifstream binFile(fileName, std::ifstream::binary);
if (binFile.is_open() == false) {
ERROR_LOG("open file %s failed", fileName.c_str());
return nullptr;
}
binFile.seekg(0, binFile.end);
uint32_t binFileBufferLen = binFile.tellg();
if (binFileBufferLen == 0) {
ERROR_LOG("binfile is empty, filename is %s", fileName.c_str());
binFile.close();
return nullptr;
}
binFile.seekg(0, binFile.beg);
void* binFileBufferData = nullptr;
aclError ret = ACL_ERROR_NONE;
if (!g_is_device) {
ret = aclrtMallocHost(&binFileBufferData, binFileBufferLen);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc for binFileBufferData failed");
binFile.close();
return nullptr;
}
if (binFileBufferData == nullptr) {
ERROR_LOG("malloc binFileBufferData failed");
binFile.close();
return nullptr;
}
} else {
ret = aclrtMalloc(&binFileBufferData, binFileBufferLen, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc device buffer failed. size is %u", binFileBufferLen);
binFile.close();
return nullptr;
}
}
binFile.read(static_cast<char *>(binFileBufferData), binFileBufferLen);
binFile.close();
*fileSize = binFileBufferLen;
return binFileBufferData;
}
void* Utils::GetDeviceBufferOfFile(std::string fileName, uint32_t *fileSize) {
uint32_t inputHostBuffSize = 0;
void* inputHostBuff = Utils::ReadBinFile(fileName, &inputHostBuffSize);
if (inputHostBuff == nullptr) {
return nullptr;
}
if (!g_is_device) {
void *inBufferDev = nullptr;
uint32_t inBufferSize = inputHostBuffSize;
aclError ret = aclrtMalloc(&inBufferDev, inBufferSize, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc device buffer failed. size is %u", inBufferSize);
aclrtFreeHost(inputHostBuff);
return nullptr;
}
ret = aclrtMemcpy(inBufferDev, inBufferSize, inputHostBuff, inputHostBuffSize, ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("memcpy failed. device buffer size is %u, input host buffer size is %u",
inBufferSize, inputHostBuffSize);
aclrtFree(inBufferDev);
aclrtFreeHost(inputHostBuff);
return nullptr;
}
aclrtFreeHost(inputHostBuff);
*fileSize = inBufferSize;
return inBufferDev;
} else {
*fileSize = inputHostBuffSize;
return inputHostBuff;
}
}

View File

@ -0,0 +1,176 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
x1 = bbox_pred[0]
y1 = bbox_pred[1]
width1 = bbox_pred[2] - bbox_pred[0]
height1 = bbox_pred[3] - bbox_pred[1]
x2 = bbox_ground[0]
y2 = bbox_ground[1]
width2 = bbox_ground[2] - bbox_ground[0]
height2 = bbox_ground[3] - bbox_ground[1]
endx = max(x1 + width1, x2 + width2)
startx = min(x1, x2)
width = width1 + width2 - (endx - startx)
endy = max(y1 + height1, y2 + height2)
starty = min(y1, y2)
height = height1 + height2 - (endy - starty)
if width <= 0 or height <= 0:
iou = 0
else:
area = width * height
area1 = width1 * height1
area2 = width2 * height2
iou = area * 1. / (area1 + area2 - area)
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
class ConfigYOLOV3ResNet18:
"""
Config parameters for YOLOv3.
Examples:
ConfigYoloV3ResNet18.
"""
def __init__(self):
self.img_shape = [352, 640]
self.feature_shape = [32, 3, 352, 640]
self.num_classes = 2
self.nms_max_num = 50
self.backbone_input_shape = [64, 64, 128, 256]
self.backbone_shape = [64, 128, 256, 512]
self.backbone_layers = [2, 2, 2, 2]
self.backbone_stride = [1, 2, 2, 2]
self.ignore_threshold = 0.5
self.obj_threshold = 0.3
self.nms_threshold = 0.4
self.anchor_scales = [(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(163, 326)]
self.out_channel = int(len(self.anchor_scales) / 3 * (self.num_classes + 5))
def metrics(pred_data):
"""Calculate precision and recall of predicted bboxes."""
config = ConfigYOLOV3ResNet18()
num_classes = config.num_classes
count_corrects = [1e-6 for _ in range(num_classes)]
count_grounds = [1e-6 for _ in range(num_classes)]
count_preds = [1e-6 for _ in range(num_classes)]
for i, sample in enumerate(pred_data):
gt_anno = sample["annotation"]
box_scores = sample['box_scores']
boxes = sample['boxes']
mask = box_scores >= config.obj_threshold
boxes_ = []
scores_ = []
classes_ = []
max_boxes = config.nms_max_num
for c in range(num_classes):
class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])]
class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])]
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
classes = np.ones_like(class_box_scores, 'int32') * c
boxes_.append(class_boxes)
scores_.append(class_box_scores)
classes_.append(classes)
boxes = np.concatenate(boxes_, axis=0)
classes = np.concatenate(classes_, axis=0)
# metric
count_correct = [1e-6 for _ in range(num_classes)]
count_ground = [1e-6 for _ in range(num_classes)]
count_pred = [1e-6 for _ in range(num_classes)]
for anno in gt_anno:
count_ground[anno[4]] += 1
for box_index, box in enumerate(boxes):
bbox_pred = [box[1], box[0], box[3], box[2]]
count_pred[classes[box_index]] += 1
for anno in gt_anno:
class_ground = anno[4]
if classes[box_index] == class_ground:
iou = calc_iou(bbox_pred, anno)
if iou >= 0.5:
count_correct[class_ground] += 1
break
count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
return precision, recall