From 7320e6b3f6b6e4bf47220636a8efa2b1967098ee Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 18 May 2021 16:43:10 +0800 Subject: [PATCH] psenet 310 infer modified: psenet/ascend310_infer/src/main.cc modified: psenet/postprocess.py modified: psenet/scripts/run_infer_310.sh modified: psenet/README_CN.md --- model_zoo/official/cv/psenet/README.md | 31 +++ model_zoo/official/cv/psenet/README_CN.md | 33 +++- .../cv/psenet/ascend310_infer/CMakeLists.txt | 15 ++ .../cv/psenet/ascend310_infer/build.sh | 29 +++ .../cv/psenet/ascend310_infer/inc/utils.h | 35 ++++ .../cv/psenet/ascend310_infer/src/main.cc | 161 +++++++++++++++ .../cv/psenet/ascend310_infer/src/utils.cc | 185 ++++++++++++++++++ model_zoo/official/cv/psenet/postprocess.py | 104 ++++++++++ .../cv/psenet/scripts/run_infer_310.sh | 96 +++++++++ 9 files changed, 688 insertions(+), 1 deletion(-) create mode 100644 model_zoo/official/cv/psenet/ascend310_infer/CMakeLists.txt create mode 100644 model_zoo/official/cv/psenet/ascend310_infer/build.sh create mode 100644 model_zoo/official/cv/psenet/ascend310_infer/inc/utils.h create mode 100644 model_zoo/official/cv/psenet/ascend310_infer/src/main.cc create mode 100644 model_zoo/official/cv/psenet/ascend310_infer/src/utils.cc create mode 100644 model_zoo/official/cv/psenet/postprocess.py create mode 100644 model_zoo/official/cv/psenet/scripts/run_infer_310.sh diff --git a/model_zoo/official/cv/psenet/README.md b/model_zoo/official/cv/psenet/README.md index 25346319350..c8cbe24213b 100644 --- a/model_zoo/official/cv/psenet/README.md +++ b/model_zoo/official/cv/psenet/README.md @@ -14,6 +14,10 @@ - [Distributed Training](#distributed-training) - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) + - [Inference Process](#inference-process) + - [Export MindIR](#export-mindir) + - [Infer on Ascend310](#infer-on-ascend310) + - [result](#result) - [Model Description](#model-description) - [Performance](#performance) - [Evaluation Performance](#evaluation-performance) @@ -180,6 +184,33 @@ sh ./script/run_eval_ascend.sh.sh Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0} +## Inference Process + +### [Export MindIR](#contents) + +```shell +python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] +``` + +The ckpt_file parameter is required, +`EXPORT_FORMAT` should be in ["AIR", "MINDIR"] + +### Infer on Ascend310 + +Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model. +Current batch_Size can only be set to 1. Before running the following process, please configure the environment by following the instructions provided in [Quick start](#quick-start). + +```shell +# Ascend310 inference +bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] +``` + +- `DEVICE_ID` is optional, default value is 0. + +### result + +The `res` folder is generated in the upper-level directory. For details about the final precision calculation, see [Eval Script for ICDAR2015](#eval-script-for-icdar2015). + # [Model Description](#contents) ## [Performance](#contents) diff --git a/model_zoo/official/cv/psenet/README_CN.md b/model_zoo/official/cv/psenet/README_CN.md index 503279b30ca..d3322cdafbd 100644 --- a/model_zoo/official/cv/psenet/README_CN.md +++ b/model_zoo/official/cv/psenet/README_CN.md @@ -16,7 +16,11 @@ - [运行测试代码](#运行测试代码) - [ICDAR2015评估脚本](#icdar2015评估脚本) - [用法](#用法) - - [结果](#结果) + - [结果](#结果) + - [推理过程](#推理过程) + - [导出MindIR](#导出mindir) + - [在Ascend310执行推理](#在ascend310执行推理) + - [结果](#结果) - [模型描述](#模型描述) - [性能](#性能) - [评估性能](#评估性能) @@ -178,6 +182,33 @@ sh ./script/run_eval_ascend.sh.sh Calculated!{"precision": 0.8147966668299853,"recall":0.8006740491092923,"hmean":0.8076736279747451,"AP":0} +## 推理过程 + +### [导出MindIR](#contents) + +```shell +python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] +``` + +参数ckpt_file为必填项, +`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。 + +### 在Ascend310执行推理 + +在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。 +目前仅支持batch_Size为1的推理。在执行推理前,请按照[快速入门](#快速入门)配置环境。 + +```shell +# Ascend310 推理 +bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] +``` + +- `DEVICE_ID` 可选,默认值为0。 + +### result + +在运行目录的上一级目录将生成`res`文件夹,最终精度计算过程,请参照[ICDAR2015评估脚本](#icdar2015评估脚本). + # 模型描述 ## 性能 diff --git a/model_zoo/official/cv/psenet/ascend310_infer/CMakeLists.txt b/model_zoo/official/cv/psenet/ascend310_infer/CMakeLists.txt new file mode 100644 index 00000000000..cf49a4e9f25 --- /dev/null +++ b/model_zoo/official/cv/psenet/ascend310_infer/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) + diff --git a/model_zoo/official/cv/psenet/ascend310_infer/build.sh b/model_zoo/official/cv/psenet/ascend310_infer/build.sh new file mode 100644 index 00000000000..d8ea19ff828 --- /dev/null +++ b/model_zoo/official/cv/psenet/ascend310_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ -d out ]; then + rm -rf out +fi + +mkdir out +cd out || exit + +if [ -f "Makefile" ]; then + make clean +fi + +cmake .. \ + -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make \ No newline at end of file diff --git a/model_zoo/official/cv/psenet/ascend310_infer/inc/utils.h b/model_zoo/official/cv/psenet/ascend310_infer/inc/utils.h new file mode 100644 index 00000000000..f8ae1e5b473 --- /dev/null +++ b/model_zoo/official/cv/psenet/ascend310_infer/inc/utils.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +std::vector GetAllFiles(std::string_view dirName); +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs); +std::vector GetAllFiles(std::string dir_name); +std::vector> GetAllInputData(std::string dir_name); + +#endif diff --git a/model_zoo/official/cv/psenet/ascend310_infer/src/main.cc b/model_zoo/official/cv/psenet/ascend310_infer/src/main.cc new file mode 100644 index 00000000000..958703eb735 --- /dev/null +++ b/model_zoo/official/cv/psenet/ascend310_infer/src/main.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/dataset/execute.h" +#include "include/dataset/vision.h" +#include "inc/utils.h" + +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::dataset::Execute; +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::Pad; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::HWC2CHW; + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + std::vector model_inputs = model.GetInputs(); + if (model_inputs.empty()) { + std::cout << "Invalid model, inputs is empty." << std::endl; + return 1; + } + + auto all_files = GetAllFiles(FLAGS_dataset_path); + if (all_files.empty()) { + std::cout << "ERROR: no input data." << std::endl; + return 1; + } + + std::map costTime_map; + size_t size = all_files.size(); + auto decode = Decode(); + Execute composeDecode(decode); + + auto resize = Resize({1920, 1920}); + auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375}); + auto hwc2chw = HWC2CHW(); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs; + double endTimeMs; + std::vector inputs; + std::vector outputs; + std::cout << "Start predict input files:" << all_files[i] < shape = imgDecode.Shape(); + int imgHeight = shape[0]; + int imgWidth = shape[1]; + std::vector pad_size; + if (imgWidth < imgHeight) { + pad_size = {0, 0, (imgHeight - imgWidth), 0}; + } else { + pad_size = {0, 0, 0, (imgWidth - imgHeight)}; + } + auto pad = Pad(pad_size, {0}); + Execute trans_list({pad, resize, normalize, hwc2chw}); + auto img = MSTensor(); + ret = trans_list(imgDecode, &img); + if (ret != kSuccess) { + std::cout << "ERROR: Image transfer failed." << std::endl; + return 1; + } + + inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), + img.Data().get(), img.DataSize()); + gettimeofday(&start, nullptr); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, nullptr); + if (ret != kSuccess) { + std::cout << "Predict " << all_files[i] << " failed." << std::endl; + return 1; + } + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTimeMs, endTimeMs)); + WriteResult(all_files[i], outputs); + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/model_zoo/official/cv/psenet/ascend310_infer/src/utils.cc b/model_zoo/official/cv/psenet/ascend310_infer/src/utils.cc new file mode 100644 index 00000000000..d71f388b83d --- /dev/null +++ b/model_zoo/official/cv/psenet/ascend310_infer/src/utils.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "inc/utils.h" + +using mindspore::MSTensor; +using mindspore::DataType; + + +std::vector> GetAllInputData(std::string dir_name) { + std::vector> ret; + + DIR *dir = OpenDir(dir_name); + if (dir == nullptr) { + return {}; + } + struct dirent *filename; + /* read all the files in the dir ~ */ + std::vector sub_dirs; + while ((filename = readdir(dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + // get rid of "." and ".." + if (d_name == "." || d_name == ".." || d_name.empty()) { + continue; + } + std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name); + struct stat s; + lstat(dir_path.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + continue; + } + + sub_dirs.emplace_back(dir_path); + } + std::sort(sub_dirs.begin(), sub_dirs.end()); + + (void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret), + [](const std::string &d) { return GetAllFiles(d); }); + + return ret; +} + + +std::vector GetAllFiles(std::string dir_name) { + struct dirent *filename; + DIR *dir = OpenDir(dir_name); + if (dir == nullptr) { + return {}; + } + + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + if (d_name == "." || d_name == ".." || d_name.size() <= 3) { + continue; + } + res.emplace_back(std::string(dir_name) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + + return res; +} + + +std::vector GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(std::string(dirName) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + for (auto &f : res) { + std::cout << "image file: " << f << std::endl; + } + return res; +} + + +int WriteResult(const std::string& imageFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput; + netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE *outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir; + dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/model_zoo/official/cv/psenet/postprocess.py b/model_zoo/official/cv/psenet/postprocess.py new file mode 100644 index 00000000000..3fd591c7450 --- /dev/null +++ b/model_zoo/official/cv/psenet/postprocess.py @@ -0,0 +1,104 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import os +import math +import operator +from functools import reduce +import argparse +import numpy as np +import cv2 + +from src.config import config +from src.ETSNET.pse import pse + +def sort_to_clockwise(points): + center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2)) + clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees( + math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True) + return clockwise_points + +def write_result_as_txt(image_name, img_bboxes, path): + if not os.path.isdir(path): + os.makedirs(path) + filename = os.path.join(path, 'res_{}.txt'.format(os.path.splitext(image_name)[0])) + lines = [] + for _, img_bbox in enumerate(img_bboxes): + img_bbox = img_bbox.reshape(-1, 2) + img_bbox = np.array(list(sort_to_clockwise(img_bbox)))[[3, 0, 1, 2]].copy().reshape(-1) + values = [int(v) for v in img_bbox] + line = "%d,%d,%d,%d,%d,%d,%d,%d\n" % tuple(values) + lines.append(line) + with open(filename, 'w') as f: + for line in lines: + f.write(line) + +def get_img(image_path): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + +parser = argparse.ArgumentParser(description='postprocess') +parser.add_argument("--result_path", type=str, default="./scripts/result_Files", help='result Files path.') +parser.add_argument("--img_path", type=str, default="", help='image files path.') +args = parser.parse_args() + +if __name__ == "__main__": + if not os.path.isdir('./res/submit_ic15/'): + os.makedirs('./res/submit_ic15/') + if not os.path.isdir('./res/vis_ic15/'): + os.makedirs('./res/vis_ic15/') + + file_list = os.listdir(args.img_path) + for k in file_list: + if os.path.splitext(k)[-1].lower() in ['.jpg', '.jpeg', '.png']: + img_path = os.path.join(args.img_path, k) + img = get_img(img_path).reshape(1, 720, 1280, 3) + img = img[0].astype(np.uint8).copy() + img_name = os.path.split(img_path)[-1] + + score = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_0.bin'), np.float32) + score = score.reshape(1, 1, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE) + kernels = np.fromfile(os.path.join(args.result_path, k.split('.')[0] + '_1.bin'), bool) + kernels = kernels.reshape(1, config.KERNEL_NUM, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE) + score = np.squeeze(score) + kernels = np.squeeze(kernels) + + # post-process + pred = pse(kernels, 5.0) + scale = max(img.shape[:2]) * 1.0 / config.INFER_LONG_SIZE + label = pred + label_num = np.max(label) + 1 + bboxes = [] + + for i in range(1, label_num): + pot = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1] + if pot.shape[0] < 600: + continue + + score_i = np.mean(score[label == i]) + if score_i < 0.93: + continue + + rect = cv2.minAreaRect(pot) + bbox = cv2.boxPoints(rect) * scale + bbox = bbox.astype('int32') + cv2.drawContours(img, [bbox], 0, (0, 255, 0), 3) + bboxes.append(bbox) + + # save res + cv2.imwrite('./res/vis_ic15/{}'.format(img_name), img[:, :, [2, 1, 0]].copy()) + write_result_as_txt(img_name, bboxes, './res/submit_ic15/') diff --git a/model_zoo/official/cv/psenet/scripts/run_infer_310.sh b/model_zoo/official/cv/psenet/scripts/run_infer_310.sh new file mode 100644 index 00000000000..eeb1ad694ce --- /dev/null +++ b/model_zoo/official/cv/psenet/scripts/run_infer_310.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 2 || $# -gt 3 ]]; then + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +model=$(get_real_path $1) +data_path=$(get_real_path $2) + +device_id=0 +if [ $# == 3 ]; then + device_id=$3 +fi + +echo "mindir name: "$model +echo "dataset path: "$data_path +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function compile_app() +{ + cd ../ascend310_infer/ || exit + bash build.sh &> build.log +} + +function infer() +{ + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + ../ascend310_infer/out/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log +} + +function cal_acc() +{ + cd .. || exit + python3.7 postprocess.py --result_path=./scripts/result_Files --img_path=$data_path &> ./scripts/acc.log +} + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_acc +if [ $? -ne 0 ]; then + echo "calculate accuracy failed" + exit 1 +fi \ No newline at end of file