ssd_ghostnet add 310infer

This commit is contained in:
chenweitao_295 2021-06-01 11:28:30 +08:00
parent 20b1e0291c
commit 6b653c5711
9 changed files with 629 additions and 7 deletions

View File

@ -1,3 +1,28 @@
# Contents
- [Contents](#contents)
- [SSD Description](#ssd-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training on Ascend](#training-on-ascend)
- [Evaluation Process](#evaluation-process)
- [Evaluation on Ascend](#evaluation-on-ascend)
- [Inference Process](#inference-process)
- [Export MindIR](#export-mindir)
- [Infer on Ascend310](#infer-on-ascend310)
- [result](#result)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#inference-performance)
- [310Inference Performance](#310inference-performance)
# [SSD Description](#contents)
SSD discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location. At prediction time, the network generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.Additionally, the network combines predictions from multiple feature maps with different resolutions to naturally handle objects of various sizes.
@ -122,8 +147,10 @@ If you want to run in modelarts, please check the official documentation of [mod
├── ssd_ghostnet
├── README.md ## readme file of ssd_ghostnet
├── ascend310_infer ## application for 310 inference
├── scripts
└─ run_distribute_train_ghostnet.sh ## shell script for distributed on ascend
├─ run_distribute_train_ghostnet.sh ## shell script for distributed on ascend
└─ run_infer_310.sh ## shell script for 310inference on ascend
├── src
├─ box_util.py ## bbox utils
├─ coco_eval.py ## coco metrics utils
@ -132,13 +159,15 @@ If you want to run in modelarts, please check the official documentation of [mod
├─ lr_schedule.py ## learning ratio generator
└─ ssd_ghostnet.py ## ssd architecture
├── model_utils
│ ├── config.py ## parameter configuration
│ ├── device_adapter.py ## device adapter
│ ├── local_adapter.py ## local adapter
│ ├── moxing_adapter.py ## moxing adapter
│ ├── config.py ## parameter configuration
│ ├── device_adapter.py ## device adapter
│ ├── local_adapter.py ## local adapter
│ ├── moxing_adapter.py ## moxing adapter
├── default_config.yaml ## parameter configuration
├── eval.py ## eval scripts
├── train.py ## train scripts
├── export.py ## export mindir script
├── postprocess.py ## postprocess scripts
├── mindspore_hub_conf.py ## export model for hub
```
@ -205,6 +234,49 @@ Training result will be stored in the current path, whose folder name begins wit
python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.ckpt
```
## [Inference Process](#contents)
### Export MindIR
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
The ckpt_file parameter is required,
`FILE_FORMAT` should be in ["AIR", "MINDIR"]
### Infer on Ascend310
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
Current batch_size can only be set to 1.
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
```
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. SSD_ghostnet only support CPU mode. Note that the image shape of ssd_ghostnet inference is [300, 300], The DVPP hardware restricts width 16-alignment and height even-alignment. Therefore, the network needs to use the CPU operator to process images.
- `DEVICE_ID` is optional, default value is 0.
### result
Inference result is saved in current path, you can find result like this in acc.log file.
```bash
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.243
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.411
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.244
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.038
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.205
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.450
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.252
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.391
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.424
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.699
mAP: 0.24270569394180577
```
# [Model Description](#contents)
## [Performance](#contents)
@ -227,11 +299,25 @@ python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.c
| Parameters | Ascend |
| ------------------- | ----------------------------|
| Model Version | SSD ghostnet |
| Resource | Ascend 910; OS Euler2.8 |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 09/08/2020 (month/day/year) |
| MindSpore Version | 0.7.0 |
| Dataset | COCO2017 |
| batch_size | 1 |
| outputs | mAP |
| Accuracy | IoU=0.50: 24.1% |
| Model for inference | 55M(.ckpt file) |
| Model for inference | 55M(.ckpt file) |
#### 310Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | SSD ghostnet |
| Resource | Ascend 310; OS Euler2.8 |
| Uploaded Date | 06/01/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | COCO2017 |
| batch_size | 1 |
| outputs | mAP |
| Accuracy | IoU=0.50: 24.2% |
| Model for inference | 52.5M(.ckpt file) |

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ ! -d out ]; then
mkdir out
fi
cd out || exit
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,142 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Decode;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_string(cpu_dvpp, "", "cpu or dvpp process");
DEFINE_int32(image_height, 224, "image height");
DEFINE_int32(image_width, 224, "image width");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
auto decode = Decode();
auto resize = Resize({FLAGS_image_height, FLAGS_image_width});
auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.120, 57.375});
auto hwc2chw = HWC2CHW();
mindspore::dataset::Execute transform({decode, resize, normalize, hwc2chw});
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
if (FLAGS_cpu_dvpp == "DVPP") {
// ssd_ghostnet currently only support CPU mode.
} else {
auto img = MSTensor();
auto image = ReadFileToTensor(all_files[i]);
transform(image, &img);
std::vector<MSTensor> model_inputs = model.GetInputs();
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
img.Data().get(), img.DataSize());
}
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(all_files[i], outputs);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -82,6 +82,8 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
num_classes: 81
result_path: ""
img_path: ""
# The annotation.json position of voc validation dataset.
voc_root: ""
# voc original dataset.

View File

@ -0,0 +1,88 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import numpy as np
from PIL import Image
from src.model_utils.config import config
from src.coco_eval import metrics
batch_size = 1
def get_imgSize(file_name):
img = Image.open(file_name)
return img.size
def get_result(result_path, img_id_file_path, drop=True):
""" get result"""
coco_root = os.path.join(config.data_path, "coco_ori")
anno_json = os.path.join(coco_root, config.instances_set.format(config.val_data_type))
if drop:
from pycocotools.coco import COCO
train_cls = config.coco_classes
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
files = os.listdir(img_id_file_path)
pred_data = []
for file in files:
img_ids_name = file.split('.')[0]
img_id = int(np.squeeze(img_ids_name))
if drop:
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
annos = []
iscrowd = False
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
iscrowd = iscrowd or label["iscrowd"]
if class_name in train_cls:
x_min, x_max = bbox[0], bbox[0] + bbox[2]
y_min, y_max = bbox[1], bbox[1] + bbox[3]
annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
if iscrowd or (not annos):
continue
img_size = get_imgSize(os.path.join(img_id_file_path, file))
image_shape = np.array([img_size[1], img_size[0]])
result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin")
result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin")
boxes = np.fromfile(result_path_0, dtype=np.float32).reshape(config.num_ssd_boxes, 4)
box_scores = np.fromfile(result_path_1, dtype=np.float32).reshape(config.num_ssd_boxes, config.num_classes)
pred_data.append({
"boxes": boxes,
"box_scores": box_scores,
"img_id": img_id,
"image_shape": image_shape
})
mAP = metrics(pred_data)
print(f" mAP:{mAP}")
if __name__ == '__main__':
get_result(config.result_path, config.img_path)

View File

@ -0,0 +1,105 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID]
DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
DVPP=${3^^}
device_id=0
if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "image process mode: "$DVPP
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer || exit
bash build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
if [ "$DVPP" == "DVPP" ];then
echo "ssd_ghostnet currently only support CPU mode."
elif [ "$DVPP" == "CPU" ]; then
../ascend310_infer/out/main --mindir_path=$model --dataset_path=$data_path --cpu_dvpp=$DVPP --device_id=$device_id --image_height=300 --image_width=300 &> infer.log
else
echo "image process mode must be in [DVPP|CPU]"
exit 1
fi
}
function cal_acc()
{
python3.7 ../postprocess.py --result_path=./result_Files --img_path=$data_path &> acc.log &
}
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi