diff --git a/model_zoo/research/cv/FaceDetection/README.md b/model_zoo/research/cv/FaceDetection/README.md index 85d782f663e..9f068a73e9e 100644 --- a/model_zoo/research/cv/FaceDetection/README.md +++ b/model_zoo/research/cv/FaceDetection/README.md @@ -87,6 +87,7 @@ The entire code structure is as following: . └─ Face Detection ├─ README.md + ├─ ascend310_infer # application for 310 inference ├─ model_utils ├─ __init__.py # init file ├─ config.py # Parse arguments @@ -97,6 +98,7 @@ The entire code structure is as following: ├─ run_standalone_train.sh # launch standalone training(1p) in ascend ├─ run_distribute_train.sh # launch distributed training(8p) in ascend ├─ run_eval.sh # launch evaluating in ascend + ├─ run_infer_310.sh # launch inference on Ascend310 └─ run_export.sh # launch exporting air model ├─ src ├─ FaceDetection @@ -115,6 +117,9 @@ The entire code structure is as following: ├─ default_config.yaml # default configurations ├─ train.py # training scripts ├─ eval.py # evaluation scripts + ├─ postprocess.py # postprocess script + ├─ preprocess.py # preprocess script + ├─ bin.py # bin script └─ export.py # export air model ``` @@ -266,13 +271,39 @@ Saving ../../results/0-2441_61000/.._.._results_0-2441_61000_face_AP_0.760.png And the detect result and P-R graph will also be saved in "./results/[MODEL_NAME]/" -### Convert model +### Inference process -If you want to infer the network on Ascend 310, you should convert the model to AIR: +#### Convert model + +If you want to infer the network on Ascend 310, you should convert the model to MINDIR or AIR: + +```shell +# Ascend310 inference +python export.py --pretrained [PRETRAIN] --batch_size [BATCH_SIZE] --file_format [EXPORT_FORMAT] +``` + +The pretrained parameter is required. +`EXPORT_FORMAT` should be in ["AIR", "MINDIR"] +Current batch_size can only be set to 1. + +#### Infer on Ascend310 + +Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model. + +```shell +# Ascend310 inference +bash run_infer_310.sh [MINDIR_PATH] [MINDRECORD_PATH] [DEVICE_ID] +``` + +- `DEVICE_ID` is optional, default value is 0. + +#### result + +Inference result is saved in current path, you can find result like this in map.log file. ```bash -cd ./scripts -bash run_export.sh [PLATFORM] [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] +calculate [recall | persicion | ap]... +Saving ../../results/0-2441_61000/.._.._results_0-2441_61000_face_AP_0.7575.png ``` # [Model Description](#contents) @@ -310,6 +341,20 @@ bash run_export.sh [PLATFORM] [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | Accuracy | 8pcs: 76.0% | | Model for inference | 37M (.ckpt file) | +### Inference Performance + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Model Version | Face Detection | +| Resource | Ascend 310; Euler2.8 | +| Uploaded Date | 19/06/2021 (month/day/year) | +| MindSpore Version | 1.2.0 | +| Dataset | 3K images | +| batch_size | 1 | +| outputs | mAP | +| mAP | mAP=75.75% | +| Model for inference | 37M(.ckpt file) | + # [ModelZoo Homepage](#contents) Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/research/cv/FaceDetection/ascend310_infer/CMakeLists.txt b/model_zoo/research/cv/FaceDetection/ascend310_infer/CMakeLists.txt new file mode 100644 index 00000000000..ee3c8544734 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/ascend310_infer/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/model_zoo/research/cv/FaceDetection/ascend310_infer/build.sh b/model_zoo/research/cv/FaceDetection/ascend310_infer/build.sh new file mode 100644 index 00000000000..770a8851efa --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/ascend310_infer/build.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ ! -d out ]; then + mkdir out +fi +cd out || exit +cmake .. \ + -DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/model_zoo/research/cv/FaceDetection/ascend310_infer/inc/utils.h b/model_zoo/research/cv/FaceDetection/ascend310_infer/inc/utils.h new file mode 100644 index 00000000000..efebe03a8c1 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/ascend310_infer/inc/utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +std::vector GetAllFiles(std::string_view dirName); +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs); +#endif diff --git a/model_zoo/research/cv/FaceDetection/ascend310_infer/src/main.cc b/model_zoo/research/cv/FaceDetection/ascend310_infer/src/main.cc new file mode 100644 index 00000000000..4daec31d9a0 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/ascend310_infer/src/main.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "inc/utils.h" + +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(input0_path, ".", "input0 path"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + std::vector model_inputs = model.GetInputs(); + if (model_inputs.empty()) { + std::cout << "Invalid model, inputs is empty." << std::endl; + return 1; + } + + auto input0_files = GetAllFiles(FLAGS_input0_path); + + if (input0_files.empty()) { + std::cout << "ERROR: input data empty." << std::endl; + return 1; + } + + std::map costTime_map; + size_t size = input0_files.size(); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs; + double endTimeMs; + std::vector inputs; + std::vector outputs; + std::cout << "Start predict input files:" << input0_files[i] << std::endl; + + auto input0 = ReadFileToTensor(input0_files[i]); + inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), + input0.Data().get(), input0.DataSize()); + + gettimeofday(&start, nullptr); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, nullptr); + if (ret != kSuccess) { + std::cout << "Predict " << input0_files[i] << " failed." << std::endl; + return 1; + } + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair(startTimeMs, endTimeMs)); + WriteResult(input0_files[i], outputs); + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/model_zoo/research/cv/FaceDetection/ascend310_infer/src/utils.cc b/model_zoo/research/cv/FaceDetection/ascend310_infer/src/utils.cc new file mode 100644 index 00000000000..b509c57f823 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/ascend310_infer/src/utils.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "inc/utils.h" + +#include +#include +#include + +using mindspore::MSTensor; +using mindspore::DataType; + +std::vector GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(std::string(dirName) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + for (auto &f : res) { + std::cout << "image file: " << f << std::endl; + } + return res; +} + +int WriteResult(const std::string& imageFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput; + netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE * outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir; + dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/model_zoo/research/cv/FaceDetection/bin.py b/model_zoo/research/cv/FaceDetection/bin.py new file mode 100644 index 00000000000..bfcb0479d62 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/bin.py @@ -0,0 +1,76 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""bin for 310 inference""" +import os +import numpy as np +from PIL import Image, ImageOps +from model_utils.config import config + + +def tf_pil(img): + """ Letterbox an image to fit in the network """ + + net_w, net_h = config.input_shape + fill_color = 127 + im_w, im_h = img.size + + if im_w == net_w and im_h == net_h: + return img + + # Rescaling + if im_w / net_w >= im_h / net_h: + scale = net_w / im_w + else: + scale = net_h / im_h + if scale != 1: + resample_mode = Image.NEAREST + img = img.resize((int(scale * im_w), int(scale * im_h)), resample_mode) + im_w, im_h = img.size + + if im_w == net_w and im_h == net_h: + return img + + # Padding + img_np = np.array(img) + channels = img_np.shape[2] if len(img_np.shape) > 2 else 1 + pad_w = (net_w - im_w) / 2 + pad_h = (net_h - im_h) / 2 + pad = (int(pad_w), int(pad_h), int(pad_w + .5), int(pad_h + .5)) + img = ImageOps.expand(img, border=pad, fill=(fill_color,) * channels) + return img + + +def hwc2chw(img_np): + return img_np.transpose(2, 0, 1).copy() + + +def to_tensor(image): + image = np.asarray(image) + image = hwc2chw(image) + image = image / 255. + return image.astype(np.float32) + + +if __name__ == '__main__': + result_path = os.path.join(config.preprocess_path, 'images_bin') + if not os.path.isdir(result_path): + os.makedirs(result_path, exist_ok=True) + data_path = os.path.join(config.preprocess_path, "images") + files = os.listdir(data_path) + for file in files: + img_pil = Image.open(os.path.join(data_path, file)).convert("RGB") + img_pil = tf_pil(img_pil) + img_pil = to_tensor(img_pil) + img_pil.tofile(os.path.join(result_path, file.split('.')[0] + '.bin')) diff --git a/model_zoo/research/cv/FaceDetection/default_config.yaml b/model_zoo/research/cv/FaceDetection/default_config.yaml index 4ca377f65e2..f3ea79b54f2 100644 --- a/model_zoo/research/cv/FaceDetection/default_config.yaml +++ b/model_zoo/research/cv/FaceDetection/default_config.yaml @@ -58,6 +58,15 @@ anchors_mask: [[8, 9, 10, 11], [4, 5, 6, 7], [0, 1, 2, 3]] conf_thresh: 0.1 nms_thresh: 0.45 +#export +file_name: "FaceDetection" +file_format: "AIR" + +#310 infer +preprocess_path: "" +save_output_path: "" +data_dir: "" + --- # Help description for each configuration diff --git a/model_zoo/research/cv/FaceDetection/export.py b/model_zoo/research/cv/FaceDetection/export.py index a24d6e5714f..2864bdda0df 100644 --- a/model_zoo/research/cv/FaceDetection/export.py +++ b/model_zoo/research/cv/FaceDetection/export.py @@ -19,13 +19,17 @@ import numpy as np from mindspore import context from mindspore import Tensor from mindspore.train.serialization import export, load_checkpoint, load_param_into_net - +from src.network_define import BuildTestNetwork from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3 from model_utils.config import config def save_air(): - '''save air''' - print('============= yolov3 start save air ==================') + '''save air or mindir''' + anchors = config.anchors + reduction_0 = 64.0 + reduction_1 = 32.0 + reduction_2 = 16.0 + print('============= yolov3 start save air or mindir ==================') devid = int(os.getenv('DEVICE_ID', '0')) if config.run_platform != 'CPU' else 0 context.set_context(mode=context.GRAPH_MODE, device_target=config.run_platform, save_graphs=False, device_id=devid) @@ -47,12 +51,12 @@ def save_air(): param_dict_new[key] = values load_param_into_net(network, param_dict_new) print('load model {} success'.format(config.pretrained)) - + test_net = BuildTestNetwork(network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes, + config) input_data = np.random.uniform(low=0, high=1.0, size=(config.batch_size, 3, 448, 768)).astype(np.float32) tensor_input_data = Tensor(input_data) - export(network, tensor_input_data, - file_name=config.pretrained.replace('.ckpt', '_' + str(config.batch_size) + 'b.air'), file_format='AIR') + export(test_net, tensor_input_data, file_name=config.file_name, file_format=config.file_format) print("export model success.") diff --git a/model_zoo/research/cv/FaceDetection/postprocess.py b/model_zoo/research/cv/FaceDetection/postprocess.py new file mode 100644 index 00000000000..a19d6a8e48f --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/postprocess.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process for 310 inference""" +import os +import matplotlib.pyplot as plt +import numpy as np +from mindspore import Tensor +from mindspore.common import dtype as mstype +from src.FaceDetection import voc_wrapper +from src.network_define import get_bounding_boxes, tensor_to_brambox, \ + parse_gt_from_anno, parse_rets, calc_recall_precision_ap + +from model_utils.config import config + + +def cal_map(result_path, data_dir, save_output_path): + """cal map""" + labels = ['face'] + det = {} + img_size = {} + img_anno = {} + eval_times = 0 + classes = {0: 'face'} + ret_files_set = {'face': os.path.join(save_output_path, 'comp4_det_test_face_rm5050.txt')} + files = os.listdir(os.path.join(data_dir, "labels")) + for file in files: + image_name = file.split('.')[0] + label = np.fromfile(os.path.join(data_dir, "labels", file), dtype=np.float64).reshape((1, 200, 6)) + image_size = np.fromfile(os.path.join(data_dir, "image_size", file), dtype=np.int32).reshape((1, 1, 2)) + eval_times += 1 + dets = [] + tdets = [] + file_path = os.path.join(result_path, image_name) + coords_0 = np.fromfile(file_path + '_0.bin', dtype=np.float32).reshape((1, 4, 84, 4)) + coords_0 = Tensor(coords_0, mstype.float32) + cls_scores_0 = np.fromfile(file_path + '_1.bin', dtype=np.float32).reshape((1, 4, 84)) + cls_scores_0 = Tensor(cls_scores_0, mstype.float32) + coords_1 = np.fromfile(file_path + '_2.bin', dtype=np.float32).reshape((1, 4, 336, 4)) + coords_1 = Tensor(coords_1, mstype.float32) + cls_scores_1 = np.fromfile(file_path + '_3.bin', dtype=np.float32).reshape((1, 4, 336)) + cls_scores_1 = Tensor(cls_scores_1, mstype.float32) + coords_2 = np.fromfile(file_path + '_4.bin', dtype=np.float32).reshape((1, 4, 1344, 4)) + coords_2 = Tensor(coords_2, mstype.float32) + cls_scores_2 = np.fromfile(file_path + '_5.bin', dtype=np.float32).reshape((1, 4, 1344)) + cls_scores_2 = Tensor(cls_scores_2, mstype.float32) + + boxes_0, boxes_1, boxes_2 = get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, + cls_scores_2, config.conf_thresh, config.input_shape, + config.num_classes) + + converted_boxes_0, converted_boxes_1, converted_boxes_2 = tensor_to_brambox(boxes_0, boxes_1, boxes_2, + config.input_shape, labels) + + tdets.append(converted_boxes_0) + tdets.append(converted_boxes_1) + tdets.append(converted_boxes_2) + batch = len(tdets[0]) + for b in range(batch): + single_dets = [] + for op in range(3): + single_dets.extend(tdets[op][b]) + dets.append(single_dets) + + det.update({image_name: v for k, v in enumerate(dets)}) + img_size.update({image_name: v for k, v in enumerate(image_size)}) + img_anno.update({image_name: v for k, v in enumerate(label)}) + + netw, neth = config.input_shape + reorg_dets = voc_wrapper.reorg_detection(det, netw, neth, img_size) + voc_wrapper.gen_results(reorg_dets, save_output_path, img_size, config.nms_thresh) + + # compute mAP + ground_truth = parse_gt_from_anno(img_anno, classes) + + ret_list = parse_rets(ret_files_set) + iou_thr = 0.5 + evaluate = calc_recall_precision_ap(ground_truth, ret_list, iou_thr) + print(evaluate) + + aps_str = '' + for cls in evaluate: + per_line, = plt.plot(evaluate[cls]['recall'], evaluate[cls]['precision'], 'b-') + per_line.set_label('%s:AP=%.4f' % (cls, evaluate[cls]['ap'])) + aps_str += '_%s_AP_%.4f' % (cls, evaluate[cls]['ap']) + plt.plot([i / 1000.0 for i in range(1, 1001)], [i / 1000.0 for i in range(1, 1001)], 'y--') + plt.axis([0, 1.2, 0, 1.2]) # [x_min, x_max, y_min, y_max] + plt.xlabel('recall') + plt.ylabel('precision') + plt.grid() + + plt.legend() + plt.title('PR') + + # save mAP + ap_save_path = os.path.join(save_output_path, save_output_path.replace('/', '_') + aps_str + '.png') + print('Saving {}'.format(ap_save_path)) + plt.savefig(ap_save_path) + + print('=============yolov3 evaluating finished==================') + + +if __name__ == '__main__': + cal_map(config.result_path, config.data_dir, config.save_output_path) diff --git a/model_zoo/research/cv/FaceDetection/preprocess.py b/model_zoo/research/cv/FaceDetection/preprocess.py new file mode 100644 index 00000000000..3255a9de057 --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/preprocess.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pre process for 310 inference""" +import os +import numpy as np +from PIL import Image +import mindspore.dataset.vision.py_transforms as P +import mindspore.dataset as de +from model_utils.config import config + + +class SingleScaleTrans_Infer: + '''SingleScaleTrans''' + + def __init__(self, resize, max_anno_count=200): + self.resize = (resize[0], resize[1]) + self.max_anno_count = max_anno_count + + def __call__(self, imgs, ann, image_names, image_size, batch_info): + + decode = P.Decode() + ret_imgs = [] + ret_anno = [] + + for i, image in enumerate(imgs): + img_pil = decode(image) + input_data = img_pil, ann[i] + ret_imgs.append(np.array(input_data[0])) + ret_anno.append(input_data[1]) + + for i, anno in enumerate(ret_anno): + anno_count = anno.shape[0] + if anno_count < self.max_anno_count: + ret_anno[i] = np.concatenate( + (ret_anno[i], np.zeros((self.max_anno_count - anno_count, 6), dtype=float)), axis=0) + else: + ret_anno[i] = ret_anno[i][:self.max_anno_count] + + return np.array(ret_imgs), np.array(ret_anno), image_names, image_size + + +def preprocess(): + """preprocess""" + preprocess_path = config.preprocess_path + images_path = os.path.join(preprocess_path, 'images') + if not os.path.isdir(images_path): + os.makedirs(images_path, exist_ok=True) + + labels_path = os.path.join(preprocess_path, 'labels') + if not os.path.isdir(labels_path): + os.makedirs(labels_path, exist_ok=True) + + image_name_path = os.path.join(preprocess_path, 'image_name') + if not os.path.isdir(image_name_path): + os.makedirs(image_name_path, exist_ok=True) + image_size_path = os.path.join(preprocess_path, 'image_size') + if not os.path.isdir(image_size_path): + os.makedirs(image_size_path, exist_ok=True) + + ds = de.MindDataset(os.path.join(config.mindrecord_path, "data.mindrecord0"), + columns_list=["image", "annotation", "image_name", "image_size"]) + single_scale_trans = SingleScaleTrans_Infer(resize=config.input_shape) + ds = ds.batch(config.batch_size, per_batch_map=single_scale_trans, + input_columns=["image", "annotation", "image_name", "image_size"], num_parallel_workers=8) + ds = ds.repeat(1) + for data in ds.create_tuple_iterator(output_numpy=True): + images, labels, image_name, image_size = data[0:4] + images = Image.fromarray(images[0].astype('uint8')).convert('RGB') + images.save(os.path.join(images_path, image_name[0].decode() + ".jpg")) + labels.tofile(os.path.join(labels_path, image_name[0].decode() + ".bin")) + image_name.tofile(os.path.join(image_name_path, image_name[0].decode() + ".bin")) + image_size.tofile(os.path.join(image_size_path, image_name[0].decode() + ".bin")) + + +if __name__ == '__main__': + preprocess() diff --git a/model_zoo/research/cv/FaceDetection/scripts/run_infer_310.sh b/model_zoo/research/cv/FaceDetection/scripts/run_infer_310.sh new file mode 100644 index 00000000000..15b609c9e6d --- /dev/null +++ b/model_zoo/research/cv/FaceDetection/scripts/run_infer_310.sh @@ -0,0 +1,114 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 2 || $# -gt 3 ]]; then + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [MINDRECORD_DIR] [DEVICE_ID] + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +model=$(get_real_path $1) +mindrecord_dir=$(get_real_path $2) +device_id=0 +if [ $# == 3 ]; then + device_id=$3 +fi + +echo "mindir name: "$model +echo "mindrecord dir: "$mindrecord_dir +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function preprocess_data() +{ + if [ -d preprocess_Result ]; then + rm -rf ./preprocess_Result + fi + mkdir preprocess_Result + python3.7 ../preprocess.py --preprocess_path=./preprocess_Result --mindrecord_path=$mindrecord_dir --batch_size=1 &> preprocess.log + python3.7 ../bin.py --preprocess_path=./preprocess_Result + data_dir=./preprocess_Result + input0_path=./preprocess_Result/images_bin +} + +function compile_app() +{ + cd ../ascend310_infer || exit + bash build.sh &> build.log +} + +function infer() +{ + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + ../ascend310_infer/out/main --mindir_path=$model --input0_path=$input0_path --device_id=$device_id &> infer.log +} + +function cal_map() +{ + if [ -d infer_output ]; then + rm -rf ./infer_output + fi + mkdir infer_output + python3.7 ../postprocess.py --result_path=./result_Files --data_dir=$data_dir --save_output_path=./infer_output &> map.log & +} +preprocess_data +if [ $? -ne 0 ]; then + echo "preprocess data failed" + exit 1 +fi +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_map +if [ $? -ne 0 ]; then + echo "calculate map failed" + exit 1 +fi \ No newline at end of file