!19566 push icnet code add 310inference
Merge pull request !19566 from bigpingping/master
This commit is contained in:
commit
d2a2f14a48
|
@ -13,6 +13,7 @@
|
||||||
- [Evaluation Process](#evaluation-process)
|
- [Evaluation Process](#evaluation-process)
|
||||||
- [Evaluation](#evaluation)
|
- [Evaluation](#evaluation)
|
||||||
- [Evaluation Result](#evaluation-result)
|
- [Evaluation Result](#evaluation-result)
|
||||||
|
- [310 infer](#310-inference)
|
||||||
- [Model Description](#model-description)
|
- [Model Description](#model-description)
|
||||||
- [Description of Random Situation](#description-of-random-situation)
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
@ -50,27 +51,39 @@ It contains 5,000 finely annotated images split into training, validation and te
|
||||||
```python
|
```python
|
||||||
.
|
.
|
||||||
└─ICNet
|
└─ICNet
|
||||||
├─configs
|
├── ascend310_infer
|
||||||
├─icnet.yaml # config file
|
│ ├── build.sh
|
||||||
├─models
|
│ ├── CMakeLists.txt
|
||||||
├─base_models
|
│ ├── inc
|
||||||
├─resnt50_v1.py # used resnet50
|
│ │ └── utils.h
|
||||||
├─__init__.py
|
│ └── src
|
||||||
├─icnet.py # validation network
|
│ ├── main.cc
|
||||||
├─icnet_dc.py # training network
|
│ └── utils.cc
|
||||||
├─scripts
|
├── eval.py # validation
|
||||||
├─run_distribute_train8p.sh # Multi card distributed training in ascend
|
├── export.py # export mindir
|
||||||
├─run_eval.sh # validation script
|
├── postprocess.py # 310 infer calculate accuracy
|
||||||
├─utils
|
├── README.md # descriptions about ICNet
|
||||||
├─__init__.py
|
├── scripts
|
||||||
├─logger.py # logger
|
│ ├── run_distribute_train8p.sh # multi cards distributed training in ascend
|
||||||
├─loss.py # loss
|
│ ├── run_eval.sh # validation script
|
||||||
├─losses.py # SoftmaxCrossEntropyLoss
|
│ └── run_infer_310.sh # 310 infer script
|
||||||
├─lr_scheduler.py # lr
|
├── src
|
||||||
└─metric.py # metric
|
│ ├── cityscapes_mindrecord.py # create mindrecord dataset
|
||||||
├─eval.py # validation
|
│ ├── __init__.py
|
||||||
├─train.py # train
|
│ ├── logger.py # logger
|
||||||
└─visualize.py # inference visualization
|
│ ├── losses.py # used losses
|
||||||
|
│ ├── loss.py # loss
|
||||||
|
│ ├── lr_scheduler.py # lr
|
||||||
|
│ ├── metric.py # metric
|
||||||
|
│ ├── models
|
||||||
|
│ │ ├── icnet_1p.py # net single card
|
||||||
|
│ │ ├── icnet_dc.py # net multi cards
|
||||||
|
│ │ ├── icnet.py # validation card
|
||||||
|
│ │ └── resnet50_v1.py # backbone
|
||||||
|
│ ├── model_utils
|
||||||
|
│ │ └── icnet.yaml # config
|
||||||
|
│ └── visualize.py # inference visualization
|
||||||
|
└── train.py # train
|
||||||
```
|
```
|
||||||
|
|
||||||
## Script Parameters
|
## Script Parameters
|
||||||
|
@ -169,6 +182,14 @@ avg_pixacc 0.94285786
|
||||||
avgtime 0.19648232793807982
|
avgtime 0.19648232793807982
|
||||||
````
|
````
|
||||||
|
|
||||||
|
## 310 infer
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bash run_infer_310.sh [The path of the MINDIR for 310 infer] [The path of the dataset for 310 infer] 0
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:: Before executing 310 infer, create the MINDIR/AIR model using "python export.py --ckpt-file [The path of the CKPT for exporting]".
|
||||||
|
|
||||||
# [Model Description](#Content)
|
# [Model Description](#Content)
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
cmake_minimum_required(VERSION 3.14.1)
|
||||||
|
project(Ascend310Infer)
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||||
|
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||||
|
option(MINDSPORE_PATH "mindspore install path" "")
|
||||||
|
include_directories(${MINDSPORE_PATH})
|
||||||
|
include_directories(${MINDSPORE_PATH}/include)
|
||||||
|
include_directories(${PROJECT_SRC_ROOT})
|
||||||
|
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||||
|
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||||
|
|
||||||
|
add_executable(main src/main.cc src/utils.cc)
|
||||||
|
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,23 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ ! -d out ]; then
|
||||||
|
mkdir out
|
||||||
|
fi
|
||||||
|
cd out || exit
|
||||||
|
cmake .. \
|
||||||
|
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||||
|
make
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||||
|
DIR *OpenDir(std::string_view dirName);
|
||||||
|
std::string RealPath(std::string_view path);
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||||
|
#endif
|
|
@ -0,0 +1,162 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <sys/time.h>
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "../inc/utils.h"
|
||||||
|
#include "minddata/dataset/include/execute.h"
|
||||||
|
#include "minddata/dataset/include/transforms.h"
|
||||||
|
#include "minddata/dataset/include/vision.h"
|
||||||
|
#include "minddata/dataset/include/vision_ascend.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
|
||||||
|
using mindspore::Serialization;
|
||||||
|
using mindspore::Model;
|
||||||
|
using mindspore::Context;
|
||||||
|
using mindspore::Status;
|
||||||
|
using mindspore::ModelType;
|
||||||
|
using mindspore::Graph;
|
||||||
|
using mindspore::GraphCell;
|
||||||
|
using mindspore::kSuccess;
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::DataType;
|
||||||
|
using mindspore::dataset::Execute;
|
||||||
|
using mindspore::dataset::TensorTransform;
|
||||||
|
using mindspore::dataset::vision::Decode;
|
||||||
|
using mindspore::dataset::vision::Resize;
|
||||||
|
using mindspore::dataset::vision::Rescale;
|
||||||
|
using mindspore::dataset::vision::Normalize;
|
||||||
|
using mindspore::dataset::vision::HWC2CHW;
|
||||||
|
|
||||||
|
using mindspore::dataset::transforms::TypeCast;
|
||||||
|
|
||||||
|
DEFINE_string(model_path, "/root/ICNet.mindir", "model path");
|
||||||
|
DEFINE_string(dataset_path, "/data/cityscapes/leftImg8bit/val", "dataset path");
|
||||||
|
DEFINE_int32(input_width, 2048, "input width");
|
||||||
|
DEFINE_int32(input_height, 1024, "inputheight");
|
||||||
|
DEFINE_int32(device_id, 0, "device id");
|
||||||
|
DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode");
|
||||||
|
DEFINE_string(op_select_impl_mode, "", "op select impl mode");
|
||||||
|
DEFINE_string(device_target, "Ascend310", "device target");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
if (RealPath(FLAGS_model_path).empty()) {
|
||||||
|
std::cout << "Invalid model" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310_info->SetDeviceID(FLAGS_device_id);
|
||||||
|
context->MutableDeviceInfo().push_back(ascend310_info);
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "Load model failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Model model;
|
||||||
|
ret = model.Build(GraphCell(graph), context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "ERROR: Build failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> modelInputs = model.GetInputs();
|
||||||
|
|
||||||
|
auto all_files = GetAllFiles(FLAGS_dataset_path);
|
||||||
|
if (all_files.empty()) {
|
||||||
|
std::cout << "ERROR: no input data." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto decode = Decode();
|
||||||
|
auto normalize = Normalize({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225});
|
||||||
|
auto hwc2chw = HWC2CHW();
|
||||||
|
auto rescale = Rescale(1.0 / 255.0, 0);
|
||||||
|
auto typeCast = TypeCast("float32");
|
||||||
|
|
||||||
|
mindspore::dataset::Execute transformDecode(decode);
|
||||||
|
mindspore::dataset::Execute transform({rescale, normalize, hwc2chw});
|
||||||
|
mindspore::dataset::Execute transformCast(typeCast);
|
||||||
|
|
||||||
|
std::map<double, double> costTime_map;
|
||||||
|
|
||||||
|
size_t size = all_files.size();
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
struct timeval start;
|
||||||
|
struct timeval end;
|
||||||
|
double startTime_ms;
|
||||||
|
double endTime_ms;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
|
||||||
|
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||||
|
mindspore::MSTensor image = ReadFileToTensor(all_files[i]);
|
||||||
|
|
||||||
|
transformDecode(image, &image);
|
||||||
|
std::vector<int64_t> shape = image.Shape();
|
||||||
|
transform(image, &image);
|
||||||
|
|
||||||
|
inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
|
||||||
|
image.Data().get(), image.DataSize());
|
||||||
|
|
||||||
|
gettimeofday(&start, NULL);
|
||||||
|
model.Predict(inputs, &outputs);
|
||||||
|
gettimeofday(&end, NULL);
|
||||||
|
|
||||||
|
startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||||
|
endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||||
|
costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
|
||||||
|
WriteResult(all_files[i], outputs);
|
||||||
|
}
|
||||||
|
double average = 0.0;
|
||||||
|
int infer_cnt = 0;
|
||||||
|
|
||||||
|
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||||
|
double diff = 0.0;
|
||||||
|
diff = iter->second - iter->first;
|
||||||
|
average += diff;
|
||||||
|
infer_cnt++;
|
||||||
|
}
|
||||||
|
|
||||||
|
average = average / infer_cnt;
|
||||||
|
|
||||||
|
std::stringstream timeCost;
|
||||||
|
timeCost << "NN inference cost average time: " << average << " ms of infer_count " << infer_cnt << std::endl;
|
||||||
|
std::cout << "NN inference cost average time: " << average << "ms of infer_count " << infer_cnt << std::endl;
|
||||||
|
std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
|
||||||
|
std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
|
||||||
|
file_stream << timeCost.str();
|
||||||
|
file_stream.close();
|
||||||
|
costTime_map.clear();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::DataType;
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||||
|
struct dirent *filename;
|
||||||
|
DIR *dir = OpenDir(dirName);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<std::string> res;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string dName = std::string(filename->d_name);
|
||||||
|
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||||
|
}
|
||||||
|
std::sort(res.begin(), res.end());
|
||||||
|
for (auto &f : res) {
|
||||||
|
std::cout << "image file: " << f << std::endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||||
|
std::string homePath = "./result_Files";
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
size_t outputSize;
|
||||||
|
std::shared_ptr<const void> netOutput = outputs[i].Data();
|
||||||
|
outputSize = outputs[i].DataSize();
|
||||||
|
int pos = imageFile.rfind('/');
|
||||||
|
std::string fileName(imageFile, pos + 1);
|
||||||
|
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||||
|
std::string outFileName = homePath + "/" + fileName;
|
||||||
|
FILE *outputFile = fopen(outFileName.c_str(), "wb");
|
||||||
|
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||||
|
fclose(outputFile);
|
||||||
|
outputFile = nullptr;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||||
|
if (file.empty()) {
|
||||||
|
std::cout << "Pointer file is nullptr" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream ifs(file);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "File: " << file << "open failed" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
size_t size = ifs.tellg();
|
||||||
|
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)},
|
||||||
|
nullptr, size);
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||||
|
ifs.close();
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
DIR *OpenDir(std::string_view dirName) {
|
||||||
|
if (dirName.empty()) {
|
||||||
|
std::cout << " dirName is null ! " << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string realPath = RealPath(dirName);
|
||||||
|
struct stat s;
|
||||||
|
lstat(realPath.c_str(), &s);
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DIR *dir = opendir(realPath.c_str());
|
||||||
|
if (dir == nullptr) {
|
||||||
|
std::cout << "Can not open dir " << dirName << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||||
|
return dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RealPath(std::string_view path) {
|
||||||
|
char realPathMem[PATH_MAX] = {0};
|
||||||
|
char *realPathRet = nullptr;
|
||||||
|
realPathRet = realpath(path.data(), realPathMem);
|
||||||
|
if (realPathRet == nullptr) {
|
||||||
|
std::cout << "File: " << path << " is not exist.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string realPath(realPathMem);
|
||||||
|
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||||
|
return realPath;
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Evaluate mIou and Pixacc"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ICNet Evaluation")
|
||||||
|
parser.add_argument("--dataset_path", type=str, default="/home/dataset",
|
||||||
|
help="dataset path for evaluation")
|
||||||
|
parser.add_argument("--project_path", type=str, default='/home/ICNet',
|
||||||
|
help="project_path")
|
||||||
|
parser.add_argument("--device_id", type=int, default=5, help="Device id, default is 5.")
|
||||||
|
parser.add_argument("--result_path", type=str, default="", help="Image path.")
|
||||||
|
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
"""evaluate"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.cfg = config
|
||||||
|
|
||||||
|
self.mask_folder = '/home/data'
|
||||||
|
|
||||||
|
# evaluation metrics
|
||||||
|
self.metric = SegmentationMetric(19)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""evaluate"""
|
||||||
|
self.metric.reset()
|
||||||
|
|
||||||
|
list_time = []
|
||||||
|
|
||||||
|
for root, _, files in os.walk(args_opt.dataset_path):
|
||||||
|
for filename in files:
|
||||||
|
if filename.endswith('.png'):
|
||||||
|
img_path = os.path.join(root, filename)
|
||||||
|
file_name = filename.split('.')[0]
|
||||||
|
output_file = os.path.join(args_opt.result_path, file_name + "_0.bin")
|
||||||
|
output = np.fromfile(output_file, dtype=np.float32).reshape(1, 19, 1024, 2048)
|
||||||
|
folder_name = os.path.basename(os.path.dirname(img_path))
|
||||||
|
mask_name = filename.replace('leftImg8bit', 'gtFine_labelIds')
|
||||||
|
mask_file = os.path.join(self.mask_folder, folder_name, mask_name)
|
||||||
|
mask = Image.open(mask_file) # mask shape: (W,H)
|
||||||
|
|
||||||
|
mask = self._mask_transform(mask) # mask shape: (H,w)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
end_time = time.time()
|
||||||
|
step_time = end_time - start_time
|
||||||
|
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
self.metric.update(output, mask)
|
||||||
|
list_time.append(step_time)
|
||||||
|
|
||||||
|
mIoU, pixAcc = self.metric.get()
|
||||||
|
|
||||||
|
average_time = sum(list_time) / len(list_time)
|
||||||
|
|
||||||
|
print("avgmiou", mIoU)
|
||||||
|
print("avg_pixacc", pixAcc)
|
||||||
|
print("avgtime", average_time)
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
mask = self._class_to_index(np.array(mask).astype('int32'))
|
||||||
|
return np.array(mask).astype('int32')
|
||||||
|
|
||||||
|
def _class_to_index(self, mask):
|
||||||
|
"""assert the value"""
|
||||||
|
values = np.unique(mask)
|
||||||
|
self._key = np.array([-1, -1, -1, -1, -1, -1,
|
||||||
|
-1, -1, 0, 1, -1, -1,
|
||||||
|
2, 3, 4, -1, -1, -1,
|
||||||
|
5, -1, 6, 7, 8, 9,
|
||||||
|
10, 11, 12, 13, 14, 15,
|
||||||
|
-1, -1, 16, 17, 18])
|
||||||
|
self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
|
||||||
|
for value in values:
|
||||||
|
assert value in self._mapping
|
||||||
|
# Get the index of each pixel value in the mask corresponding to _mapping
|
||||||
|
index = np.digitize(mask.ravel(), self._mapping, right=True)
|
||||||
|
# According to the above index index, according to _key, the corresponding mask image is obtained
|
||||||
|
return self._key[index].reshape(mask.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.path.append(args_opt.project_path)
|
||||||
|
from src.metric import SegmentationMetric
|
||||||
|
from src.logger import SetupLogger
|
||||||
|
# Set config file
|
||||||
|
config_file = "/src/model_utils/icnet.yaml"
|
||||||
|
config_path = os.path.join(args_opt.project_path, config_file)
|
||||||
|
with open(config_path, "r") as yaml_file:
|
||||||
|
cfg = yaml.load(yaml_file.read())
|
||||||
|
logger = SetupLogger(name="semantic_segmentation",
|
||||||
|
save_dir=cfg["train"]["ckpt_dir"],
|
||||||
|
distributed_rank=0,
|
||||||
|
filename='{}_{}_evaluate_log.txt'.format(cfg["model"]["name"], cfg["model"]["backbone"]))
|
||||||
|
|
||||||
|
evaluator = Evaluator(cfg)
|
||||||
|
evaluator.eval()
|
|
@ -0,0 +1,106 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]; then
|
||||||
|
echo "Usage: sh run_infer_310.sh [MODEL_PATH] [DATA_PATH] [DEVICE_ID]
|
||||||
|
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
model=$(get_real_path $1)
|
||||||
|
data_path=$(get_real_path $2)
|
||||||
|
device_id=$3
|
||||||
|
|
||||||
|
echo $model
|
||||||
|
echo $data_path
|
||||||
|
echo $device_id
|
||||||
|
|
||||||
|
export ASCEND_HOME=/usr/local/Ascend/
|
||||||
|
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||||
|
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||||
|
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||||
|
else
|
||||||
|
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||||
|
fi
|
||||||
|
|
||||||
|
function compile_app()
|
||||||
|
{
|
||||||
|
cd ../ascend310_infer || exit
|
||||||
|
if [ -f "Makefile" ]; then
|
||||||
|
make clean
|
||||||
|
fi
|
||||||
|
sh build.sh &> build.log
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "compile app code failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - || exit
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer()
|
||||||
|
{
|
||||||
|
if [ -d result_Files ]; then
|
||||||
|
rm -rf ./result_Files
|
||||||
|
fi
|
||||||
|
if [ -d time_Result ]; then
|
||||||
|
rm -rf ./time_Result
|
||||||
|
fi
|
||||||
|
mkdir result_Files
|
||||||
|
mkdir time_Result
|
||||||
|
../ascend310_infer/out/main --model_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "execute inference failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
function cal_acc()
|
||||||
|
{
|
||||||
|
if [ -d output ]; then
|
||||||
|
rm -rf ./output
|
||||||
|
fi
|
||||||
|
if [ -d output_img ]; then
|
||||||
|
rm -rf ./output_img
|
||||||
|
fi
|
||||||
|
mkdir output
|
||||||
|
mkdir output_img
|
||||||
|
python ../postprocess.py --dataset_path=$data_path --result_path=result_Files &> acc.log
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "calculate accuracy failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
compile_app
|
||||||
|
infer
|
||||||
|
cal_acc
|
|
@ -0,0 +1,2 @@
|
||||||
|
""""init"""
|
||||||
|
from .loss import ICNetLoss
|
|
@ -50,7 +50,7 @@ def _get_city_pairs(folder, split='train'):
|
||||||
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
||||||
# "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
|
# "./Cityscapes/gtFine/train" or "./Cityscapes/gtFine/val"
|
||||||
mask_folder = os.path.join(folder, 'gtFine/' + split)
|
mask_folder = os.path.join(folder, 'gtFine/' + split)
|
||||||
|
# The order of img_paths and mask_paths is one-to-one correspondence
|
||||||
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
||||||
return img_paths, mask_paths
|
return img_paths, mask_paths
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Evaluation Metrics for Semantic Segmentation"""
|
"""Evaluation Metrics for Semantic Segmentation"""
|
||||||
from mindspore import Tensor
|
import numpy as np
|
||||||
import mindspore.ops as ops
|
|
||||||
import mindspore.common.dtype as dtype
|
|
||||||
|
|
||||||
__all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union']
|
__all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union']
|
||||||
|
|
||||||
|
@ -41,11 +39,11 @@ class SegmentationMetric:
|
||||||
correct, labeled = batch_pix_accuracy(pred, label)
|
correct, labeled = batch_pix_accuracy(pred, label)
|
||||||
inter, union = batch_intersection_union(pred, label, self.nclass)
|
inter, union = batch_intersection_union(pred, label, self.nclass)
|
||||||
|
|
||||||
self.total_correct += correct
|
self.total_correct = correct + self.total_correct
|
||||||
self.total_label += labeled
|
self.total_label = labeled + self.total_label
|
||||||
|
|
||||||
self.total_inter += inter
|
self.total_inter = inter + self.total_inter
|
||||||
self.total_union += union
|
self.total_union = union + self.total_union
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Gets the current evaluation result.
|
"""Gets the current evaluation result.
|
||||||
|
@ -55,19 +53,17 @@ class SegmentationMetric:
|
||||||
metrics : tuple of float
|
metrics : tuple of float
|
||||||
pixAcc and mIoU
|
pixAcc and mIoU
|
||||||
"""
|
"""
|
||||||
mean = ops.ReduceMean(keep_dims=False)
|
pixAcc = np.true_divide(self.total_correct, (2.220446049250313e-16 + self.total_label)) # remove c.spacing(1)
|
||||||
pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove c.spacing(1)
|
IoU = np.true_divide(self.total_inter, (2.220446049250313e-16 + self.total_union))
|
||||||
IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
|
|
||||||
|
|
||||||
mIoU = mean(IoU, axis=0)
|
mIoU = np.mean(IoU)
|
||||||
|
|
||||||
return pixAcc, mIoU
|
return mIoU, pixAcc
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets the internal evaluation result to initial state."""
|
"""Resets the internal evaluation result to initial state."""
|
||||||
zeros = ops.Zeros()
|
self.total_inter = np.zeros(self.nclass, dtype=np.float)
|
||||||
self.total_inter = zeros(self.nclass, dtype.float32)
|
self.total_union = np.zeros(self.nclass, dtype=np.float)
|
||||||
self.total_union = zeros(self.nclass, dtype.float32)
|
|
||||||
self.total_correct = 0
|
self.total_correct = 0
|
||||||
self.total_label = 0
|
self.total_label = 0
|
||||||
|
|
||||||
|
@ -75,19 +71,15 @@ class SegmentationMetric:
|
||||||
def batch_pix_accuracy(output, target):
|
def batch_pix_accuracy(output, target):
|
||||||
"""PixAcc"""
|
"""PixAcc"""
|
||||||
|
|
||||||
predict = ops.Argmax(output_type=dtype.int32, axis=1)(output) + 1
|
predict = np.argmax(output, axis=1) + 1
|
||||||
# (1,19, 1024,2048)-->(1, 1024,2048)
|
# (1,19, 1024,2048)-->(1, 1024,2048)
|
||||||
target = target + 1
|
target = target + 1
|
||||||
|
|
||||||
typetrue = dtype.float32
|
labeled = np.array(target > 0).astype(int)
|
||||||
cast = ops.Cast()
|
pixel_labeled = np.sum(labeled) # sum of pixels without 0
|
||||||
sumtarget = ops.ReduceSum()
|
|
||||||
sumcorrect = ops.ReduceSum()
|
|
||||||
|
|
||||||
labeled = cast(target > 0, typetrue)
|
pixel_correct = np.sum(np.array(predict == target).astype(int) * np.array(target > 0).astype(int))
|
||||||
pixel_labeled = sumtarget(labeled) # sum of pixels without 0
|
# Quantity of correct pixels
|
||||||
|
|
||||||
pixel_correct = sumcorrect(cast(predict == target, typetrue) * cast(target > 0, typetrue)) # 标记正确的像素和
|
|
||||||
|
|
||||||
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
|
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
|
||||||
return pixel_correct, pixel_labeled
|
return pixel_correct, pixel_labeled
|
||||||
|
@ -96,26 +88,19 @@ def batch_pix_accuracy(output, target):
|
||||||
def batch_intersection_union(output, target, nclass):
|
def batch_intersection_union(output, target, nclass):
|
||||||
"""mIoU"""
|
"""mIoU"""
|
||||||
# inputs are numpy array, output 4D, target 3D
|
# inputs are numpy array, output 4D, target 3D
|
||||||
predict = ops.Argmax(output_type=dtype.int32, axis=1)(output) + 1 # [N,H,W]
|
predict = np.argmax(output, axis=1) + 1 # [N,H,W]
|
||||||
target = target.astype(dtype.float32) + 1 # [N,H,W]
|
target = target.astype(float) + 1 # [N,H,W]
|
||||||
|
|
||||||
typetrue = dtype.float32
|
predict = predict.astype(float) * np.array(target > 0).astype(float)
|
||||||
cast = ops.Cast()
|
intersection = predict * np.array(predict == target).astype(float)
|
||||||
predict = cast(predict, typetrue) * cast(target > 0, typetrue)
|
|
||||||
intersection = cast(predict, typetrue) * cast(predict == target, typetrue)
|
|
||||||
# areas of intersection and union
|
# areas of intersection and union
|
||||||
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
|
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
|
||||||
|
|
||||||
Range = Tensor([0.0, 20.0], dtype.float32)
|
area_inter, _ = np.array(np.histogram(intersection, bins=nclass, range=(1, nclass+1)))
|
||||||
hist = ops.HistogramFixedWidth(nclass + 1)
|
area_pred, _ = np.array(np.histogram(predict, bins=nclass, range=(1, nclass+1)))
|
||||||
area_inter = hist(intersection, Range)
|
area_lab, _ = np.array(np.histogram(target, bins=nclass, range=(1, nclass+1)))
|
||||||
area_pred = hist(predict, Range)
|
|
||||||
area_lab = hist(target, Range)
|
|
||||||
|
|
||||||
area_union = area_pred + area_lab - area_inter
|
area_all = area_pred + area_lab
|
||||||
|
area_union = area_all - area_inter
|
||||||
|
|
||||||
area_inter = area_inter[1:]
|
return area_inter, area_union
|
||||||
area_union = area_union[1:]
|
|
||||||
Sum = ops.ReduceSum()
|
|
||||||
assert Sum(cast(area_inter > area_union, typetrue)) == 0, "Intersection area should be smaller than Union area"
|
|
||||||
return cast(area_inter, typetrue), cast(area_union, typetrue)
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ train:
|
||||||
epochs: 160
|
epochs: 160
|
||||||
val_epoch: 1 # run validation every val-epoch
|
val_epoch: 1 # run validation every val-epoch
|
||||||
ckpt_dir: "./ckpt/" # ckpt and training log will be saved here
|
ckpt_dir: "./ckpt/" # ckpt and training log will be saved here
|
||||||
mindrecord_dir: '/root/mindrecord'
|
mindrecord_dir: '/root/ICNet/mindrecord'
|
||||||
save_checkpoint_epochs: 5
|
save_checkpoint_epochs: 5
|
||||||
keep_checkpoint_max: 10
|
keep_checkpoint_max: 10
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
"""__init__"""
|
"""__init__"""
|
||||||
from .icnet import ICNet
|
from .icnet import ICNet
|
||||||
from .icnet_dc import ICNetdc
|
from .icnet_dc import ICNetdc
|
||||||
|
from .icnet_1p import ICNet1p
|
||||||
|
|
|
@ -67,7 +67,7 @@ class ICNet(nn.Cell):
|
||||||
|
|
||||||
output = self.head(x_sub1, x_sub2, x_sub4)
|
output = self.head(x_sub1, x_sub2, x_sub4)
|
||||||
|
|
||||||
return output
|
return output[0]
|
||||||
|
|
||||||
|
|
||||||
class PyramidPoolingModule(nn.Cell):
|
class PyramidPoolingModule(nn.Cell):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Tensor
|
||||||
from mindspore import load_param_into_net
|
from mindspore import load_param_into_net
|
||||||
from mindspore import load_checkpoint
|
from mindspore import load_checkpoint
|
||||||
import mindspore.dataset.vision.py_transforms as transforms
|
import mindspore.dataset.vision.py_transforms as transforms
|
||||||
from src.models.icnet import ICNet
|
from models.icnet import ICNet
|
||||||
|
|
||||||
__all__ = ['get_color_palette', 'set_img_color',
|
__all__ = ['get_color_palette', 'set_img_color',
|
||||||
'show_prediction', 'show_colorful_images', 'save_colorful_images']
|
'show_prediction', 'show_colorful_images', 'save_colorful_images']
|
||||||
|
@ -115,28 +115,6 @@ def _getvocpalette(num_cls):
|
||||||
|
|
||||||
vocpalette = _getvocpalette(256)
|
vocpalette = _getvocpalette(256)
|
||||||
|
|
||||||
cityspalette = [
|
|
||||||
128, 64, 128,
|
|
||||||
244, 35, 232,
|
|
||||||
70, 70, 70,
|
|
||||||
102, 102, 156,
|
|
||||||
190, 153, 153,
|
|
||||||
153, 153, 153,
|
|
||||||
250, 170, 30,
|
|
||||||
220, 220, 0,
|
|
||||||
107, 142, 35,
|
|
||||||
152, 251, 152,
|
|
||||||
0, 130, 180,
|
|
||||||
220, 20, 60,
|
|
||||||
255, 0, 0,
|
|
||||||
0, 0, 142,
|
|
||||||
0, 0, 70,
|
|
||||||
0, 60, 100,
|
|
||||||
0, 80, 100,
|
|
||||||
0, 0, 230,
|
|
||||||
119, 11, 32,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _class_to_index(mask):
|
def _class_to_index(mask):
|
||||||
"""assert the value"""
|
"""assert the value"""
|
||||||
|
@ -150,9 +128,9 @@ def _class_to_index(mask):
|
||||||
_mapping = np.array(range(-1, len(_key) - 1)).astype('int32')
|
_mapping = np.array(range(-1, len(_key) - 1)).astype('int32')
|
||||||
for value in values:
|
for value in values:
|
||||||
assert value in _mapping
|
assert value in _mapping
|
||||||
|
# Get the index of each pixel value in the mask corresponding to _mapping
|
||||||
index = np.digitize(mask.ravel(), _mapping, right=True)
|
index = np.digitize(mask.ravel(), _mapping, right=True)
|
||||||
|
# According to the above index index, according to _key, the corresponding mask image is obtained
|
||||||
return _key[index].reshape(mask.shape)
|
return _key[index].reshape(mask.shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,7 +145,7 @@ if __name__ == '__main__':
|
||||||
ckpt_file_name = '/root/ICNet/ckpt/ICNet-160_93_699.ckpt'
|
ckpt_file_name = '/root/ICNet/ckpt/ICNet-160_93_699.ckpt'
|
||||||
param_dict = load_checkpoint(ckpt_file_name)
|
param_dict = load_checkpoint(ckpt_file_name)
|
||||||
load_param_into_net(model, param_dict)
|
load_param_into_net(model, param_dict)
|
||||||
image_path = 'Test/val_lindau_000023_000019_leftImg8bit.png'
|
image_path = '../Test/val_lindau_000023_000019_leftImg8bit.png'
|
||||||
image = Image.open(image_path).convert('RGB')
|
image = Image.open(image_path).convert('RGB')
|
||||||
image = _img_transform(image)
|
image = _img_transform(image)
|
||||||
image = Tensor(image)
|
image = Tensor(image)
|
||||||
|
@ -181,4 +159,4 @@ if __name__ == '__main__':
|
||||||
pred = pred.asnumpy()
|
pred = pred.asnumpy()
|
||||||
pred = pred.squeeze(0)
|
pred = pred.squeeze(0)
|
||||||
pred = get_color_palette(pred, "citys")
|
pred = get_color_palette(pred, "citys")
|
||||||
pred.save('Test/visual_pred.png')
|
pred.save('Test/visual_pred_random.png')
|
||||||
|
|
Loading…
Reference in New Issue