From 7c25a77d7cd115ea4100d429aaf889f92561bb36 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 8 Jun 2021 18:28:24 +0800 Subject: [PATCH] resnetv2 310 infer modified: official/cv/resnet/create_imagenet2012_label.py modified: research/cv/resnetv2/create_imagenet2012_label.py --- .../cv/resnet/create_imagenet2012_label.py | 2 +- model_zoo/official/cv/resnet/postprocess.py | 2 +- model_zoo/research/cv/resnetv2/README_CN.md | 28 ++++ .../resnetv2/ascend310_infer/CMakeLists.txt | 14 ++ .../cv/resnetv2/ascend310_infer/build.sh | 29 ++++ .../cv/resnetv2/ascend310_infer/inc/utils.h | 33 ++++ .../cv/resnetv2/ascend310_infer/src/main.cc | 140 +++++++++++++++++ .../cv/resnetv2/ascend310_infer/src/utils.cc | 145 ++++++++++++++++++ .../cv/resnetv2/create_imagenet2012_label.py | 49 ++++++ model_zoo/research/cv/resnetv2/postprocess.py | 84 ++++++++++ model_zoo/research/cv/resnetv2/preprocess.py | 59 +++++++ .../cv/resnetv2/scripts/run_infer_310.sh | 123 +++++++++++++++ 12 files changed, 706 insertions(+), 2 deletions(-) create mode 100644 model_zoo/research/cv/resnetv2/ascend310_infer/CMakeLists.txt create mode 100644 model_zoo/research/cv/resnetv2/ascend310_infer/build.sh create mode 100644 model_zoo/research/cv/resnetv2/ascend310_infer/inc/utils.h create mode 100644 model_zoo/research/cv/resnetv2/ascend310_infer/src/main.cc create mode 100644 model_zoo/research/cv/resnetv2/ascend310_infer/src/utils.cc create mode 100644 model_zoo/research/cv/resnetv2/create_imagenet2012_label.py create mode 100644 model_zoo/research/cv/resnetv2/postprocess.py create mode 100644 model_zoo/research/cv/resnetv2/preprocess.py create mode 100644 model_zoo/research/cv/resnetv2/scripts/run_infer_310.sh diff --git a/model_zoo/official/cv/resnet/create_imagenet2012_label.py b/model_zoo/official/cv/resnet/create_imagenet2012_label.py index 38f6ee94284..20d86ae69ed 100644 --- a/model_zoo/official/cv/resnet/create_imagenet2012_label.py +++ b/model_zoo/official/cv/resnet/create_imagenet2012_label.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# less required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and diff --git a/model_zoo/official/cv/resnet/postprocess.py b/model_zoo/official/cv/resnet/postprocess.py index e438b627884..5f91bcc81eb 100644 --- a/model_zoo/official/cv/resnet/postprocess.py +++ b/model_zoo/official/cv/resnet/postprocess.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# less required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and diff --git a/model_zoo/research/cv/resnetv2/README_CN.md b/model_zoo/research/cv/resnetv2/README_CN.md index 5991b447cfc..6535b1958b8 100644 --- a/model_zoo/research/cv/resnetv2/README_CN.md +++ b/model_zoo/research/cv/resnetv2/README_CN.md @@ -250,6 +250,34 @@ result: {'top_5_accuracy': 0.9988982371794872, 'top_1_accuracy': 0.9502283653846 result: {'top_1_accuracy': 0.7606515786082474, 'top_5_accuracy': 0.9271504510309279} ``` +## 推理过程 + +### [导出MindIR](#contents) + +```shell +python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] +``` + +参数ckpt_file为必填项, +`file_format` 必须在 ["AIR", "MINDIR"]中选择。 + +### 在Ascend310执行推理 + +在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用mindir模型执行推理的示例。 + +```shell +# Ascend310 inference +bash run_infer_310.sh [MINDIR_PATH] [DATASET] [DATA_PATH] [DEVICE_ID] +``` + +- `DATASET` 为数据集类型,如cifar10, cifar100等。 +- `DATA_PATH`为数据集路径。 +- `DEVICE_ID` 可选,默认值为0。 + +### 结果 + +推理结果保存在脚本执行的当前路径,你可以在acc.log中看到精度计算结果。 + # 模型描述 ## 性能 diff --git a/model_zoo/research/cv/resnetv2/ascend310_infer/CMakeLists.txt b/model_zoo/research/cv/resnetv2/ascend310_infer/CMakeLists.txt new file mode 100644 index 00000000000..170e6c5275e --- /dev/null +++ b/model_zoo/research/cv/resnetv2/ascend310_infer/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(Ascend310Infer) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(main src/main.cc src/utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/model_zoo/research/cv/resnetv2/ascend310_infer/build.sh b/model_zoo/research/cv/resnetv2/ascend310_infer/build.sh new file mode 100644 index 00000000000..285514e19f2 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/ascend310_infer/build.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ -d out ]; then + rm -rf out +fi + +mkdir out +cd out || exit + +if [ -f "Makefile" ]; then + make clean +fi + +cmake .. \ + -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/model_zoo/research/cv/resnetv2/ascend310_infer/inc/utils.h b/model_zoo/research/cv/resnetv2/ascend310_infer/inc/utils.h new file mode 100644 index 00000000000..0b400632f51 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/ascend310_infer/inc/utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" + +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector &outputs); +std::vector GetAllFiles(std::string dir_name); + +#endif diff --git a/model_zoo/research/cv/resnetv2/ascend310_infer/src/main.cc b/model_zoo/research/cv/resnetv2/ascend310_infer/src/main.cc new file mode 100644 index 00000000000..00ef7813fdb --- /dev/null +++ b/model_zoo/research/cv/resnetv2/ascend310_infer/src/main.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/dataset/vision_ascend.h" +#include "include/dataset/execute.h" +#include "include/dataset/transforms.h" +#include "include/dataset/vision.h" +#include "inc/utils.h" + +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::CenterCrop; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::TensorTransform; +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::dataset::Execute; + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_string(dataset, "imagenet2012", "dataset"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared(); + auto ascend310 = std::make_shared(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + auto all_files = GetAllFiles(FLAGS_dataset_path); + if (all_files.empty()) { + std::cout << "ERROR: no input data." << std::endl; + return 1; + } + + std::vector modelInputs = model.GetInputs(); + std::map costTime_map; + size_t size = all_files.size(); + + auto decode = Decode(); + auto resize = Resize({256}); + auto centercrop = CenterCrop({224}); + auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375}); + auto hwc2chw = HWC2CHW(); + mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw}); + + for (size_t i = 0; i < size; ++i) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs; + double endTimeMs; + std::vector inputs; + std::vector outputs; + std::cout << "Start predict input files:" << all_files[i] <(startTimeMs, endTimeMs)); + WriteResult(all_files[i], outputs); + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + average += iter->second - iter->first; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/model_zoo/research/cv/resnetv2/ascend310_infer/src/utils.cc b/model_zoo/research/cv/resnetv2/ascend310_infer/src/utils.cc new file mode 100644 index 00000000000..728d57d9362 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/ascend310_infer/src/utils.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "inc/utils.h" + +using mindspore::MSTensor; +using mindspore::DataType; + +std::vector GetAllFiles(std::string dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector dirs; + std::vector files; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == "..") { + continue; + } else if (filename->d_type == DT_DIR) { + dirs.emplace_back(std::string(dirName) + "/" + filename->d_name); + } else if (filename->d_type == DT_REG) { + files.emplace_back(std::string(dirName) + "/" + filename->d_name); + } else { + continue; + } + } + + for (auto d : dirs) { + dir = OpenDir(d); + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + files.emplace_back(std::string(d) + "/" + filename->d_name); + } + } + std::sort(files.begin(), files.end()); + for (auto &f : files) { + std::cout << "image file: " << f << std::endl; + } + return files; +} + +int WriteResult(const std::string& imageFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput; + netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE *outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir; + dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/model_zoo/research/cv/resnetv2/create_imagenet2012_label.py b/model_zoo/research/cv/resnetv2/create_imagenet2012_label.py new file mode 100644 index 00000000000..d29a6719fd6 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/create_imagenet2012_label.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""create_imagenet2012_label""" +import os +import json +import argparse + +parser = argparse.ArgumentParser(description="resnet imagenet2012 label") +parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.") +args = parser.parse_args() + + +def create_label(file_path): + '''create imagenet2012 label''' + print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!") + dirs = os.listdir(file_path) + file_list = [] + for file in dirs: + file_list.append(file) + file_list = sorted(file_list) + + total = 0 + img_label = {} + for i, file_dir in enumerate(file_list): + files = os.listdir(os.path.join(file_path, file_dir)) + for f in files: + img_label[f] = i + total += len(files) + + with open("imagenet_label.json", "w+") as label: + json.dump(img_label, label) + + print("[INFO] Completed! Total {} data.".format(total)) + + +if __name__ == '__main__': + create_label(args.img_path) diff --git a/model_zoo/research/cv/resnetv2/postprocess.py b/model_zoo/research/cv/resnetv2/postprocess.py new file mode 100644 index 00000000000..863774c95b2 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/postprocess.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""postprocess""" +import os +import json +import argparse +import numpy as np +from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy + +parser = argparse.ArgumentParser(description="postprocess") +parser.add_argument("--dataset", type=str, required=True, help="dataset type.") +parser.add_argument("--result_path", type=str, required=True, help="result files path.") +parser.add_argument("--label_path", type=str, required=True, help="image file path.") +args_opt = parser.parse_args() + +if args_opt.dataset == "cifar10": + from src.config import config1 as config +elif args_opt.dataset == "cifar100": + from src.config import config2 as config +elif args_opt.dataset == 'imagenet2012': + from src.config import config3 as config +else: + raise ValueError("dataset is not support.") + +def cal_acc_cifar(result_path, label_path): + '''calculate cifar accuracy''' + top1_acc = Top1CategoricalAccuracy() + top5_acc = Top5CategoricalAccuracy() + result_shape = (config.batch_size, config.class_num) + + file_num = len(os.listdir(result_path)) + label_list = np.load(label_path) + for i in range(file_num): + f_name = args_opt.dataset + "_bs" + str(config.batch_size) + "_" + str(i) + "_0.bin" + full_file_path = os.path.join(result_path, f_name) + if os.path.isfile(full_file_path): + result = np.fromfile(full_file_path, dtype=np.float32).reshape(result_shape) + gt_classes = label_list[i] + + top1_acc.update(result, gt_classes) + top5_acc.update(result, gt_classes) + print("top1 acc: ", top1_acc.eval()) + print("top5 acc: ", top5_acc.eval()) + +def cal_acc_imagenet(result_path, label_path): + '''calculate imagenet2012 accuracy''' + batch_size = 1 + files = os.listdir(result_path) + with open(label_path, "r") as label: + labels = json.load(label) + + top1 = 0 + top5 = 0 + total_data = len(files) + for file in files: + img_ids_name = file.split('_0.')[0] + data_path = os.path.join(result_path, img_ids_name + "_0.bin") + result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, config.class_num) + for batch in range(batch_size): + predict = np.argsort(-result[batch], axis=-1) + if labels[img_ids_name+".JPEG"] == predict[0]: + top1 += 1 + if labels[img_ids_name+".JPEG"] in predict[:5]: + top5 += 1 + print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.") + + +if __name__ == '__main__': + if args_opt.dataset.lower() == "cifar10" or args_opt.dataset.lower() == "cifar100": + cal_acc_cifar(args_opt.result_path, args_opt.label_path) + else: + cal_acc_imagenet(args_opt.result_path, args_opt.label_path) diff --git a/model_zoo/research/cv/resnetv2/preprocess.py b/model_zoo/research/cv/resnetv2/preprocess.py new file mode 100644 index 00000000000..6ddbaa9b576 --- /dev/null +++ b/model_zoo/research/cv/resnetv2/preprocess.py @@ -0,0 +1,59 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" preprocess """ +import os +import argparse +import numpy as np + +parser = argparse.ArgumentParser(description='preprocess') +parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset, cifar10, imagenet2012') +parser.add_argument('--dataset_path', type=str, default="../cifar-10/cifar-10-verify-bin", + help='Dataset path.') +parser.add_argument('--output_path', type=str, default="./preprocess_Result", + help='preprocess Result path.') +args_opt = parser.parse_args() + +# import dataset +if args_opt.dataset == "cifar10": + from src.dataset import create_dataset1 as create_dataset + from src.config import config1 as config +elif args_opt.dataset == "cifar100": + from src.dataset import create_dataset2 as create_dataset + from src.config import config2 as config +else: + raise ValueError("dataset is not support.") + + +def get_cifar_bin(): + '''generate cifar bin files.''' + ds = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) + img_path = os.path.join(args_opt.output_path, "00_img_data") + label_path = os.path.join(args_opt.output_path, "label.npy") + os.makedirs(img_path) + label_list = [] + + for i, data in enumerate(ds.create_dict_iterator(output_numpy=True)): + img_data = data["image"] + img_label = data["label"] + + file_name = args_opt.dataset + "_bs" + str(config.batch_size) + "_" + str(i) + ".bin" + img_file_path = os.path.join(img_path, file_name) + img_data.tofile(img_file_path) + label_list.append(img_label) + np.save(label_path, label_list) + print("=" * 20, "export bin files finished", "=" * 20) + +if __name__ == '__main__': + get_cifar_bin() diff --git a/model_zoo/research/cv/resnetv2/scripts/run_infer_310.sh b/model_zoo/research/cv/resnetv2/scripts/run_infer_310.sh new file mode 100644 index 00000000000..2bcc168814a --- /dev/null +++ b/model_zoo/research/cv/resnetv2/scripts/run_infer_310.sh @@ -0,0 +1,123 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 3 || $# -gt 4 ]]; then + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET] [DATA_PATH] [DEVICE_ID] + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +model=$(get_real_path $1) +if [ ${2,,} == 'cifar10' ] || [ ${2,,} == 'cifar100' ] || [ ${2,,} == 'imagenet2012' ]; then + dataset=$2 +else + echo "dataset must choose from [cifar10, cifar100, imagenet2012]" + exit 1 +fi + +data_path=$(get_real_path $3) + +device_id=0 +if [ $# == 4 ]; then + device_id=$4 +fi + +echo "mindir name: "$model +echo "dataset path: "$data_path +echo "dataset: "$dataset +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function compile_app() +{ + cd ../ascend310_infer || exit + bash build.sh &> build.log +} + +function preprocess_data() +{ + if [ -d preprocess_Result ]; then + rm -rf ./preprocess_Result + fi + mkdir preprocess_Result + + python3.7 ../preprocess.py --dataset=$dataset --dataset_path=$data_path --output_path=./preprocess_Result +} + +function infer() +{ + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + ../ascend310_infer/out/main --mindir_path=$model --dataset_path=$data_path --dataset=$dataset --device_id=$device_id &> infer.log +} + +function cal_acc() +{ + if [ "${dataset}" == "cifar10" ] || [ "${dataset}" == "cifar100" ]; then + python ../postprocess.py --dataset=$dataset --label_path=./preprocess_Result/label.npy --result_path=result_Files &> acc.log + else + python3.7 ../create_imagenet2012_label.py --img_path=$data_path + python3.7 ../postprocess.py --dataset=$dataset --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log + fi +} + +if [ "${dataset}" == "cifar10" ] || [ "${dataset}" == "cifar100" ]; then + preprocess_data + data_path=./preprocess_Result/00_img_data +fi + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_acc +if [ $? -ne 0 ]; then + echo "calculate accuracy failed" + exit 1 +fi \ No newline at end of file