!12776 add unet 310 mindir infer

From: @lihongkang1
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-04 14:22:21 +08:00 committed by Gitee
commit b7e977f590
10 changed files with 707 additions and 79 deletions

View File

@ -22,33 +22,29 @@
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Unet Description](#contents)
## [Unet Description](#contents)
Unet Medical model for 2D image segmentation. This implementation is as described in the original paper [UNet: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597). Unet, in the 2015 ISBI cell tracking competition, many of the best are obtained. In this paper, a network model for medical image segmentation is proposed, and a data enhancement method is proposed to effectively use the annotation data to solve the problem of insufficient annotation data in the medical field. A U-shaped network structure is also used to extract the context and location information.
[Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." * conditionally accepted at MICCAI 2015*. 2015.
[Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." *conditionally accepted at MICCAI 2015*. 2015.
# [Model Architecture](#contents)
Specifically, the U network structure is proposed in UNET, which can better extract and fuse high-level features and obtain context information and spatial location information. The U network structure is composed of encoder and decoder. The encoder is composed of two 3x3 conv and a 2x2 max pooling iteration. The number of channels is doubled after each down sampling. The decoder is composed of a 2x2 deconv, concat layer and two 3x3 convolutions, and then outputs after a 1x1 convolution.
# [Dataset](#contents)
Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
- Description: The training and test datasets are two stacks of 30 sections from a serial section Transmission Electron Microscopy (ssTEM) data set of the Drosophila first instar larva ventral nerve cord (VNC). The microcube measures 2 x 2 x 1.5 microns approx., with a resolution of 4x4x50 nm/pixel.
- License: You are free to use this data set for the purpose of generating or testing non-commercial image segmentation software. If any scientific publications derive from the usage of this data set, you must cite TrakEM2 and the following publication: Cardona A, Saalfeld S, Preibisch S, Schmid B, Cheng A, Pulokas J, Tomancak P, Hartenstein V. 2010. An Integrated Micro- and Macroarchitectural Analysis of the Drosophila Brain by Computer-Assisted Serial Section Electron Microscopy. PLoS Biol 8(10): e1000502. doi:10.1371/journal.pbio.1000502.
- Dataset size22.5M
- Dataset size22.5M,
- Train15M, 30 images (Training data contains 2 multi-page TIF files, each containing 30 2D-images. train-volume.tif and train-labels.tif respectly contain data and label.)
- Val(We randomly divde the training data into 5-fold and evaluate the model by across 5-fold cross-validation.)
- Val(We randomly divide the training data into 5-fold and evaluate the model by across 5-fold cross-validation.)
- Test7.5M, 30 images (Testing data contains 1 multi-page TIF files, each containing 30 2D-images. test-volume.tif respectly contain data.)
- Data formatbinary files(TIF file)
- NoteData will be processed in src/data_loader.py
# [Environment Requirements](#contents)
- HardwareAscend
@ -59,8 +55,6 @@ Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home)
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
@ -82,13 +76,11 @@ After installing MindSpore via the official website, you can start training and
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
```text
├── model_zoo
├── README.md // descriptions about all the models
├── unet
@ -133,14 +125,13 @@ Parameters for both training and evaluation can be set in config.py
'resume_ckpt': './', # pretrain model path
```
## [Training Process](#contents)
### Training
- running on Ascend
```
```shell
python train.py --data_url=/path/to/data/ > train.log 2>&1 &
OR
bash scripts/run_standalone_train.sh [DATASET]
@ -150,7 +141,8 @@ Parameters for both training and evaluation can be set in config.py
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```
```shell
# grep "loss is " train.log
step: 1, loss is 0.7011719, fps is 0.25025035060906264
step: 2, loss is 0.69433594, fps is 56.77693756377044
@ -163,19 +155,20 @@ Parameters for both training and evaluation can be set in config.py
step: 598, loss is 0.19958496, fps is 57.95493929352674
step: 599, loss is 0.18371582, fps is 58.04039977720966
step: 600, loss is 0.22070312, fps is 56.99692546024671
```
The model checkpoint will be saved in the current directory.
### Distributed Training
```
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
```
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
```
```shell
# grep "loss is" logs/device0/log.log
step: 1, loss is 0.70524895, fps is 0.15914689861221412
step: 2, loss is 0.6925452, fps is 56.43668656967454
@ -192,7 +185,7 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-48_600.ckpt".
```
```shell
python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]
@ -200,16 +193,16 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```
```shell
# grep "Cross valid dice coeff is:" eval.log
============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473}
```
# [Model Description](#contents)
## [Performance](#contents)
## Performance
### Evaluation Performance
@ -232,15 +225,16 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
| Checkpoint for Fine tuning | 355.11M (.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
## [How to use](#contents)
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
- Running on Ascend
```
```python
# Set context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",save_graphs=True,device_id=device_id)
@ -259,13 +253,41 @@ If you need to use the trained model to perform inference on multiple hardware p
print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
```
- Running on Ascend 310
Export MindIR
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
The ckpt_file parameter is required,
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
Before performing inference, the MINDIR file must be exported by export script on the 910 environment.
Current batch_size can only be set to 1.
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
`DEVICE_ID` is optional, default value is 0.
Inference result is saved in current path, you can find result in acc.log file.
```text
Cross valid dice coeff is: 0.9054352151297033
```
### Continue Training on the Pretrained Model
- running on Ascend
```
```python
# Define model
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
# Continue training if set 'resume' to be True
@ -298,11 +320,10 @@ If you need to use the trained model to perform inference on multiple hardware p
print("============== End Training ==============")
```
# [Description of Random Situation](#contents)
In data_loader.py, we set the seed inside “_get_val_train_indices" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -254,6 +254,33 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329
print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
```
- Ascend 310环境运行
导出mindir模型
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
参数`ckpt_file` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。
在执行推理前MINDIR文件必须在910上通过export.py文件导出。
目前仅可处理batch_Size为1。
```shell
# Ascend310 推理
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
`DEVICE_ID` 可选,默认值为 0。
推理结果保存在当前路径可在acc.log中看到最终精度结果。
```text
Cross valid dice coeff is: 0.9054352151297033
```
### 继续训练预训练模型

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(MindSporeCxxTestcase[CXX])
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT}/../inc)
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main main.cc utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/api/context.h"
#include "include/minddata/dataset/include/execute.h"
#include "include/minddata/dataset/include/vision.h"
#include "../inc/utils.h"
#include "include/api/types.h"
using mindspore::Context;
using mindspore::GlobalContext;
using mindspore::ModelContext;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::dataset::Execute;
using mindspore::MSTensor;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<Context>();
Model model(GraphCell(graph), model_context);
Status ret = model.Build();
if (ret != kSuccess) {
std::cout << "EEEEEEEERROR Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTime_ms;
double endTime_ms;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto img = ReadFileToTensor(all_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
img.Data().get(), img.DataSize());
gettimeofday(&start, NULL);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, NULL);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
return 1;
}
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
WriteResult(all_files[i], outputs);
}
double average = 0.0;
int infer_cnt = 0;
char tmpCh[256] = {0};
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
infer_cnt++;
}
average = average/infer_cnt;
snprintf(tmpCh, sizeof(tmpCh), "NN inference cost average time: %4.3f ms of infer_count %d \n", average, infer_cnt);
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
file_stream << tmpCh;
file_stream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,136 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." ||
dName == ".." ||
filename->d_type != DT_REG)
continue;
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
MSTensor buffer(file, DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char real_path_mem[PATH_MAX] = {0};
char *real_path_ret = nullptr;
real_path_ret = realpath(path.data(), real_path_mem);
if (real_path_ret == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string real_path(real_path_mem);
std::cout << path << " realpath is: " << real_path << std::endl;
return real_path;
}

View File

@ -0,0 +1,97 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""unet 310 infer."""
import os
import argparse
import numpy as np
from src.data_loader import create_dataset
from src.config import cfg_unet
from scipy.special import softmax
class dice_coeff():
def __init__(self):
self.clear()
def clear(self):
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = inputs[0]
y = np.array(inputs[1])
self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1)
y_pred = softmax(y_pred, axis=3)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2*float(inter)/float(union+1e-6)
print("single dice coeff is:", single_dice_coeff)
self._dice_coeff_sum += single_dice_coeff
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
def test_net(data_dir,
cross_valid_ind=1,
cfg=None):
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
labels_list = []
for data in valid_dataset:
labels_list.append(data[1].asnumpy())
return labels_list
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
help='infer result path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet)
rst_path = args.rst_path
metrics = dice_coeff()
for j in range(len(os.listdir(rst_path))):
file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
output = np.fromfile(file_name, np.float32).reshape(1, 2, 388, 388)
label = label_list[j]
metrics.update(output, label)
print("Cross valid dice coeff is: ", metrics.eval())

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""unet 310 infer preprocess dataset"""
import argparse
from src.data_loader import create_dataset
from src.config import cfg_unet
def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None):
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
for i, data in enumerate(valid_dataset):
file_name = "ISBI_test_bs_1_" + str(i) + ".bin"
file_path = result_path + file_name
data[0].asnumpy().tofile(file_path)
def get_args():
parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
help='data directory')
parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/',
help='result path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path=
args.result_path)

View File

@ -0,0 +1,115 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
if [ $# == 3 ]; then
device_id=$3
if [ -z $device_id ]; then
device_id=0
else
device_id=$device_id
fi
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages/te.egg:$ASCEND_HOME/atc/python/site-packages/topi.egg:$ASCEND_HOME/atc/python/site-packages/auto_tune.egg::$ASCEND_HOME/atc/python/site-packages/schedule_search.egg:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
if [ -d preprocess_Result ]; then
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --data_url=$data_path --result_path=./preprocess_Result/
}
function compile_app()
{
cd ../ascend310_infer/src
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
}
function infer()
{
cd -
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log
}
function cal_acc()
{
python3.7 ../postprocess.py --data_url=$data_path --rst_path=./result_Files/ &> acc.log &
}
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess dataset failed"
exit 1
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo "execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi